Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.python.data.ops import dataset_ops
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import errors
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.platform import test
     27 
     28 
     29 class SequenceDatasetTest(test.TestCase):
     30 
     31   def testRepeatTensorDataset(self):
     32     """Test a dataset that repeats its input multiple times."""
     33     components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
     34     # This placeholder can be fed when dataset-definition subgraph
     35     # runs (i.e. `init_op` below) to configure the number of
     36     # repetitions used in a particular iterator.
     37     count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
     38 
     39     iterator = (dataset_ops.Dataset.from_tensors(components)
     40                 .repeat(count_placeholder).make_initializable_iterator())
     41     init_op = iterator.initializer
     42     get_next = iterator.get_next()
     43 
     44     self.assertEqual([c.shape for c in components],
     45                      [t.shape for t in get_next])
     46 
     47     with self.test_session() as sess:
     48       # Test a finite repetition.
     49       sess.run(init_op, feed_dict={count_placeholder: 3})
     50       for _ in range(3):
     51         results = sess.run(get_next)
     52         for component, result_component in zip(components, results):
     53           self.assertAllEqual(component, result_component)
     54 
     55       with self.assertRaises(errors.OutOfRangeError):
     56         sess.run(get_next)
     57 
     58       # Test a different finite repetition.
     59       sess.run(init_op, feed_dict={count_placeholder: 7})
     60       for _ in range(7):
     61         results = sess.run(get_next)
     62         for component, result_component in zip(components, results):
     63           self.assertAllEqual(component, result_component)
     64       with self.assertRaises(errors.OutOfRangeError):
     65         sess.run(get_next)
     66 
     67       # Test an empty repetition.
     68       sess.run(init_op, feed_dict={count_placeholder: 0})
     69       with self.assertRaises(errors.OutOfRangeError):
     70         sess.run(get_next)
     71 
     72       # Test an infinite repetition.
     73       # NOTE(mrry): There's not a good way to test that the sequence
     74       # actually is infinite.
     75       sess.run(init_op, feed_dict={count_placeholder: -1})
     76       for _ in range(17):
     77         results = sess.run(get_next)
     78         for component, result_component in zip(components, results):
     79           self.assertAllEqual(component, result_component)
     80 
     81   def testTakeTensorDataset(self):
     82     components = (np.arange(10),)
     83     count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
     84 
     85     iterator = (dataset_ops.Dataset.from_tensor_slices(components)
     86                 .take(count_placeholder).make_initializable_iterator())
     87     init_op = iterator.initializer
     88     get_next = iterator.get_next()
     89 
     90     self.assertEqual([c.shape[1:] for c in components],
     91                      [t.shape for t in get_next])
     92 
     93     with self.test_session() as sess:
     94       # Take fewer than input size
     95       sess.run(init_op, feed_dict={count_placeholder: 4})
     96       for i in range(4):
     97         results = sess.run(get_next)
     98         self.assertAllEqual(results, components[0][i:i+1])
     99 
    100       with self.assertRaises(errors.OutOfRangeError):
    101         sess.run(get_next)
    102 
    103       # Take more than input size
    104       sess.run(init_op, feed_dict={count_placeholder: 25})
    105       for i in range(10):
    106         results = sess.run(get_next)
    107         self.assertAllEqual(results, components[0][i:i+1])
    108 
    109       with self.assertRaises(errors.OutOfRangeError):
    110         sess.run(get_next)
    111 
    112       # Take all of input
    113       sess.run(init_op, feed_dict={count_placeholder: -1})
    114       for i in range(10):
    115         results = sess.run(get_next)
    116         self.assertAllEqual(results, components[0][i:i+1])
    117 
    118       with self.assertRaises(errors.OutOfRangeError):
    119         sess.run(get_next)
    120 
    121       # Take nothing
    122       sess.run(init_op, feed_dict={count_placeholder: 0})
    123 
    124       with self.assertRaises(errors.OutOfRangeError):
    125         sess.run(get_next)
    126 
    127   def testSkipTensorDataset(self):
    128     components = (np.arange(10),)
    129     count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
    130 
    131     iterator = (dataset_ops.Dataset.from_tensor_slices(components)
    132                 .skip(count_placeholder).make_initializable_iterator())
    133     init_op = iterator.initializer
    134     get_next = iterator.get_next()
    135 
    136     self.assertEqual([c.shape[1:] for c in components],
    137                      [t.shape for t in get_next])
    138 
    139     with self.test_session() as sess:
    140       # Skip fewer than input size, we should skip
    141       # the first 4 elements and then read the rest.
    142       sess.run(init_op, feed_dict={count_placeholder: 4})
    143       for i in range(4, 10):
    144         results = sess.run(get_next)
    145         self.assertAllEqual(results, components[0][i:i+1])
    146       with self.assertRaises(errors.OutOfRangeError):
    147         sess.run(get_next)
    148 
    149       # Skip more than input size: get nothing.
    150       sess.run(init_op, feed_dict={count_placeholder: 25})
    151       with self.assertRaises(errors.OutOfRangeError):
    152         sess.run(get_next)
    153 
    154       # Skip exactly input size.
    155       sess.run(init_op, feed_dict={count_placeholder: 10})
    156       with self.assertRaises(errors.OutOfRangeError):
    157         sess.run(get_next)
    158 
    159       # Set -1 for 'count': skip the entire dataset.
    160       sess.run(init_op, feed_dict={count_placeholder: -1})
    161       with self.assertRaises(errors.OutOfRangeError):
    162         sess.run(get_next)
    163 
    164       # Skip nothing
    165       sess.run(init_op, feed_dict={count_placeholder: 0})
    166       for i in range(0, 10):
    167         results = sess.run(get_next)
    168         self.assertAllEqual(results, components[0][i:i+1])
    169       with self.assertRaises(errors.OutOfRangeError):
    170         sess.run(get_next)
    171 
    172   def testRepeatRepeatTensorDataset(self):
    173     """Test the composition of repeat datasets."""
    174     components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
    175     inner_count = array_ops.placeholder(dtypes.int64, shape=[])
    176     outer_count = array_ops.placeholder(dtypes.int64, shape=[])
    177 
    178     iterator = (dataset_ops.Dataset.from_tensors(components).repeat(inner_count)
    179                 .repeat(outer_count).make_initializable_iterator())
    180     init_op = iterator.initializer
    181     get_next = iterator.get_next()
    182 
    183     self.assertEqual([c.shape for c in components],
    184                      [t.shape for t in get_next])
    185 
    186     with self.test_session() as sess:
    187       sess.run(init_op, feed_dict={inner_count: 7, outer_count: 14})
    188       for _ in range(7 * 14):
    189         results = sess.run(get_next)
    190         for component, result_component in zip(components, results):
    191           self.assertAllEqual(component, result_component)
    192       with self.assertRaises(errors.OutOfRangeError):
    193         sess.run(get_next)
    194 
    195   def testRepeatEmptyDataset(self):
    196     """Test that repeating an empty dataset does not hang."""
    197     iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10).skip(10)
    198                 .repeat(-1).make_initializable_iterator())
    199     init_op = iterator.initializer
    200     get_next = iterator.get_next()
    201 
    202     with self.test_session() as sess:
    203       sess.run(init_op)
    204       with self.assertRaises(errors.OutOfRangeError):
    205         sess.run(get_next)
    206 
    207 
    208 if __name__ == "__main__":
    209   test.main()
    210