Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import threading
     21 
     22 import numpy as np
     23 
     24 from tensorflow.python.data.ops import dataset_ops
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import errors
     27 from tensorflow.python.platform import test
     28 
     29 
     30 class DatasetConstructorTest(test.TestCase):
     31 
     32   def _testFromGenerator(self, generator, elem_sequence, num_repeats):
     33     iterator = (
     34         dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64)
     35         .repeat(num_repeats)
     36         .prefetch(5)
     37         .make_initializable_iterator())
     38     init_op = iterator.initializer
     39     get_next = iterator.get_next()
     40 
     41     with self.test_session() as sess:
     42       for _ in range(2):  # Run twice to test reinitialization.
     43         sess.run(init_op)
     44         for _ in range(num_repeats):
     45           for elem in elem_sequence:
     46             self.assertAllEqual(elem, sess.run(get_next))
     47         with self.assertRaises(errors.OutOfRangeError):
     48           sess.run(get_next)
     49 
     50   def _testFromGeneratorOneShot(self, generator, elem_sequence, num_repeats):
     51     iterator = (
     52         dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64)
     53         .repeat(num_repeats)
     54         .prefetch(5)
     55         .make_one_shot_iterator())
     56     get_next = iterator.get_next()
     57 
     58     with self.test_session() as sess:
     59       for _ in range(num_repeats):
     60         for elem in elem_sequence:
     61           self.assertAllEqual(elem, sess.run(get_next))
     62       with self.assertRaises(errors.OutOfRangeError):
     63         sess.run(get_next)
     64 
     65   def testFromGeneratorUsingFunction(self):
     66     def generator():
     67       for i in range(1, 100):
     68         yield [i] * i
     69     elem_sequence = list(generator())
     70     self._testFromGenerator(generator, elem_sequence, 1)
     71     self._testFromGenerator(generator, elem_sequence, 5)
     72     self._testFromGeneratorOneShot(generator, elem_sequence, 1)
     73     self._testFromGeneratorOneShot(generator, elem_sequence, 5)
     74 
     75   def testFromGeneratorUsingList(self):
     76     generator = lambda: [[i] * i for i in range(1, 100)]
     77     elem_sequence = list(generator())
     78     self._testFromGenerator(generator, elem_sequence, 1)
     79     self._testFromGenerator(generator, elem_sequence, 5)
     80 
     81   def testFromGeneratorUsingNdarray(self):
     82     generator = lambda: np.arange(100, dtype=np.int64)
     83     elem_sequence = list(generator())
     84     self._testFromGenerator(generator, elem_sequence, 1)
     85     self._testFromGenerator(generator, elem_sequence, 5)
     86 
     87   def testFromGeneratorUsingGeneratorExpression(self):
     88     # NOTE(mrry): Generator *expressions* are not repeatable (or in
     89     # general reusable), because they eagerly evaluate the `for`
     90     # expression as `iter(range(1, 100))` and discard the means of
     91     # reconstructing `range(1, 100)`. Wrapping the generator
     92     # expression in a `lambda` makes it repeatable.
     93     generator = lambda: ([i] * i for i in range(1, 100))
     94     elem_sequence = list(generator())
     95     self._testFromGenerator(generator, elem_sequence, 1)
     96     self._testFromGenerator(generator, elem_sequence, 5)
     97 
     98   def testFromMultipleConcurrentGenerators(self):
     99     num_inner_repeats = 5
    100     num_outer_repeats = 100
    101 
    102     def generator():
    103       for i in range(1, 10):
    104         yield ([i] * i, [i, i ** 2, i ** 3])
    105     input_list = list(generator())
    106 
    107     # The interleave transformation is essentially a flat map that
    108     # draws from multiple input datasets concurrently (in a cyclic
    109     # fashion). By placing `Datsaet.from_generator()` inside an
    110     # interleave, we test its behavior when multiple iterators are
    111     # active at the same time; by additionally prefetching inside the
    112     # interleave, we create the possibility of parallel (modulo GIL)
    113     # invocations to several iterators created by the same dataset.
    114     def interleave_fn(_):
    115       return (dataset_ops.Dataset.from_generator(
    116           generator, output_types=(dtypes.int64, dtypes.int64),
    117           output_shapes=([None], [3]))
    118               .repeat(num_inner_repeats).prefetch(5))
    119 
    120     iterator = (
    121         dataset_ops.Dataset.range(num_outer_repeats)
    122         .interleave(interleave_fn, cycle_length=10,
    123                     block_length=len(input_list))
    124         .make_initializable_iterator())
    125     init_op = iterator.initializer
    126     get_next = iterator.get_next()
    127 
    128     with self.test_session() as sess:
    129       sess.run(init_op)
    130       for _ in range(num_inner_repeats * num_outer_repeats):
    131         for elem in input_list:
    132           val0, val1 = sess.run(get_next)
    133           self.assertAllEqual(elem[0], val0)
    134           self.assertAllEqual(elem[1], val1)
    135       with self.assertRaises(errors.OutOfRangeError):
    136         sess.run(get_next)
    137 
    138   # TODO(b/67868766): Reenable this when the source of flakiness is discovered.
    139   def _testFromGeneratorsRunningInParallel(self):
    140     num_parallel_iterators = 3
    141 
    142     # Define shared state that multiple iterator instances will access to
    143     # demonstrate their concurrent activity.
    144     lock = threading.Lock()
    145     condition = threading.Condition(lock)
    146     next_ticket = [0]  # GUARDED_BY(lock)
    147 
    148     def generator():
    149       # NOTE(mrry): We yield one element before the barrier, because
    150       # the current implementation of `Dataset.interleave()` must
    151       # fetch one element from each incoming dataset to start the
    152       # prefetching.
    153       yield 0
    154 
    155       # Define a barrier that `num_parallel_iterators` iterators must enter
    156       # before any can proceed. Demonstrates that multiple iterators may be
    157       # active at the same time.
    158       condition.acquire()
    159       ticket = next_ticket[0]
    160       next_ticket[0] += 1
    161       if ticket == num_parallel_iterators - 1:
    162         # The last iterator to join the barrier notifies the others.
    163         condition.notify_all()
    164       else:
    165         # Wait until the last iterator enters the barrier.
    166         while next_ticket[0] < num_parallel_iterators:
    167           condition.wait()
    168       condition.release()
    169 
    170       yield 1
    171 
    172     # As in `testFromMultipleConcurrentGenerators()`, we use a combination of
    173     # `Dataset.interleave()` and `Dataset.prefetch()` to cause multiple
    174     # iterators to be active concurrently.
    175     def interleave_fn(_):
    176       return dataset_ops.Dataset.from_generator(
    177           generator, output_types=dtypes.int64, output_shapes=[]).prefetch(2)
    178 
    179     iterator = (
    180         dataset_ops.Dataset.range(num_parallel_iterators)
    181         .interleave(
    182             interleave_fn, cycle_length=num_parallel_iterators, block_length=1)
    183         .make_initializable_iterator())
    184     init_op = iterator.initializer
    185     get_next = iterator.get_next()
    186 
    187     with self.test_session() as sess:
    188       sess.run(init_op)
    189       for elem in [0, 1]:
    190         for _ in range(num_parallel_iterators):
    191           self.assertAllEqual(elem, sess.run(get_next))
    192       with self.assertRaises(errors.OutOfRangeError):
    193         sess.run(get_next)
    194 
    195   def testFromGeneratorImplicitConversion(self):
    196     def generator():
    197       yield [1]
    198       yield [2]
    199       yield [3]
    200 
    201     for dtype in [dtypes.int8, dtypes.int32, dtypes.int64]:
    202       iterator = (dataset_ops.Dataset.from_generator(
    203           generator, output_types=dtype, output_shapes=[1])
    204                   .make_initializable_iterator())
    205       init_op = iterator.initializer
    206       get_next = iterator.get_next()
    207 
    208       self.assertEqual(dtype, get_next.dtype)
    209 
    210       with self.test_session() as sess:
    211         sess.run(init_op)
    212         for expected in [[1], [2], [3]]:
    213           next_val = sess.run(get_next)
    214           self.assertEqual(dtype.as_numpy_dtype, next_val.dtype)
    215           self.assertAllEqual(expected, next_val)
    216         with self.assertRaises(errors.OutOfRangeError):
    217           sess.run(get_next)
    218 
    219   def testFromGeneratorString(self):
    220     def generator():
    221       yield "foo"
    222       yield b"bar"
    223       yield u"baz"
    224 
    225     iterator = (dataset_ops.Dataset.from_generator(
    226         generator, output_types=dtypes.string, output_shapes=[])
    227                 .make_initializable_iterator())
    228     init_op = iterator.initializer
    229     get_next = iterator.get_next()
    230 
    231     with self.test_session() as sess:
    232       sess.run(init_op)
    233       for expected in [b"foo", b"bar", b"baz"]:
    234         next_val = sess.run(get_next)
    235         self.assertAllEqual(expected, next_val)
    236       with self.assertRaises(errors.OutOfRangeError):
    237         sess.run(get_next)
    238 
    239   def testFromGeneratorTypeError(self):
    240     def generator():
    241       yield np.array([1, 2, 3], dtype=np.int64)
    242       yield np.array([4, 5, 6], dtype=np.int64)
    243       yield "ERROR"
    244       yield np.array([7, 8, 9], dtype=np.int64)
    245 
    246     iterator = (dataset_ops.Dataset.from_generator(
    247         generator, output_types=dtypes.int64, output_shapes=[3])
    248                 .make_initializable_iterator())
    249     init_op = iterator.initializer
    250     get_next = iterator.get_next()
    251 
    252     with self.test_session() as sess:
    253       sess.run(init_op)
    254       self.assertAllEqual([1, 2, 3], sess.run(get_next))
    255       self.assertAllEqual([4, 5, 6], sess.run(get_next))
    256       # NOTE(mrry): Type name in message differs between Python 2 (`long`) and
    257       # 3 (`int`).
    258       with self.assertRaisesOpError(r"invalid literal for"):
    259         sess.run(get_next)
    260       self.assertAllEqual([7, 8, 9], sess.run(get_next))
    261       with self.assertRaises(errors.OutOfRangeError):
    262         sess.run(get_next)
    263 
    264   def testFromGeneratorShapeError(self):
    265     def generator():
    266       yield np.array([1, 2, 3], dtype=np.int64)
    267       yield np.array([4, 5, 6], dtype=np.int64)
    268       yield np.array([7, 8, 9, 10], dtype=np.int64)
    269       yield np.array([11, 12, 13], dtype=np.int64)
    270 
    271     iterator = (dataset_ops.Dataset.from_generator(
    272         generator, output_types=dtypes.int64, output_shapes=[3])
    273                 .make_initializable_iterator())
    274     init_op = iterator.initializer
    275     get_next = iterator.get_next()
    276 
    277     with self.test_session() as sess:
    278       sess.run(init_op)
    279       self.assertAllEqual([1, 2, 3], sess.run(get_next))
    280       self.assertAllEqual([4, 5, 6], sess.run(get_next))
    281       with self.assertRaisesOpError(r"element of shape \(3,\) was expected"):
    282         sess.run(get_next)
    283       self.assertAllEqual([11, 12, 13], sess.run(get_next))
    284       with self.assertRaises(errors.OutOfRangeError):
    285         sess.run(get_next)
    286 
    287   def testFromGeneratorHeterogeneous(self):
    288     def generator():
    289       yield 1
    290       yield [2, 3]
    291 
    292     iterator = (
    293         dataset_ops.Dataset.from_generator(
    294             generator, output_types=dtypes.int64).make_initializable_iterator())
    295     init_op = iterator.initializer
    296     get_next = iterator.get_next()
    297 
    298     with self.test_session() as sess:
    299       sess.run(init_op)
    300       self.assertAllEqual(1, sess.run(get_next))
    301       self.assertAllEqual([2, 3], sess.run(get_next))
    302       with self.assertRaises(errors.OutOfRangeError):
    303         sess.run(get_next)
    304 
    305 
    306 if __name__ == "__main__":
    307   test.main()
    308