Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from os import path
     21 import shutil
     22 import tempfile
     23 
     24 import numpy as np
     25 
     26 from tensorflow.python.data.ops import dataset_ops
     27 from tensorflow.python.data.ops import iterator_ops
     28 from tensorflow.python.framework import constant_op
     29 from tensorflow.python.framework import dtypes
     30 from tensorflow.python.framework import errors
     31 from tensorflow.python.framework import ops
     32 from tensorflow.python.ops import array_ops
     33 from tensorflow.python.ops import variables
     34 from tensorflow.python.platform import test
     35 
     36 
     37 class FilesystemCacheDatasetTest(test.TestCase):
     38 
     39   def setUp(self):
     40     self.tmp_dir = tempfile.mkdtemp()
     41     self.cache_prefix = path.join(self.tmp_dir, "cache")
     42 
     43   def tearDown(self):
     44     if self.tmp_dir:
     45       shutil.rmtree(self.tmp_dir, ignore_errors=True)
     46 
     47   def testCacheDatasetPassthrough(self):
     48     components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
     49                   np.array([9.0, 10.0, 11.0, 12.0]))
     50     count_placeholder = array_ops.placeholder_with_default(
     51         constant_op.constant(5, dtypes.int64), shape=[])
     52     filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
     53 
     54     repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components)
     55                       .repeat(count_placeholder))
     56 
     57     cache_dataset = repeat_dataset.cache(filename_placeholder)
     58 
     59     self.assertEqual(
     60         tuple([c.shape[1:] for c in components]), cache_dataset.output_shapes)
     61 
     62     # Create initialization ops for iterators without and with
     63     # caching, respectively.
     64     iterator = iterator_ops.Iterator.from_structure(cache_dataset.output_types,
     65                                                     cache_dataset.output_shapes)
     66     init_fifo_op = iterator.make_initializer(repeat_dataset)
     67     init_cache_op = iterator.make_initializer(cache_dataset)
     68 
     69     get_next = iterator.get_next()
     70 
     71     with self.test_session() as sess:
     72       # First run without caching to collect the "ground truth".
     73       sess.run(init_fifo_op)
     74       elements = []
     75       for _ in range(20):
     76         elements.append(sess.run(get_next))
     77       with self.assertRaises(errors.OutOfRangeError):
     78         sess.run(get_next)
     79 
     80       # Assert that the cached dataset has the same elements as the
     81       # "ground truth".
     82       sess.run(
     83           init_cache_op, feed_dict={filename_placeholder: self.cache_prefix})
     84       cached_elements = []
     85       for _ in range(20):
     86         cached_elements.append(sess.run(get_next))
     87       with self.assertRaises(errors.OutOfRangeError):
     88         sess.run(get_next)
     89       self.assertAllEqual(elements, cached_elements)
     90 
     91       # Re-initialize with an empty upstream (to throw errors.OutOfRangeError
     92       # if we didn't use the cache).
     93       sess.run(
     94           init_cache_op,
     95           feed_dict={
     96               count_placeholder: 0,
     97               filename_placeholder: self.cache_prefix
     98           })
     99       replayed_elements = []
    100       for _ in range(20):
    101         replayed_elements.append(sess.run(get_next))
    102       with self.assertRaises(errors.OutOfRangeError):
    103         sess.run(get_next)
    104       self.assertEqual(cached_elements, replayed_elements)
    105 
    106       # Re-initialize with an empty upstream and a missing cache file (should
    107       # throw errors.OutOfRangeError immediately).
    108       sess.run(
    109           init_cache_op,
    110           feed_dict={
    111               count_placeholder: 0,
    112               filename_placeholder: self.cache_prefix + "nonsense"
    113           })
    114       with self.assertRaises(errors.OutOfRangeError):
    115         sess.run(get_next)
    116 
    117   def testConcurrentWriters(self):
    118     components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
    119                   np.array([9.0, 10.0, 11.0, 12.0]))
    120     filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    121 
    122     cache_dataset1 = (dataset_ops.Dataset.from_tensor_slices(components)
    123                       .cache(filename_placeholder))
    124     cache_dataset2 = (dataset_ops.Dataset.from_tensor_slices(components)
    125                       .cache(filename_placeholder))
    126 
    127     iterator1 = cache_dataset1.make_initializable_iterator()
    128     iterator2 = cache_dataset2.make_initializable_iterator()
    129     init_cache_op1 = iterator1.initializer
    130     init_cache_op2 = iterator2.initializer
    131 
    132     get_next1 = iterator1.get_next()
    133     get_next2 = iterator2.get_next()
    134 
    135     with self.test_session() as sess:
    136       sess.run(
    137           init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix})
    138       sess.run(get_next1)  # this should succeed
    139 
    140       sess.run(
    141           init_cache_op2, feed_dict={filename_placeholder: self.cache_prefix})
    142       with self.assertRaises(errors.AlreadyExistsError):
    143         sess.run(get_next2)
    144 
    145       sess.run(get_next1)  # this should continue to succeed
    146 
    147   def testConcurrentReaders(self):
    148     components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
    149                   np.array([9.0, 10.0, 11.0, 12.0]))
    150     filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    151 
    152     cache_dataset1 = (dataset_ops.Dataset.from_tensor_slices(components)
    153                       .cache(filename_placeholder))
    154     cache_dataset2 = (dataset_ops.Dataset.from_tensor_slices(components)
    155                       .cache(filename_placeholder))
    156 
    157     iterator1 = cache_dataset1.make_initializable_iterator()
    158     iterator2 = cache_dataset2.make_initializable_iterator()
    159     init_cache_op1 = iterator1.initializer
    160     init_cache_op2 = iterator2.initializer
    161 
    162     get_next1 = iterator1.get_next()
    163     get_next2 = iterator2.get_next()
    164 
    165     with self.test_session() as sess:
    166       sess.run(
    167           init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix})
    168       elements = []
    169       for _ in range(4):
    170         elements.append(sess.run(get_next1))
    171       with self.assertRaises(errors.OutOfRangeError):
    172         sess.run(get_next1)
    173 
    174       # Re-initialize
    175       sess.run(
    176           init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix})
    177       sess.run(
    178           init_cache_op2, feed_dict={filename_placeholder: self.cache_prefix})
    179 
    180       # Reading concurrently should succeed.
    181       elements_itr1 = []
    182       elements_itr2 = []
    183       elements_itr2.append(sess.run(get_next2))
    184       elements_itr1.append(sess.run(get_next1))
    185       elements_itr2.append(sess.run(get_next2))
    186       elements_itr1.append(sess.run(get_next1))
    187       # Intentionally reversing the order
    188       elements_itr1.append(sess.run(get_next1))
    189       elements_itr2.append(sess.run(get_next2))
    190       elements_itr1.append(sess.run(get_next1))
    191       elements_itr2.append(sess.run(get_next2))
    192 
    193       with self.assertRaises(errors.OutOfRangeError):
    194         sess.run(get_next2)
    195 
    196       with self.assertRaises(errors.OutOfRangeError):
    197         sess.run(get_next1)
    198 
    199       self.assertAllEqual(elements, elements_itr1)
    200       self.assertAllEqual(elements, elements_itr2)
    201 
    202 
    203 class MemoryCacheDatasetTest(test.TestCase):
    204 
    205   def testCacheDatasetPassthrough(self):
    206     with ops.device("cpu:0"):
    207       repeat_count = variables.Variable(constant_op.constant(10, dtypes.int64))
    208       dataset = dataset_ops.Dataset.range(3).flat_map(
    209           lambda x: dataset_ops.Dataset.from_tensors(x).repeat(repeat_count))
    210 
    211       cached_dataset = dataset.cache().repeat(2)
    212       uncached_dataset = dataset.repeat(2)
    213 
    214       # Needs to be initializable to capture the variable.
    215       cached_iterator = cached_dataset.make_initializable_iterator()
    216       cached_next = cached_iterator.get_next()
    217       uncached_iterator = uncached_dataset.make_initializable_iterator()
    218       uncached_next = uncached_iterator.get_next()
    219 
    220       with self.test_session() as sess:
    221 
    222         sess.run(repeat_count.initializer)
    223         sess.run(cached_iterator.initializer)
    224         sess.run(uncached_iterator.initializer)
    225 
    226         for i in range(3):
    227           for _ in range(10):
    228             self.assertEqual(sess.run(cached_next), i)
    229             self.assertEqual(sess.run(uncached_next), i)
    230 
    231         sess.run(repeat_count.assign(0))
    232 
    233         # The uncached iterator should now be empty.
    234         with self.assertRaises(errors.OutOfRangeError):
    235           sess.run(uncached_next)
    236 
    237         # The cached iterator replays from cache.
    238         for i in range(3):
    239           for _ in range(10):
    240             self.assertEqual(sess.run(cached_next), i)
    241 
    242         # The cached iterator should now be empty.
    243         with self.assertRaises(errors.OutOfRangeError):
    244           sess.run(cached_next)
    245 
    246   def testEmptyCacheReading(self):
    247     components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
    248                   np.array([9.0, 10.0, 11.0, 12.0]))
    249     count_placeholder = array_ops.placeholder_with_default(
    250         constant_op.constant(5, dtypes.int64), shape=[])
    251 
    252     repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components)
    253                       .repeat(count_placeholder))
    254 
    255     cache_dataset = repeat_dataset.cache()
    256 
    257     # Create initialization ops for iterators without and with
    258     # caching, respectively.
    259     iterator = cache_dataset.make_initializable_iterator()
    260     init_cache_op = iterator.initializer
    261 
    262     get_next = iterator.get_next()
    263 
    264     with self.test_session() as sess:
    265       # Initialize with an empty upstream and a missing cache file (should
    266       # throw errors.OutOfRangeError immediately).
    267       sess.run(init_cache_op, feed_dict={count_placeholder: 0})
    268       with self.assertRaises(errors.OutOfRangeError):
    269         sess.run(get_next)
    270 
    271   def testConcurrentReaders(self):
    272     count_placeholder = array_ops.placeholder_with_default(
    273         constant_op.constant(5, dtypes.int64), shape=[])
    274     dataset = dataset_ops.Dataset.range(count_placeholder).cache()
    275     d1 = dataset.map(lambda x: x + 1)
    276     d2 = dataset.map(lambda x: x + 6)
    277 
    278     i1 = d1.make_initializable_iterator()
    279     i2 = d2.make_initializable_iterator()
    280 
    281     with self.test_session() as sess:
    282       sess.run(i1.initializer)
    283 
    284       self.assertEqual(1, sess.run(i1.get_next()))
    285       self.assertEqual(2, sess.run(i1.get_next()))
    286       self.assertEqual(3, sess.run(i1.get_next()))
    287 
    288       sess.run(i2.initializer, feed_dict={count_placeholder: 3})
    289 
    290       self.assertEqual(6, sess.run(i2.get_next()))
    291       self.assertEqual(7, sess.run(i2.get_next()))
    292       self.assertEqual(4, sess.run(i1.get_next()))  # interleave execution
    293       self.assertEqual([8, 5], sess.run([i2.get_next(), i1.get_next()]))
    294 
    295       with self.assertRaises(errors.OutOfRangeError):
    296         sess.run(i1.get_next())
    297       with self.assertRaises(errors.OutOfRangeError):
    298         sess.run(i2.get_next())
    299 
    300 
    301 if __name__ == "__main__":
    302   test.main()
    303