1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for `tf.data.Dataset.cache()`.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from os import path 21 import shutil 22 import tempfile 23 24 import numpy as np 25 26 from tensorflow.python.data.kernel_tests import test_base 27 from tensorflow.python.data.ops import dataset_ops 28 from tensorflow.python.framework import constant_op 29 from tensorflow.python.framework import dtypes 30 from tensorflow.python.framework import errors 31 from tensorflow.python.framework import ops 32 from tensorflow.python.framework import test_util 33 from tensorflow.python.ops import variables 34 from tensorflow.python.platform import test 35 36 37 @test_util.run_all_in_graph_and_eager_modes 38 class FileCacheTest(test_base.DatasetTestBase): 39 40 def setUp(self): 41 self.tmp_dir = tempfile.mkdtemp() 42 self.cache_prefix = path.join(self.tmp_dir, "cache") 43 44 def tearDown(self): 45 if self.tmp_dir: 46 shutil.rmtree(self.tmp_dir, ignore_errors=True) 47 48 def testCacheDatasetPassthrough(self): 49 components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), 50 np.array([9.0, 10.0, 11.0, 12.0])) 51 52 def dataset_fn(count=5, filename=None): 53 repeat_dataset = ( 54 dataset_ops.Dataset.from_tensor_slices(components).repeat(count)) 55 if filename: 56 return repeat_dataset.cache(filename) 57 else: 58 return repeat_dataset 59 60 self.assertEqual( 61 tuple([c.shape[1:] for c in components]), 62 dataset_ops.get_legacy_output_shapes(dataset_fn())) 63 64 get_next = self.getNext(dataset_fn()) 65 66 # First run without caching to collect the "ground truth". 67 elements = [] 68 for _ in range(20): 69 elements.append(self.evaluate(get_next())) 70 with self.assertRaises(errors.OutOfRangeError): 71 self.evaluate(get_next()) 72 73 # Assert that the cached dataset has the same elements as the 74 # "ground truth". 75 get_next = self.getNext(dataset_fn(filename=self.cache_prefix)) 76 cached_elements = [] 77 for _ in range(20): 78 cached_elements.append(self.evaluate(get_next())) 79 with self.assertRaises(errors.OutOfRangeError): 80 self.evaluate(get_next()) 81 self.assertAllEqual(elements, cached_elements) 82 83 # Re-initialize with an empty upstream (to throw errors.OutOfRangeError 84 # if we didn't use the cache). 85 get_next = self.getNext(dataset_fn(count=0, filename=self.cache_prefix)) 86 replayed_elements = [] 87 for _ in range(20): 88 replayed_elements.append(self.evaluate(get_next())) 89 with self.assertRaises(errors.OutOfRangeError): 90 self.evaluate(get_next()) 91 self.assertEqual(cached_elements, replayed_elements) 92 93 # Re-initialize with an empty upstream and a missing cache file (should 94 # throw errors.OutOfRangeError immediately). 95 get_next = self.getNext( 96 dataset_fn(count=0, filename=self.cache_prefix + "nonsense")) 97 with self.assertRaises(errors.OutOfRangeError): 98 self.evaluate(get_next()) 99 100 def testConcurrentWriters(self): 101 components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), 102 np.array([9.0, 10.0, 11.0, 12.0])) 103 104 cache_dataset1 = ( 105 dataset_ops.Dataset.from_tensor_slices(components).cache( 106 self.cache_prefix)) 107 cache_dataset2 = ( 108 dataset_ops.Dataset.from_tensor_slices(components).cache( 109 self.cache_prefix)) 110 111 get_next1 = self.getNext(cache_dataset1) 112 get_next2 = self.getNext(cache_dataset2) 113 114 self.evaluate(get_next1()) # this should succeed 115 116 with self.assertRaises(errors.AlreadyExistsError): 117 self.evaluate(get_next2()) 118 119 self.evaluate(get_next1()) # this should continue to succeed 120 121 def testConcurrentReaders(self): 122 components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), 123 np.array([9.0, 10.0, 11.0, 12.0])) 124 125 cache_dataset1 = ( 126 dataset_ops.Dataset.from_tensor_slices(components).cache( 127 self.cache_prefix)) 128 cache_dataset2 = ( 129 dataset_ops.Dataset.from_tensor_slices(components).cache( 130 self.cache_prefix)) 131 132 get_next1 = self.getNext(cache_dataset1) 133 get_next2 = self.getNext(cache_dataset2) 134 135 elements = [] 136 for _ in range(4): 137 elements.append(self.evaluate(get_next1())) 138 with self.assertRaises(errors.OutOfRangeError): 139 self.evaluate(get_next1()) 140 141 # Re-initialize 142 get_next1 = self.getNext(cache_dataset1, requires_initialization=True) 143 get_next2 = self.getNext(cache_dataset2, requires_initialization=True) 144 145 # Reading concurrently should succeed. 146 elements_itr1 = [] 147 elements_itr2 = [] 148 elements_itr2.append(self.evaluate(get_next2())) 149 elements_itr1.append(self.evaluate(get_next1())) 150 elements_itr2.append(self.evaluate(get_next2())) 151 elements_itr1.append(self.evaluate(get_next1())) 152 # Intentionally reversing the order 153 elements_itr1.append(self.evaluate(get_next1())) 154 elements_itr2.append(self.evaluate(get_next2())) 155 elements_itr1.append(self.evaluate(get_next1())) 156 elements_itr2.append(self.evaluate(get_next2())) 157 158 with self.assertRaises(errors.OutOfRangeError): 159 self.evaluate(get_next2()) 160 161 with self.assertRaises(errors.OutOfRangeError): 162 self.evaluate(get_next1()) 163 164 self.assertAllEqual(elements, elements_itr1) 165 self.assertAllEqual(elements, elements_itr2) 166 167 168 @test_util.run_all_in_graph_and_eager_modes 169 class MemoryCacheTest(test_base.DatasetTestBase): 170 171 def testCacheDatasetPassthrough(self): 172 with ops.device("cpu:0"): 173 repeat_count = variables.Variable(constant_op.constant(10, dtypes.int64)) 174 dataset = dataset_ops.Dataset.range(3).flat_map( 175 lambda x: dataset_ops.Dataset.from_tensors(x).repeat(repeat_count)) 176 177 cached_dataset = dataset.cache().repeat(2) 178 uncached_dataset = dataset.repeat(2) 179 180 self.evaluate(repeat_count.initializer) 181 # Needs to be initializable to capture the variable. 182 cached_next = self.getNext(cached_dataset, requires_initialization=True) 183 uncached_next = self.getNext( 184 uncached_dataset, requires_initialization=True) 185 for i in range(3): 186 for _ in range(10): 187 self.assertEqual(self.evaluate(cached_next()), i) 188 self.assertEqual(self.evaluate(uncached_next()), i) 189 190 self.evaluate(repeat_count.assign(0)) 191 192 # The uncached iterator should now be empty. 193 with self.assertRaises(errors.OutOfRangeError): 194 self.evaluate(uncached_next()) 195 196 # The cached iterator replays from cache. 197 for i in range(3): 198 for _ in range(10): 199 self.assertEqual(self.evaluate(cached_next()), i) 200 201 # The cached iterator should now be empty. 202 with self.assertRaises(errors.OutOfRangeError): 203 self.evaluate(cached_next()) 204 205 def testEmptyCacheReading(self): 206 components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), 207 np.array([9.0, 10.0, 11.0, 12.0])) 208 209 repeat_dataset = ( 210 dataset_ops.Dataset.from_tensor_slices(components).repeat(0)) 211 cache_dataset = repeat_dataset.cache() 212 213 # Create initialization ops for iterators without and with 214 # caching, respectively. 215 self.assertDatasetProduces(cache_dataset, expected_output=[]) 216 217 def testConcurrentReaders(self): 218 219 dataset = dataset_ops.Dataset.range(5).cache() 220 d1 = dataset.map(lambda x: x + 1) 221 d2 = dataset.map(lambda x: x + 6) 222 223 get_next1 = self.getNext(d1) 224 225 self.assertEqual(1, self.evaluate(get_next1())) 226 self.assertEqual(2, self.evaluate(get_next1())) 227 self.assertEqual(3, self.evaluate(get_next1())) 228 229 get_next2 = self.getNext(d2) 230 231 self.assertEqual(6, self.evaluate(get_next2())) 232 self.assertEqual(7, self.evaluate(get_next2())) 233 self.assertEqual(4, self.evaluate(get_next1())) # interleave execution 234 self.assertEqual([8, 5], 235 [self.evaluate(get_next2()), 236 self.evaluate(get_next1())]) 237 self.assertEqual(9, self.evaluate(get_next2())) 238 self.assertEqual(10, self.evaluate(get_next2())) 239 240 with self.assertRaises(errors.OutOfRangeError): 241 self.evaluate(get_next2()) 242 with self.assertRaises(errors.OutOfRangeError): 243 self.evaluate(get_next1()) 244 245 def testCacheTakeRepeat(self): 246 dataset = dataset_ops.Dataset.range(10).cache().take(5).repeat(2) 247 248 expected_output = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] 249 self.assertDatasetProduces(dataset, expected_output=expected_output) 250 251 252 if __name__ == "__main__": 253 test.main() 254