1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import threading 21 22 import numpy as np 23 24 from tensorflow.python.data.ops import dataset_ops 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import errors 27 from tensorflow.python.platform import test 28 29 30 class DatasetConstructorTest(test.TestCase): 31 32 def _testFromGenerator(self, generator, elem_sequence, num_repeats): 33 iterator = ( 34 dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64) 35 .repeat(num_repeats) 36 .prefetch(5) 37 .make_initializable_iterator()) 38 init_op = iterator.initializer 39 get_next = iterator.get_next() 40 41 with self.test_session() as sess: 42 for _ in range(2): # Run twice to test reinitialization. 43 sess.run(init_op) 44 for _ in range(num_repeats): 45 for elem in elem_sequence: 46 self.assertAllEqual(elem, sess.run(get_next)) 47 with self.assertRaises(errors.OutOfRangeError): 48 sess.run(get_next) 49 50 def _testFromGeneratorOneShot(self, generator, elem_sequence, num_repeats): 51 iterator = ( 52 dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64) 53 .repeat(num_repeats) 54 .prefetch(5) 55 .make_one_shot_iterator()) 56 get_next = iterator.get_next() 57 58 with self.test_session() as sess: 59 for _ in range(num_repeats): 60 for elem in elem_sequence: 61 self.assertAllEqual(elem, sess.run(get_next)) 62 with self.assertRaises(errors.OutOfRangeError): 63 sess.run(get_next) 64 65 def testFromGeneratorUsingFunction(self): 66 def generator(): 67 for i in range(1, 100): 68 yield [i] * i 69 elem_sequence = list(generator()) 70 self._testFromGenerator(generator, elem_sequence, 1) 71 self._testFromGenerator(generator, elem_sequence, 5) 72 self._testFromGeneratorOneShot(generator, elem_sequence, 1) 73 self._testFromGeneratorOneShot(generator, elem_sequence, 5) 74 75 def testFromGeneratorUsingList(self): 76 generator = lambda: [[i] * i for i in range(1, 100)] 77 elem_sequence = list(generator()) 78 self._testFromGenerator(generator, elem_sequence, 1) 79 self._testFromGenerator(generator, elem_sequence, 5) 80 81 def testFromGeneratorUsingNdarray(self): 82 generator = lambda: np.arange(100, dtype=np.int64) 83 elem_sequence = list(generator()) 84 self._testFromGenerator(generator, elem_sequence, 1) 85 self._testFromGenerator(generator, elem_sequence, 5) 86 87 def testFromGeneratorUsingGeneratorExpression(self): 88 # NOTE(mrry): Generator *expressions* are not repeatable (or in 89 # general reusable), because they eagerly evaluate the `for` 90 # expression as `iter(range(1, 100))` and discard the means of 91 # reconstructing `range(1, 100)`. Wrapping the generator 92 # expression in a `lambda` makes it repeatable. 93 generator = lambda: ([i] * i for i in range(1, 100)) 94 elem_sequence = list(generator()) 95 self._testFromGenerator(generator, elem_sequence, 1) 96 self._testFromGenerator(generator, elem_sequence, 5) 97 98 def testFromMultipleConcurrentGenerators(self): 99 num_inner_repeats = 5 100 num_outer_repeats = 100 101 102 def generator(): 103 for i in range(1, 10): 104 yield ([i] * i, [i, i ** 2, i ** 3]) 105 input_list = list(generator()) 106 107 # The interleave transformation is essentially a flat map that 108 # draws from multiple input datasets concurrently (in a cyclic 109 # fashion). By placing `Datsaet.from_generator()` inside an 110 # interleave, we test its behavior when multiple iterators are 111 # active at the same time; by additionally prefetching inside the 112 # interleave, we create the possibility of parallel (modulo GIL) 113 # invocations to several iterators created by the same dataset. 114 def interleave_fn(_): 115 return (dataset_ops.Dataset.from_generator( 116 generator, output_types=(dtypes.int64, dtypes.int64), 117 output_shapes=([None], [3])) 118 .repeat(num_inner_repeats).prefetch(5)) 119 120 iterator = ( 121 dataset_ops.Dataset.range(num_outer_repeats) 122 .interleave(interleave_fn, cycle_length=10, 123 block_length=len(input_list)) 124 .make_initializable_iterator()) 125 init_op = iterator.initializer 126 get_next = iterator.get_next() 127 128 with self.test_session() as sess: 129 sess.run(init_op) 130 for _ in range(num_inner_repeats * num_outer_repeats): 131 for elem in input_list: 132 val0, val1 = sess.run(get_next) 133 self.assertAllEqual(elem[0], val0) 134 self.assertAllEqual(elem[1], val1) 135 with self.assertRaises(errors.OutOfRangeError): 136 sess.run(get_next) 137 138 # TODO(b/67868766): Reenable this when the source of flakiness is discovered. 139 def _testFromGeneratorsRunningInParallel(self): 140 num_parallel_iterators = 3 141 142 # Define shared state that multiple iterator instances will access to 143 # demonstrate their concurrent activity. 144 lock = threading.Lock() 145 condition = threading.Condition(lock) 146 next_ticket = [0] # GUARDED_BY(lock) 147 148 def generator(): 149 # NOTE(mrry): We yield one element before the barrier, because 150 # the current implementation of `Dataset.interleave()` must 151 # fetch one element from each incoming dataset to start the 152 # prefetching. 153 yield 0 154 155 # Define a barrier that `num_parallel_iterators` iterators must enter 156 # before any can proceed. Demonstrates that multiple iterators may be 157 # active at the same time. 158 condition.acquire() 159 ticket = next_ticket[0] 160 next_ticket[0] += 1 161 if ticket == num_parallel_iterators - 1: 162 # The last iterator to join the barrier notifies the others. 163 condition.notify_all() 164 else: 165 # Wait until the last iterator enters the barrier. 166 while next_ticket[0] < num_parallel_iterators: 167 condition.wait() 168 condition.release() 169 170 yield 1 171 172 # As in `testFromMultipleConcurrentGenerators()`, we use a combination of 173 # `Dataset.interleave()` and `Dataset.prefetch()` to cause multiple 174 # iterators to be active concurrently. 175 def interleave_fn(_): 176 return dataset_ops.Dataset.from_generator( 177 generator, output_types=dtypes.int64, output_shapes=[]).prefetch(2) 178 179 iterator = ( 180 dataset_ops.Dataset.range(num_parallel_iterators) 181 .interleave( 182 interleave_fn, cycle_length=num_parallel_iterators, block_length=1) 183 .make_initializable_iterator()) 184 init_op = iterator.initializer 185 get_next = iterator.get_next() 186 187 with self.test_session() as sess: 188 sess.run(init_op) 189 for elem in [0, 1]: 190 for _ in range(num_parallel_iterators): 191 self.assertAllEqual(elem, sess.run(get_next)) 192 with self.assertRaises(errors.OutOfRangeError): 193 sess.run(get_next) 194 195 def testFromGeneratorImplicitConversion(self): 196 def generator(): 197 yield [1] 198 yield [2] 199 yield [3] 200 201 for dtype in [dtypes.int8, dtypes.int32, dtypes.int64]: 202 iterator = (dataset_ops.Dataset.from_generator( 203 generator, output_types=dtype, output_shapes=[1]) 204 .make_initializable_iterator()) 205 init_op = iterator.initializer 206 get_next = iterator.get_next() 207 208 self.assertEqual(dtype, get_next.dtype) 209 210 with self.test_session() as sess: 211 sess.run(init_op) 212 for expected in [[1], [2], [3]]: 213 next_val = sess.run(get_next) 214 self.assertEqual(dtype.as_numpy_dtype, next_val.dtype) 215 self.assertAllEqual(expected, next_val) 216 with self.assertRaises(errors.OutOfRangeError): 217 sess.run(get_next) 218 219 def testFromGeneratorString(self): 220 def generator(): 221 yield "foo" 222 yield b"bar" 223 yield u"baz" 224 225 iterator = (dataset_ops.Dataset.from_generator( 226 generator, output_types=dtypes.string, output_shapes=[]) 227 .make_initializable_iterator()) 228 init_op = iterator.initializer 229 get_next = iterator.get_next() 230 231 with self.test_session() as sess: 232 sess.run(init_op) 233 for expected in [b"foo", b"bar", b"baz"]: 234 next_val = sess.run(get_next) 235 self.assertAllEqual(expected, next_val) 236 with self.assertRaises(errors.OutOfRangeError): 237 sess.run(get_next) 238 239 def testFromGeneratorTypeError(self): 240 def generator(): 241 yield np.array([1, 2, 3], dtype=np.int64) 242 yield np.array([4, 5, 6], dtype=np.int64) 243 yield "ERROR" 244 yield np.array([7, 8, 9], dtype=np.int64) 245 246 iterator = (dataset_ops.Dataset.from_generator( 247 generator, output_types=dtypes.int64, output_shapes=[3]) 248 .make_initializable_iterator()) 249 init_op = iterator.initializer 250 get_next = iterator.get_next() 251 252 with self.test_session() as sess: 253 sess.run(init_op) 254 self.assertAllEqual([1, 2, 3], sess.run(get_next)) 255 self.assertAllEqual([4, 5, 6], sess.run(get_next)) 256 # NOTE(mrry): Type name in message differs between Python 2 (`long`) and 257 # 3 (`int`). 258 with self.assertRaisesOpError(r"invalid literal for"): 259 sess.run(get_next) 260 self.assertAllEqual([7, 8, 9], sess.run(get_next)) 261 with self.assertRaises(errors.OutOfRangeError): 262 sess.run(get_next) 263 264 def testFromGeneratorShapeError(self): 265 def generator(): 266 yield np.array([1, 2, 3], dtype=np.int64) 267 yield np.array([4, 5, 6], dtype=np.int64) 268 yield np.array([7, 8, 9, 10], dtype=np.int64) 269 yield np.array([11, 12, 13], dtype=np.int64) 270 271 iterator = (dataset_ops.Dataset.from_generator( 272 generator, output_types=dtypes.int64, output_shapes=[3]) 273 .make_initializable_iterator()) 274 init_op = iterator.initializer 275 get_next = iterator.get_next() 276 277 with self.test_session() as sess: 278 sess.run(init_op) 279 self.assertAllEqual([1, 2, 3], sess.run(get_next)) 280 self.assertAllEqual([4, 5, 6], sess.run(get_next)) 281 with self.assertRaisesOpError(r"element of shape \(3,\) was expected"): 282 sess.run(get_next) 283 self.assertAllEqual([11, 12, 13], sess.run(get_next)) 284 with self.assertRaises(errors.OutOfRangeError): 285 sess.run(get_next) 286 287 def testFromGeneratorHeterogeneous(self): 288 def generator(): 289 yield 1 290 yield [2, 3] 291 292 iterator = ( 293 dataset_ops.Dataset.from_generator( 294 generator, output_types=dtypes.int64).make_initializable_iterator()) 295 init_op = iterator.initializer 296 get_next = iterator.get_next() 297 298 with self.test_session() as sess: 299 sess.run(init_op) 300 self.assertAllEqual(1, sess.run(get_next)) 301 self.assertAllEqual([2, 3], sess.run(get_next)) 302 with self.assertRaises(errors.OutOfRangeError): 303 sess.run(get_next) 304 305 306 if __name__ == "__main__": 307 test.main() 308