Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import itertools
     21 
     22 from tensorflow.python.data.ops import dataset_ops
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import errors
     25 from tensorflow.python.framework import sparse_tensor
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import sparse_ops
     28 from tensorflow.python.platform import test
     29 
     30 
     31 class InterleaveDatasetTest(test.TestCase):
     32 
     33   def _interleave(self, lists, cycle_length, block_length):
     34     num_open = 0
     35 
     36     # `all_iterators` acts as a queue of iterators over each element of `lists`.
     37     all_iterators = [iter(l) for l in lists]
     38 
     39     # `open_iterators` are the iterators whose elements are currently being
     40     # interleaved.
     41     open_iterators = []
     42     for i in range(cycle_length):
     43       if all_iterators:
     44         open_iterators.append(all_iterators.pop(0))
     45         num_open += 1
     46       else:
     47         open_iterators.append(None)
     48 
     49     while num_open or all_iterators:
     50       for i in range(cycle_length):
     51         if open_iterators[i] is None:
     52           if all_iterators:
     53             open_iterators[i] = all_iterators.pop(0)
     54             num_open += 1
     55           else:
     56             continue
     57         for _ in range(block_length):
     58           try:
     59             yield next(open_iterators[i])
     60           except StopIteration:
     61             open_iterators[i] = None
     62             num_open -= 1
     63             break
     64 
     65   def testPythonImplementation(self):
     66     input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
     67                    [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]
     68 
     69     # Cycle length 1 acts like `Dataset.flat_map()`.
     70     expected_elements = itertools.chain(*input_lists)
     71     for expected, produced in zip(
     72         expected_elements, self._interleave(input_lists, 1, 1)):
     73       self.assertEqual(expected, produced)
     74 
     75     # Cycle length > 1.
     76     expected_elements = [4, 5, 4, 5, 4, 5, 4,
     77                          5, 5, 6, 6,  # NOTE(mrry): When we cycle back
     78                                       # to a list and are already at
     79                                       # the end of that list, we move
     80                                       # on to the next element.
     81                          4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5]
     82     for expected, produced in zip(
     83         expected_elements, self._interleave(input_lists, 2, 1)):
     84       self.assertEqual(expected, produced)
     85 
     86     # Cycle length > 1 and block length > 1.
     87     expected_elements = [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6,
     88                          4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6]
     89     for expected, produced in zip(
     90         expected_elements, self._interleave(input_lists, 2, 3)):
     91       self.assertEqual(expected, produced)
     92 
     93     # Cycle length > len(input_values).
     94     expected_elements = [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6,
     95                          4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6]
     96     for expected, produced in zip(
     97         expected_elements, self._interleave(input_lists, 7, 2)):
     98       self.assertEqual(expected, produced)
     99 
    100   def testInterleaveDataset(self):
    101     input_values = array_ops.placeholder(dtypes.int64, shape=[None])
    102     cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
    103     block_length = array_ops.placeholder(dtypes.int64, shape=[])
    104 
    105     repeat_count = 2
    106 
    107     dataset = (
    108         dataset_ops.Dataset.from_tensor_slices(input_values)
    109         .repeat(repeat_count)
    110         .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
    111                     cycle_length, block_length))
    112     iterator = dataset.make_initializable_iterator()
    113     init_op = iterator.initializer
    114     next_element = iterator.get_next()
    115 
    116     with self.test_session() as sess:
    117       # Cycle length 1 acts like `Dataset.flat_map()`.
    118       sess.run(init_op, feed_dict={input_values: [4, 5, 6],
    119                                    cycle_length: 1, block_length: 3})
    120 
    121       for expected_element in self._interleave(
    122           [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3):
    123         self.assertEqual(expected_element, sess.run(next_element))
    124 
    125       # Cycle length > 1.
    126       # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5,
    127       #            6, 5, 6, 5, 6, 5, 6, 5]
    128       sess.run(init_op, feed_dict={input_values: [4, 5, 6],
    129                                    cycle_length: 2, block_length: 1})
    130       for expected_element in self._interleave(
    131           [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1):
    132         self.assertEqual(expected_element, sess.run(next_element))
    133       with self.assertRaises(errors.OutOfRangeError):
    134         sess.run(next_element)
    135 
    136       # Cycle length > 1 and block length > 1.
    137       # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5,
    138       #            5, 5, 6, 6, 6, 5, 5, 6, 6, 6]
    139       sess.run(init_op, feed_dict={input_values: [4, 5, 6],
    140                                    cycle_length: 2, block_length: 3})
    141       for expected_element in self._interleave(
    142           [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3):
    143         self.assertEqual(expected_element, sess.run(next_element))
    144       with self.assertRaises(errors.OutOfRangeError):
    145         sess.run(next_element)
    146 
    147       # Cycle length > len(input_values) * repeat_count.
    148       # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4,
    149       #            5, 5, 6, 6, 5, 6, 6, 5, 6, 6]
    150       sess.run(init_op, feed_dict={input_values: [4, 5, 6],
    151                                    cycle_length: 7, block_length: 2})
    152       for expected_element in self._interleave(
    153           [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2):
    154         self.assertEqual(expected_element, sess.run(next_element))
    155       with self.assertRaises(errors.OutOfRangeError):
    156         sess.run(next_element)
    157 
    158       # Empty input.
    159       sess.run(init_op, feed_dict={input_values: [],
    160                                    cycle_length: 2, block_length: 3})
    161       with self.assertRaises(errors.OutOfRangeError):
    162         sess.run(next_element)
    163 
    164       # Non-empty input leading to empty output.
    165       sess.run(init_op, feed_dict={input_values: [0, 0, 0],
    166                                    cycle_length: 2, block_length: 3})
    167       with self.assertRaises(errors.OutOfRangeError):
    168         sess.run(next_element)
    169 
    170       # Mixture of non-empty and empty interleaved datasets.
    171       sess.run(init_op, feed_dict={input_values: [4, 0, 6],
    172                                    cycle_length: 2, block_length: 3})
    173       for expected_element in self._interleave(
    174           [[4] * 4, [], [6] * 6] * repeat_count, 2, 3):
    175         self.assertEqual(expected_element, sess.run(next_element))
    176       with self.assertRaises(errors.OutOfRangeError):
    177         sess.run(next_element)
    178 
    179   def testSparse(self):
    180 
    181     def _map_fn(i):
    182       return sparse_tensor.SparseTensorValue(
    183           indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
    184 
    185     def _interleave_fn(x):
    186       return dataset_ops.Dataset.from_tensor_slices(
    187           sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
    188 
    189     iterator = (
    190         dataset_ops.Dataset.range(10).map(_map_fn).interleave(
    191             _interleave_fn, cycle_length=1).make_initializable_iterator())
    192     init_op = iterator.initializer
    193     get_next = iterator.get_next()
    194 
    195     with self.test_session() as sess:
    196       sess.run(init_op)
    197       for i in range(10):
    198         for j in range(2):
    199           expected = [i, 0] if j % 2 == 0 else [0, -i]
    200           self.assertAllEqual(expected, sess.run(get_next))
    201       with self.assertRaises(errors.OutOfRangeError):
    202         sess.run(get_next)
    203 
    204   def testEmptyInput(self):
    205     iterator = (
    206         dataset_ops.Dataset.from_tensor_slices([])
    207         .repeat(None)
    208         .interleave(dataset_ops.Dataset.from_tensors, cycle_length=2)
    209         .make_initializable_iterator())
    210     init_op = iterator.initializer
    211     get_next = iterator.get_next()
    212 
    213     with self.test_session() as sess:
    214       sess.run(init_op)
    215       with self.assertRaises(errors.OutOfRangeError):
    216         sess.run(get_next)
    217 
    218 
    219 if __name__ == "__main__":
    220   test.main()
    221