Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for `tf.data.Dataset.cache()`."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from os import path
     21 import shutil
     22 import tempfile
     23 
     24 import numpy as np
     25 
     26 from tensorflow.python.data.kernel_tests import test_base
     27 from tensorflow.python.data.ops import dataset_ops
     28 from tensorflow.python.framework import constant_op
     29 from tensorflow.python.framework import dtypes
     30 from tensorflow.python.framework import errors
     31 from tensorflow.python.framework import ops
     32 from tensorflow.python.framework import test_util
     33 from tensorflow.python.ops import variables
     34 from tensorflow.python.platform import test
     35 
     36 
     37 @test_util.run_all_in_graph_and_eager_modes
     38 class FileCacheTest(test_base.DatasetTestBase):
     39 
     40   def setUp(self):
     41     self.tmp_dir = tempfile.mkdtemp()
     42     self.cache_prefix = path.join(self.tmp_dir, "cache")
     43 
     44   def tearDown(self):
     45     if self.tmp_dir:
     46       shutil.rmtree(self.tmp_dir, ignore_errors=True)
     47 
     48   def testCacheDatasetPassthrough(self):
     49     components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
     50                   np.array([9.0, 10.0, 11.0, 12.0]))
     51 
     52     def dataset_fn(count=5, filename=None):
     53       repeat_dataset = (
     54           dataset_ops.Dataset.from_tensor_slices(components).repeat(count))
     55       if filename:
     56         return repeat_dataset.cache(filename)
     57       else:
     58         return repeat_dataset
     59 
     60     self.assertEqual(
     61         tuple([c.shape[1:] for c in components]),
     62         dataset_ops.get_legacy_output_shapes(dataset_fn()))
     63 
     64     get_next = self.getNext(dataset_fn())
     65 
     66     # First run without caching to collect the "ground truth".
     67     elements = []
     68     for _ in range(20):
     69       elements.append(self.evaluate(get_next()))
     70     with self.assertRaises(errors.OutOfRangeError):
     71       self.evaluate(get_next())
     72 
     73     # Assert that the cached dataset has the same elements as the
     74     # "ground truth".
     75     get_next = self.getNext(dataset_fn(filename=self.cache_prefix))
     76     cached_elements = []
     77     for _ in range(20):
     78       cached_elements.append(self.evaluate(get_next()))
     79     with self.assertRaises(errors.OutOfRangeError):
     80       self.evaluate(get_next())
     81     self.assertAllEqual(elements, cached_elements)
     82 
     83     # Re-initialize with an empty upstream (to throw errors.OutOfRangeError
     84     # if we didn't use the cache).
     85     get_next = self.getNext(dataset_fn(count=0, filename=self.cache_prefix))
     86     replayed_elements = []
     87     for _ in range(20):
     88       replayed_elements.append(self.evaluate(get_next()))
     89     with self.assertRaises(errors.OutOfRangeError):
     90       self.evaluate(get_next())
     91     self.assertEqual(cached_elements, replayed_elements)
     92 
     93     # Re-initialize with an empty upstream and a missing cache file (should
     94     # throw errors.OutOfRangeError immediately).
     95     get_next = self.getNext(
     96         dataset_fn(count=0, filename=self.cache_prefix + "nonsense"))
     97     with self.assertRaises(errors.OutOfRangeError):
     98       self.evaluate(get_next())
     99 
    100   def testConcurrentWriters(self):
    101     components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
    102                   np.array([9.0, 10.0, 11.0, 12.0]))
    103 
    104     cache_dataset1 = (
    105         dataset_ops.Dataset.from_tensor_slices(components).cache(
    106             self.cache_prefix))
    107     cache_dataset2 = (
    108         dataset_ops.Dataset.from_tensor_slices(components).cache(
    109             self.cache_prefix))
    110 
    111     get_next1 = self.getNext(cache_dataset1)
    112     get_next2 = self.getNext(cache_dataset2)
    113 
    114     self.evaluate(get_next1())  # this should succeed
    115 
    116     with self.assertRaises(errors.AlreadyExistsError):
    117       self.evaluate(get_next2())
    118 
    119     self.evaluate(get_next1())  # this should continue to succeed
    120 
    121   def testConcurrentReaders(self):
    122     components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
    123                   np.array([9.0, 10.0, 11.0, 12.0]))
    124 
    125     cache_dataset1 = (
    126         dataset_ops.Dataset.from_tensor_slices(components).cache(
    127             self.cache_prefix))
    128     cache_dataset2 = (
    129         dataset_ops.Dataset.from_tensor_slices(components).cache(
    130             self.cache_prefix))
    131 
    132     get_next1 = self.getNext(cache_dataset1)
    133     get_next2 = self.getNext(cache_dataset2)
    134 
    135     elements = []
    136     for _ in range(4):
    137       elements.append(self.evaluate(get_next1()))
    138     with self.assertRaises(errors.OutOfRangeError):
    139       self.evaluate(get_next1())
    140 
    141     # Re-initialize
    142     get_next1 = self.getNext(cache_dataset1, requires_initialization=True)
    143     get_next2 = self.getNext(cache_dataset2, requires_initialization=True)
    144 
    145     # Reading concurrently should succeed.
    146     elements_itr1 = []
    147     elements_itr2 = []
    148     elements_itr2.append(self.evaluate(get_next2()))
    149     elements_itr1.append(self.evaluate(get_next1()))
    150     elements_itr2.append(self.evaluate(get_next2()))
    151     elements_itr1.append(self.evaluate(get_next1()))
    152     # Intentionally reversing the order
    153     elements_itr1.append(self.evaluate(get_next1()))
    154     elements_itr2.append(self.evaluate(get_next2()))
    155     elements_itr1.append(self.evaluate(get_next1()))
    156     elements_itr2.append(self.evaluate(get_next2()))
    157 
    158     with self.assertRaises(errors.OutOfRangeError):
    159       self.evaluate(get_next2())
    160 
    161     with self.assertRaises(errors.OutOfRangeError):
    162       self.evaluate(get_next1())
    163 
    164     self.assertAllEqual(elements, elements_itr1)
    165     self.assertAllEqual(elements, elements_itr2)
    166 
    167 
    168 @test_util.run_all_in_graph_and_eager_modes
    169 class MemoryCacheTest(test_base.DatasetTestBase):
    170 
    171   def testCacheDatasetPassthrough(self):
    172     with ops.device("cpu:0"):
    173       repeat_count = variables.Variable(constant_op.constant(10, dtypes.int64))
    174       dataset = dataset_ops.Dataset.range(3).flat_map(
    175           lambda x: dataset_ops.Dataset.from_tensors(x).repeat(repeat_count))
    176 
    177       cached_dataset = dataset.cache().repeat(2)
    178       uncached_dataset = dataset.repeat(2)
    179 
    180       self.evaluate(repeat_count.initializer)
    181       # Needs to be initializable to capture the variable.
    182       cached_next = self.getNext(cached_dataset, requires_initialization=True)
    183       uncached_next = self.getNext(
    184           uncached_dataset, requires_initialization=True)
    185       for i in range(3):
    186         for _ in range(10):
    187           self.assertEqual(self.evaluate(cached_next()), i)
    188           self.assertEqual(self.evaluate(uncached_next()), i)
    189 
    190       self.evaluate(repeat_count.assign(0))
    191 
    192       # The uncached iterator should now be empty.
    193       with self.assertRaises(errors.OutOfRangeError):
    194         self.evaluate(uncached_next())
    195 
    196       # The cached iterator replays from cache.
    197       for i in range(3):
    198         for _ in range(10):
    199           self.assertEqual(self.evaluate(cached_next()), i)
    200 
    201       # The cached iterator should now be empty.
    202       with self.assertRaises(errors.OutOfRangeError):
    203         self.evaluate(cached_next())
    204 
    205   def testEmptyCacheReading(self):
    206     components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
    207                   np.array([9.0, 10.0, 11.0, 12.0]))
    208 
    209     repeat_dataset = (
    210         dataset_ops.Dataset.from_tensor_slices(components).repeat(0))
    211     cache_dataset = repeat_dataset.cache()
    212 
    213     # Create initialization ops for iterators without and with
    214     # caching, respectively.
    215     self.assertDatasetProduces(cache_dataset, expected_output=[])
    216 
    217   def testConcurrentReaders(self):
    218 
    219     dataset = dataset_ops.Dataset.range(5).cache()
    220     d1 = dataset.map(lambda x: x + 1)
    221     d2 = dataset.map(lambda x: x + 6)
    222 
    223     get_next1 = self.getNext(d1)
    224 
    225     self.assertEqual(1, self.evaluate(get_next1()))
    226     self.assertEqual(2, self.evaluate(get_next1()))
    227     self.assertEqual(3, self.evaluate(get_next1()))
    228 
    229     get_next2 = self.getNext(d2)
    230 
    231     self.assertEqual(6, self.evaluate(get_next2()))
    232     self.assertEqual(7, self.evaluate(get_next2()))
    233     self.assertEqual(4, self.evaluate(get_next1()))  # interleave execution
    234     self.assertEqual([8, 5],
    235                      [self.evaluate(get_next2()),
    236                       self.evaluate(get_next1())])
    237     self.assertEqual(9, self.evaluate(get_next2()))
    238     self.assertEqual(10, self.evaluate(get_next2()))
    239 
    240     with self.assertRaises(errors.OutOfRangeError):
    241       self.evaluate(get_next2())
    242     with self.assertRaises(errors.OutOfRangeError):
    243       self.evaluate(get_next1())
    244 
    245   def testCacheTakeRepeat(self):
    246     dataset = dataset_ops.Dataset.range(10).cache().take(5).repeat(2)
    247 
    248     expected_output = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
    249     self.assertDatasetProduces(dataset, expected_output=expected_output)
    250 
    251 
    252 if __name__ == "__main__":
    253   test.main()
    254