1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline ops that need test_util.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 22 from tensorflow.core.protobuf import config_pb2 23 from tensorflow.python.client import session 24 from tensorflow.python.data.ops import dataset_ops 25 from tensorflow.python.data.ops import iterator_ops 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.framework import errors 29 from tensorflow.python.framework import function 30 from tensorflow.python.framework import ops 31 from tensorflow.python.framework import test_util 32 from tensorflow.python.ops import array_ops 33 from tensorflow.python.ops import functional_ops 34 from tensorflow.python.ops import lookup_ops 35 from tensorflow.python.ops import math_ops 36 from tensorflow.python.ops import string_ops 37 from tensorflow.python.platform import test 38 39 40 class IteratorClusterTest(test.TestCase): 41 42 def testRemoteIteratorWithoutRemoteCallFail(self): 43 worker_config = config_pb2.ConfigProto() 44 worker_config.device_count["CPU"] = 2 45 worker, _ = test_util.create_local_cluster( 46 1, 1, worker_config=worker_config) 47 48 with ops.device("/job:worker/replica:0/task:0/cpu:1"): 49 dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 50 iterator_3 = dataset_3.make_one_shot_iterator() 51 iterator_3_handle = iterator_3.string_handle() 52 53 with ops.device("/job:worker/replica:0/task:0/cpu:0"): 54 remote_it = iterator_ops.Iterator.from_string_handle( 55 iterator_3_handle, dataset_3.output_types, dataset_3.output_shapes) 56 get_next_op = remote_it.get_next() 57 58 with session.Session(worker[0].target) as sess: 59 with self.assertRaises(errors.InvalidArgumentError): 60 sess.run(get_next_op) 61 62 def _testRemoteIteratorHelper(self, device0, device1, target): 63 with ops.device(device1): 64 dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 65 iterator_3 = dataset_3.make_one_shot_iterator() 66 iterator_3_handle = iterator_3.string_handle() 67 68 @function.Defun(dtypes.string) 69 def _remote_fn(h): 70 remote_iterator = iterator_ops.Iterator.from_string_handle( 71 h, dataset_3.output_types, dataset_3.output_shapes) 72 return remote_iterator.get_next() 73 74 with ops.device(device0): 75 target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 76 remote_op = functional_ops.remote_call( 77 args=[iterator_3_handle], 78 Tout=[dtypes.int32], 79 f=_remote_fn, 80 target=target_placeholder) 81 82 with session.Session(target) as sess: 83 elem = sess.run(remote_op, feed_dict={target_placeholder: device1}) 84 self.assertEqual(elem, [1]) 85 # Fails when target is cpu:0 where the resource is not located. 86 with self.assertRaises(errors.InvalidArgumentError): 87 sess.run(remote_op, feed_dict={target_placeholder: device0}) 88 elem = sess.run(iterator_3.get_next()) 89 self.assertEqual(elem, [2]) 90 elem = sess.run(remote_op, feed_dict={target_placeholder: device1}) 91 self.assertEqual(elem, [3]) 92 with self.assertRaises(errors.OutOfRangeError): 93 sess.run(remote_op, feed_dict={target_placeholder: device1}) 94 95 def testRemoteIteratorUsingRemoteCallOp(self): 96 worker_config = config_pb2.ConfigProto() 97 worker_config.device_count["CPU"] = 2 98 worker, _ = test_util.create_local_cluster( 99 1, 1, worker_config=worker_config) 100 101 self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0", 102 "/job:worker/replica:0/task:0/cpu:1", 103 worker[0].target) 104 105 def testRemoteIteratorUsingRemoteCallOpCrossProcess(self): 106 workers, _ = test_util.create_local_cluster(2, 1) 107 108 self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0", 109 "/job:worker/replica:0/task:1/cpu:0", 110 workers[0].target) 111 112 def testCaptureHashTableInSharedIterator(self): 113 worker, _ = test_util.create_local_cluster(1, 1) 114 115 # NOTE(mrry): We must use the V2 variants of `HashTable` 116 # etc. because these produce a `tf.resource`-typed output that is 117 # compatible with the in-graph function implementation. 118 default_val = -1 119 keys = constant_op.constant(["brain", "salad", "surgery"]) 120 values = constant_op.constant([0, 1, 2], dtypes.int64) 121 table = lookup_ops.HashTable( 122 lookup_ops.KeyValueTensorInitializer(keys, values), 123 default_val, 124 shared_name="shared_table") 125 126 input_sentences = dataset_ops.Dataset.from_tensor_slices( 127 ["brain brain tank salad surgery", "surgery brain"]) 128 129 iterator = ( 130 input_sentences.map(lambda x: string_ops.string_split([x]).values).map( 131 table.lookup) 132 .make_initializable_iterator(shared_name="shared_iterator")) 133 init_op = iterator.initializer 134 get_next = iterator.get_next() 135 136 with session.Session(worker[0].target) as sess: 137 sess.run(table.init) 138 sess.run(init_op) 139 self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next)) 140 141 with session.Session(worker[0].target) as sess: 142 self.assertAllEqual([2, 0], sess.run(get_next)) 143 with self.assertRaises(errors.OutOfRangeError): 144 sess.run(get_next) 145 146 def testImplicitDisposeParallelMapDataset(self): 147 # Tests whether a parallel map dataset will be cleaned up correctly when 148 # the pipeline does not run it until exhaustion. 149 # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> 150 # RepeatDataset(None) -> PrefetchDataset(100). 151 worker, _ = test_util.create_local_cluster(1, 1) 152 153 components = (np.arange(1000), 154 np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis], 155 np.array(37.0) * np.arange(1000)) 156 157 def _map_fn(x, y, z): 158 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 159 160 dataset = ( 161 dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) 162 .repeat(None).prefetch(10000)) 163 164 iterator = dataset.make_initializable_iterator() 165 init_op = iterator.initializer 166 get_next = iterator.get_next() 167 168 with session.Session(worker[0].target) as sess: 169 sess.run(init_op) 170 for _ in range(3): 171 sess.run(get_next) 172 173 174 if __name__ == "__main__": 175 test.main() 176