1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from os import path 21 import shutil 22 import tempfile 23 24 import numpy as np 25 26 from tensorflow.python.data.ops import dataset_ops 27 from tensorflow.python.data.ops import iterator_ops 28 from tensorflow.python.framework import constant_op 29 from tensorflow.python.framework import dtypes 30 from tensorflow.python.framework import errors 31 from tensorflow.python.framework import ops 32 from tensorflow.python.ops import array_ops 33 from tensorflow.python.ops import variables 34 from tensorflow.python.platform import test 35 36 37 class FilesystemCacheDatasetTest(test.TestCase): 38 39 def setUp(self): 40 self.tmp_dir = tempfile.mkdtemp() 41 self.cache_prefix = path.join(self.tmp_dir, "cache") 42 43 def tearDown(self): 44 if self.tmp_dir: 45 shutil.rmtree(self.tmp_dir, ignore_errors=True) 46 47 def testCacheDatasetPassthrough(self): 48 components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), 49 np.array([9.0, 10.0, 11.0, 12.0])) 50 count_placeholder = array_ops.placeholder_with_default( 51 constant_op.constant(5, dtypes.int64), shape=[]) 52 filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 53 54 repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components) 55 .repeat(count_placeholder)) 56 57 cache_dataset = repeat_dataset.cache(filename_placeholder) 58 59 self.assertEqual( 60 tuple([c.shape[1:] for c in components]), cache_dataset.output_shapes) 61 62 # Create initialization ops for iterators without and with 63 # caching, respectively. 64 iterator = iterator_ops.Iterator.from_structure(cache_dataset.output_types, 65 cache_dataset.output_shapes) 66 init_fifo_op = iterator.make_initializer(repeat_dataset) 67 init_cache_op = iterator.make_initializer(cache_dataset) 68 69 get_next = iterator.get_next() 70 71 with self.test_session() as sess: 72 # First run without caching to collect the "ground truth". 73 sess.run(init_fifo_op) 74 elements = [] 75 for _ in range(20): 76 elements.append(sess.run(get_next)) 77 with self.assertRaises(errors.OutOfRangeError): 78 sess.run(get_next) 79 80 # Assert that the cached dataset has the same elements as the 81 # "ground truth". 82 sess.run( 83 init_cache_op, feed_dict={filename_placeholder: self.cache_prefix}) 84 cached_elements = [] 85 for _ in range(20): 86 cached_elements.append(sess.run(get_next)) 87 with self.assertRaises(errors.OutOfRangeError): 88 sess.run(get_next) 89 self.assertAllEqual(elements, cached_elements) 90 91 # Re-initialize with an empty upstream (to throw errors.OutOfRangeError 92 # if we didn't use the cache). 93 sess.run( 94 init_cache_op, 95 feed_dict={ 96 count_placeholder: 0, 97 filename_placeholder: self.cache_prefix 98 }) 99 replayed_elements = [] 100 for _ in range(20): 101 replayed_elements.append(sess.run(get_next)) 102 with self.assertRaises(errors.OutOfRangeError): 103 sess.run(get_next) 104 self.assertEqual(cached_elements, replayed_elements) 105 106 # Re-initialize with an empty upstream and a missing cache file (should 107 # throw errors.OutOfRangeError immediately). 108 sess.run( 109 init_cache_op, 110 feed_dict={ 111 count_placeholder: 0, 112 filename_placeholder: self.cache_prefix + "nonsense" 113 }) 114 with self.assertRaises(errors.OutOfRangeError): 115 sess.run(get_next) 116 117 def testConcurrentWriters(self): 118 components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), 119 np.array([9.0, 10.0, 11.0, 12.0])) 120 filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 121 122 cache_dataset1 = (dataset_ops.Dataset.from_tensor_slices(components) 123 .cache(filename_placeholder)) 124 cache_dataset2 = (dataset_ops.Dataset.from_tensor_slices(components) 125 .cache(filename_placeholder)) 126 127 iterator1 = cache_dataset1.make_initializable_iterator() 128 iterator2 = cache_dataset2.make_initializable_iterator() 129 init_cache_op1 = iterator1.initializer 130 init_cache_op2 = iterator2.initializer 131 132 get_next1 = iterator1.get_next() 133 get_next2 = iterator2.get_next() 134 135 with self.test_session() as sess: 136 sess.run( 137 init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix}) 138 sess.run(get_next1) # this should succeed 139 140 sess.run( 141 init_cache_op2, feed_dict={filename_placeholder: self.cache_prefix}) 142 with self.assertRaises(errors.AlreadyExistsError): 143 sess.run(get_next2) 144 145 sess.run(get_next1) # this should continue to succeed 146 147 def testConcurrentReaders(self): 148 components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), 149 np.array([9.0, 10.0, 11.0, 12.0])) 150 filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 151 152 cache_dataset1 = (dataset_ops.Dataset.from_tensor_slices(components) 153 .cache(filename_placeholder)) 154 cache_dataset2 = (dataset_ops.Dataset.from_tensor_slices(components) 155 .cache(filename_placeholder)) 156 157 iterator1 = cache_dataset1.make_initializable_iterator() 158 iterator2 = cache_dataset2.make_initializable_iterator() 159 init_cache_op1 = iterator1.initializer 160 init_cache_op2 = iterator2.initializer 161 162 get_next1 = iterator1.get_next() 163 get_next2 = iterator2.get_next() 164 165 with self.test_session() as sess: 166 sess.run( 167 init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix}) 168 elements = [] 169 for _ in range(4): 170 elements.append(sess.run(get_next1)) 171 with self.assertRaises(errors.OutOfRangeError): 172 sess.run(get_next1) 173 174 # Re-initialize 175 sess.run( 176 init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix}) 177 sess.run( 178 init_cache_op2, feed_dict={filename_placeholder: self.cache_prefix}) 179 180 # Reading concurrently should succeed. 181 elements_itr1 = [] 182 elements_itr2 = [] 183 elements_itr2.append(sess.run(get_next2)) 184 elements_itr1.append(sess.run(get_next1)) 185 elements_itr2.append(sess.run(get_next2)) 186 elements_itr1.append(sess.run(get_next1)) 187 # Intentionally reversing the order 188 elements_itr1.append(sess.run(get_next1)) 189 elements_itr2.append(sess.run(get_next2)) 190 elements_itr1.append(sess.run(get_next1)) 191 elements_itr2.append(sess.run(get_next2)) 192 193 with self.assertRaises(errors.OutOfRangeError): 194 sess.run(get_next2) 195 196 with self.assertRaises(errors.OutOfRangeError): 197 sess.run(get_next1) 198 199 self.assertAllEqual(elements, elements_itr1) 200 self.assertAllEqual(elements, elements_itr2) 201 202 203 class MemoryCacheDatasetTest(test.TestCase): 204 205 def testCacheDatasetPassthrough(self): 206 with ops.device("cpu:0"): 207 repeat_count = variables.Variable(constant_op.constant(10, dtypes.int64)) 208 dataset = dataset_ops.Dataset.range(3).flat_map( 209 lambda x: dataset_ops.Dataset.from_tensors(x).repeat(repeat_count)) 210 211 cached_dataset = dataset.cache().repeat(2) 212 uncached_dataset = dataset.repeat(2) 213 214 # Needs to be initializable to capture the variable. 215 cached_iterator = cached_dataset.make_initializable_iterator() 216 cached_next = cached_iterator.get_next() 217 uncached_iterator = uncached_dataset.make_initializable_iterator() 218 uncached_next = uncached_iterator.get_next() 219 220 with self.test_session() as sess: 221 222 sess.run(repeat_count.initializer) 223 sess.run(cached_iterator.initializer) 224 sess.run(uncached_iterator.initializer) 225 226 for i in range(3): 227 for _ in range(10): 228 self.assertEqual(sess.run(cached_next), i) 229 self.assertEqual(sess.run(uncached_next), i) 230 231 sess.run(repeat_count.assign(0)) 232 233 # The uncached iterator should now be empty. 234 with self.assertRaises(errors.OutOfRangeError): 235 sess.run(uncached_next) 236 237 # The cached iterator replays from cache. 238 for i in range(3): 239 for _ in range(10): 240 self.assertEqual(sess.run(cached_next), i) 241 242 # The cached iterator should now be empty. 243 with self.assertRaises(errors.OutOfRangeError): 244 sess.run(cached_next) 245 246 def testEmptyCacheReading(self): 247 components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), 248 np.array([9.0, 10.0, 11.0, 12.0])) 249 count_placeholder = array_ops.placeholder_with_default( 250 constant_op.constant(5, dtypes.int64), shape=[]) 251 252 repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components) 253 .repeat(count_placeholder)) 254 255 cache_dataset = repeat_dataset.cache() 256 257 # Create initialization ops for iterators without and with 258 # caching, respectively. 259 iterator = cache_dataset.make_initializable_iterator() 260 init_cache_op = iterator.initializer 261 262 get_next = iterator.get_next() 263 264 with self.test_session() as sess: 265 # Initialize with an empty upstream and a missing cache file (should 266 # throw errors.OutOfRangeError immediately). 267 sess.run(init_cache_op, feed_dict={count_placeholder: 0}) 268 with self.assertRaises(errors.OutOfRangeError): 269 sess.run(get_next) 270 271 def testConcurrentReaders(self): 272 count_placeholder = array_ops.placeholder_with_default( 273 constant_op.constant(5, dtypes.int64), shape=[]) 274 dataset = dataset_ops.Dataset.range(count_placeholder).cache() 275 d1 = dataset.map(lambda x: x + 1) 276 d2 = dataset.map(lambda x: x + 6) 277 278 i1 = d1.make_initializable_iterator() 279 i2 = d2.make_initializable_iterator() 280 281 with self.test_session() as sess: 282 sess.run(i1.initializer) 283 284 self.assertEqual(1, sess.run(i1.get_next())) 285 self.assertEqual(2, sess.run(i1.get_next())) 286 self.assertEqual(3, sess.run(i1.get_next())) 287 288 sess.run(i2.initializer, feed_dict={count_placeholder: 3}) 289 290 self.assertEqual(6, sess.run(i2.get_next())) 291 self.assertEqual(7, sess.run(i2.get_next())) 292 self.assertEqual(4, sess.run(i1.get_next())) # interleave execution 293 self.assertEqual([8, 5], sess.run([i2.get_next(), i1.get_next()])) 294 295 with self.assertRaises(errors.OutOfRangeError): 296 sess.run(i1.get_next()) 297 with self.assertRaises(errors.OutOfRangeError): 298 sess.run(i2.get_next()) 299 300 301 if __name__ == "__main__": 302 test.main() 303