1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 from __future__ import absolute_import 16 from __future__ import division 17 from __future__ import print_function 18 19 import os 20 21 import threading 22 import time 23 24 import numpy as np 25 26 from tensorflow.contrib import lookup 27 from tensorflow.contrib.eager.python import datasets 28 from tensorflow.python.data import Dataset 29 from tensorflow.python.data.experimental.ops import threadpool 30 from tensorflow.python.data.experimental.ops import unique 31 from tensorflow.python.eager import test 32 from tensorflow.python.framework import constant_op 33 from tensorflow.python.framework import dtypes 34 from tensorflow.python.framework import errors 35 from tensorflow.python.framework import ops 36 from tensorflow.python.framework import sparse_tensor 37 from tensorflow.python.ops import math_ops 38 from tensorflow.python.ops import script_ops 39 from tensorflow.python.training import checkpoint_management 40 from tensorflow.python.training.tracking import util as trackable_utils 41 42 43 class IteratorTest(test.TestCase): 44 45 def testBasic(self): 46 got = [] 47 for t in datasets.Iterator(Dataset.range(4)): 48 got.append(t.numpy()) 49 self.assertAllEqual([0, 1, 2, 3], got) 50 51 def testBasicOneShotIterator(self): 52 got = [] 53 for t in Dataset.range(4).make_one_shot_iterator(): 54 got.append(t.numpy()) 55 self.assertAllEqual([0, 1, 2, 3], got) 56 57 def testBasicImplicitIterator(self): 58 got = [] 59 for t in Dataset.range(4): 60 got.append(t.numpy()) 61 self.assertAllEqual([0, 1, 2, 3], got) 62 63 def testGetNext(self): 64 iterator = datasets.Iterator(Dataset.range(4)) 65 self.assertEqual(0, iterator.get_next().numpy()) 66 self.assertEqual(1, iterator.get_next().numpy()) 67 self.assertEqual(2, iterator.get_next().numpy()) 68 self.assertEqual(3, iterator.get_next().numpy()) 69 with self.assertRaises(errors.OutOfRangeError): 70 iterator.get_next() 71 72 def testGetNextOneShotIterator(self): 73 iterator = Dataset.range(4).make_one_shot_iterator() 74 self.assertEqual(0, iterator.get_next().numpy()) 75 self.assertEqual(1, iterator.get_next().numpy()) 76 self.assertEqual(2, iterator.get_next().numpy()) 77 self.assertEqual(3, iterator.get_next().numpy()) 78 with self.assertRaises(errors.OutOfRangeError): 79 iterator.get_next() 80 81 def testMultipleIteratorsOnTheSameDataset(self): 82 ds = Dataset.range(4) 83 it1 = datasets.Iterator(ds) 84 it2 = datasets.Iterator(ds) 85 got = [x.numpy() for x in it1] 86 self.assertAllEqual([0, 1, 2, 3], got) 87 88 got = [x.numpy() for x in it2] 89 self.assertAllEqual([0, 1, 2, 3], got) 90 91 def testNestedOutputs(self): 92 ds = Dataset.zip((Dataset.range(4), Dataset.zip((Dataset.range(4), 93 Dataset.range(4))))) 94 total = 0 95 # The Iterator will return a nested structure of Tensor objects. 96 # Some funkiness to compare against simple integers. 97 for (i, x) in enumerate(datasets.Iterator(ds)): 98 want = (i, (i, i)) 99 got = (x[0].numpy(), (x[1][0].numpy(), x[1][1].numpy())) 100 self.assertEqual(got, want) 101 total += 1 102 self.assertEqual(4, total) 103 104 def testMapAndFilter(self): 105 def even(x): 106 return math_ops.equal(math_ops.mod(x, 2), 0) 107 108 it = datasets.Iterator(Dataset.range(8).map(math_ops.square).filter(even)) 109 got = [x.numpy() for x in it] 110 self.assertAllEqual([0, 4, 16, 36], got) 111 112 def testMapCaptureLookupTable(self): 113 default_val = -1 114 keys = constant_op.constant(['brain', 'salad', 'surgery']) 115 values = constant_op.constant([0, 1, 2], dtypes.int64) 116 table = lookup.HashTable( 117 lookup.KeyValueTensorInitializer(keys, values), default_val) 118 dataset = Dataset.from_tensor_slices(['brain', 'salad', 'surgery']) 119 dataset = dataset.map(table.lookup) 120 it = datasets.Iterator(dataset) 121 got = [x.numpy() for x in it] 122 self.assertAllEqual([0, 1, 2], got) 123 124 def testMultipleIteratorsOnADatasetThatUsesFunctions(self): 125 ds = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(math_ops.square) 126 127 got1 = [x.numpy() for x in datasets.Iterator(ds)] 128 self.assertAllEqual([1, 4, 9, 16, 25, 36], got1) 129 got2 = [x.numpy() for x in datasets.Iterator(ds)] 130 self.assertAllEqual(got1, got2) 131 132 def assertSparseValuesEqual(self, a, b): 133 self.assertAllEqual(a.indices, b.indices) 134 self.assertAllEqual(a.values, b.values) 135 self.assertAllEqual(a.dense_shape, b.dense_shape) 136 137 def testSparseTensorElements(self): 138 components = (sparse_tensor.SparseTensorValue( 139 indices=np.array([[0, 0], [1, 0], [2, 0]]), 140 values=np.array([0, 0, 0]), 141 dense_shape=np.array([3, 1])), 142 sparse_tensor.SparseTensorValue( 143 indices=np.array([[0, 0], [1, 1], [2, 2]]), 144 values=np.array([1, 2, 3]), 145 dense_shape=np.array([3, 3]))) 146 147 expected = [ 148 (sparse_tensor.SparseTensorValue( 149 indices=np.array([[0]]), 150 values=np.array([0]), 151 dense_shape=np.array([1])), 152 sparse_tensor.SparseTensorValue( 153 indices=np.array([[0]]), 154 values=np.array([1]), 155 dense_shape=np.array([3]))), 156 (sparse_tensor.SparseTensorValue( 157 indices=np.array([[0]]), 158 values=np.array([0]), 159 dense_shape=np.array([1])), 160 sparse_tensor.SparseTensorValue( 161 indices=np.array([[1]]), 162 values=np.array([2]), 163 dense_shape=np.array([3]))), 164 (sparse_tensor.SparseTensorValue( 165 indices=np.array([[0]]), 166 values=np.array([0]), 167 dense_shape=np.array([1])), 168 sparse_tensor.SparseTensorValue( 169 indices=np.array([[2]]), 170 values=np.array([3]), 171 dense_shape=np.array([3]))), 172 ] 173 174 for i, result in enumerate( 175 datasets.Iterator(Dataset.from_tensor_slices(components))): 176 self.assertSparseValuesEqual(expected[i][0], result[0]) 177 self.assertSparseValuesEqual(expected[i][1], result[1]) 178 179 def testPyFunc(self): 180 181 def my_map(inp): 182 return [[x + 1 for x in inp]] 183 184 ds = Dataset.range(4).map( 185 lambda x: script_ops.py_func(my_map, [[x]], dtypes.int64)) 186 got = [x.numpy() for x in datasets.Iterator(ds)] 187 self.assertAllEqual([[1], [2], [3], [4]], got) 188 189 def testTensorsPlacedOnDevice(self): 190 ds = Dataset.from_tensors([0., 1.]) 191 with ops.device(test.gpu_device_name()): 192 x = datasets.Iterator(ds).next() 193 x = math_ops.add(x, x) 194 self.assertAllEqual([0., 2.], x.numpy()) 195 196 def testGpuTensor(self): 197 ds = Dataset.from_tensors([0., 1.]) 198 with ops.device(test.gpu_device_name()): 199 for x in ds: 200 y = math_ops.add(x, x) 201 self.assertAllEqual([0., 2.], y.numpy()) 202 203 def testOverrideThreadPool(self): 204 205 def get_thread_id(_): 206 # Python creates a dummy thread object to represent the current 207 # thread when called from an "alien" thread (such as a 208 # `PrivateThreadPool` thread in this case). It does not include 209 # the TensorFlow-given display name, but it has a unique 210 # identifier that maps one-to-one with the underlying OS thread. 211 return np.array(threading.current_thread().ident).astype(np.int64) 212 213 for num_threads in [1, 2, 4, 8, 16]: 214 215 dataset = ( 216 Dataset.range(1000).map( 217 lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), 218 num_parallel_calls=32).apply(unique.unique())) 219 220 dataset = threadpool.override_threadpool( 221 dataset, 222 threadpool.PrivateThreadPool( 223 num_threads, display_name='private_thread_pool_%d' % num_threads)) 224 225 thread_ids = [] 226 for next_element in datasets.Iterator(dataset): 227 thread_ids.append(next_element) 228 self.assertEqual(len(thread_ids), len(set(thread_ids))) 229 self.assertGreater(len(thread_ids), 0) 230 # NOTE(mrry): We don't control the thread pool scheduling, and 231 # so cannot guarantee that all of the threads in the pool will 232 # perform work. 233 self.assertLessEqual(len(thread_ids), num_threads) 234 235 def testSaveRestore(self): 236 checkpoint_directory = self.get_temp_dir() 237 checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') 238 dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) 239 dataset = dataset.map(math_ops.square).batch(2) 240 iterator = datasets.Iterator(dataset) 241 checkpoint = trackable_utils.Checkpoint(iterator=iterator) 242 self.assertAllEqual([1, 4], iterator.get_next().numpy()) 243 save_path = checkpoint.save(checkpoint_prefix) 244 self.assertAllEqual([9, 16], iterator.get_next().numpy()) 245 self.assertAllEqual([25, 36], iterator.get_next().numpy()) 246 checkpoint.restore(save_path) 247 self.assertAllEqual([9, 16], iterator.get_next().numpy()) 248 self.assertAllEqual([25, 36], iterator.get_next().numpy()) 249 250 def testSaveRestoreMultipleIterator(self): 251 checkpoint_directory = self.get_temp_dir() 252 checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') 253 dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) 254 dataset = dataset.map(math_ops.square).batch(2) 255 iterator_1 = datasets.Iterator(dataset) 256 iterator_2 = datasets.Iterator(dataset) 257 dataset_2 = Dataset.range(10) 258 iterator_3 = datasets.Iterator(dataset_2) 259 260 checkpoint = trackable_utils.Checkpoint( 261 iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3) 262 self.assertAllEqual([1, 4], iterator_1.get_next().numpy()) 263 self.assertEqual(0, iterator_3.get_next().numpy()) 264 self.assertEqual(1, iterator_3.get_next().numpy()) 265 self.assertEqual(2, iterator_3.get_next().numpy()) 266 267 save_path = checkpoint.save(checkpoint_prefix) 268 self.assertAllEqual([1, 4], iterator_2.get_next().numpy()) 269 self.assertAllEqual([9, 16], iterator_2.get_next().numpy()) 270 self.assertEqual(3, iterator_3.get_next().numpy()) 271 checkpoint.restore(save_path) 272 self.assertAllEqual([9, 16], iterator_1.get_next().numpy()) 273 self.assertAllEqual([1, 4], iterator_2.get_next().numpy()) 274 self.assertEqual(3, iterator_3.get_next().numpy()) 275 276 def testRestoreExhaustedIterator(self): 277 checkpoint_directory = self.get_temp_dir() 278 checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') 279 dataset = Dataset.range(3) 280 iterator = datasets.Iterator(dataset) 281 282 checkpoint = trackable_utils.Checkpoint(iterator=iterator) 283 self.assertEqual(0, iterator.get_next().numpy()) 284 self.assertEqual(1, iterator.get_next().numpy()) 285 save_path = checkpoint.save(checkpoint_prefix) 286 self.assertEqual(2, iterator.get_next().numpy()) 287 checkpoint.restore(save_path) 288 self.assertEqual(2, iterator.get_next().numpy()) 289 290 def testRestoreInReconstructedIterator(self): 291 checkpoint_directory = self.get_temp_dir() 292 checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') 293 dataset = Dataset.range(10) 294 for i in range(5): 295 iterator = datasets.Iterator(dataset) 296 checkpoint = trackable_utils.Checkpoint(iterator=iterator) 297 checkpoint.restore(checkpoint_management.latest_checkpoint( 298 checkpoint_directory)) 299 for j in range(2): 300 self.assertEqual(i * 2 + j, iterator.get_next().numpy()) 301 checkpoint.save(file_prefix=checkpoint_prefix) 302 303 304 class DatasetConstructorBenchmark(test.Benchmark): 305 306 def benchmarkSliceRepeatBatchEager(self): 307 input_size = 10000 308 batch_size = 100 309 num_epochs = 100 310 311 input_data = np.random.randn(input_size) 312 313 dataset = ( 314 Dataset.from_tensor_slices(input_data).repeat(num_epochs) 315 .batch(batch_size)) 316 iterator = datasets.Iterator(dataset) 317 318 ends = [time.time()] 319 for _ in iterator: 320 ends.append(time.time()) 321 322 deltas = np.ediff1d(ends) 323 median_wall_time = np.median(deltas) 324 print( 325 'Slice/repeat/batch eager input size: %d batch size: %d Median wall ' 326 'time per element: %f' 327 % (input_size, batch_size, median_wall_time)) 328 self.report_benchmark( 329 iters=len(deltas), 330 wall_time=median_wall_time, 331 name='benchmark_slice_repeat_batch_eager_input_%d_batch_%d' % 332 (input_size, batch_size)) 333 334 def benchmarkSliceBatchCacheRepeatCallable(self): 335 input_size = 10000 336 batch_size = 100 337 num_epochs = 100 338 339 input_data = np.random.randn(input_size) 340 341 dataset = ( 342 Dataset.from_tensor_slices(input_data).batch(batch_size).cache() 343 .repeat(num_epochs)) 344 iterator = datasets.Iterator(dataset) 345 346 ends = [time.time()] 347 for _ in iterator: 348 ends.append(time.time()) 349 350 deltas = np.ediff1d(ends) 351 median_wall_time = np.median(deltas) 352 print( 353 'Slice/batch/cache/repeat eager input size: %d batch size: %d Median ' 354 'wall time per element: %f' 355 % (input_size, batch_size, median_wall_time)) 356 self.report_benchmark( 357 iters=len(deltas), 358 wall_time=median_wall_time, 359 name='benchmark_slice_batch_cache_repeat_eager_input_%d_batch_%d' % 360 (input_size, batch_size)) 361 362 363 if __name__ == '__main__': 364 test.main() 365