1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from collections import namedtuple 21 import threading 22 import time 23 24 import numpy as np 25 26 from tensorflow.python.client import session 27 from tensorflow.python.data.ops import dataset_ops 28 from tensorflow.python.framework import constant_op 29 from tensorflow.python.framework import dtypes 30 from tensorflow.python.framework import errors 31 from tensorflow.python.framework import ops 32 from tensorflow.python.framework import sparse_tensor 33 from tensorflow.python.ops import array_ops 34 from tensorflow.python.ops import data_flow_ops 35 from tensorflow.python.ops import functional_ops 36 from tensorflow.python.ops import lookup_ops 37 from tensorflow.python.ops import math_ops 38 from tensorflow.python.ops import random_ops 39 from tensorflow.python.ops import script_ops 40 from tensorflow.python.ops import sparse_ops 41 from tensorflow.python.ops import string_ops 42 from tensorflow.python.ops import variable_scope 43 from tensorflow.python.platform import test 44 45 46 class MapDatasetTest(test.TestCase): 47 48 def _buildMapDataset(self, components, count): 49 def _map_fn(x, y, z): 50 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 51 return (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) 52 .repeat(count)) 53 54 def testMapDataset(self): 55 """Test an dataset that maps a TF function across its input elements.""" 56 # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> 57 # RepeatDataset(count). 58 components = (np.arange(7), 59 np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], 60 np.array(37.0) * np.arange(7)) 61 count = array_ops.placeholder(dtypes.int64, shape=[]) 62 63 dataset = self._buildMapDataset(components, count) 64 iterator = dataset.make_initializable_iterator() 65 init_op = iterator.initializer 66 get_next = iterator.get_next() 67 68 self.assertEqual([c.shape[1:] for c in components], 69 [t.shape for t in get_next]) 70 71 with self.test_session() as sess: 72 # Test single-threaded access to the iterator. 73 sess.run(init_op, feed_dict={count: 14}) 74 for _ in range(14): 75 for i in range(7): 76 result = sess.run(get_next) 77 for component, result_component in zip(components, result): 78 self.assertAllEqual(component[i]**2, result_component) 79 with self.assertRaises(errors.OutOfRangeError): 80 sess.run(get_next) 81 82 # Test multi-threaded access to the same iterator. 83 sess.run(init_op, feed_dict={count: 18}) 84 results = [] 85 def iterator_thread(): 86 while True: 87 try: 88 results.append(sess.run(get_next)) 89 except errors.OutOfRangeError: 90 return 91 threads = [self.checkedThread(target=iterator_thread) for _ in range(8)] 92 for t in threads: 93 t.start() 94 for t in threads: 95 t.join() 96 97 # `results` will contain the same elements components**2 98 # repeated 18 times, but in a non-deterministic order. Sort the 99 # results, and assert that each element of components**2 is 100 # produced 18 times. 101 results.sort(key=lambda x: x[0]) 102 for i in range(7): 103 for j in range(18): 104 for component, result_component in zip(components, 105 results[i * 18 + j]): 106 self.assertAllEqual(component[i]**2, result_component) 107 108 def _buildParallelMapDataset(self, components, count, num_parallel_calls, 109 output_buffer_size): 110 def _map_fn(x, y, z): 111 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 112 return (dataset_ops.Dataset.from_tensor_slices(components) 113 .map(_map_fn, num_parallel_calls=num_parallel_calls) 114 .prefetch(output_buffer_size) 115 .repeat(count)) 116 117 def testParallelMapDataset(self): 118 """Test an dataset that maps a TF function across its input elements.""" 119 # The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) -> 120 # RepeatDataset(count). 121 components = (np.arange(7), 122 np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], 123 np.array(37.0) * np.arange(7)) 124 count = array_ops.placeholder(dtypes.int64, shape=[]) 125 num_parallel_calls = array_ops.placeholder(dtypes.int32, shape=[]) 126 output_buffer_size = array_ops.placeholder(dtypes.int64, shape=[]) 127 128 dataset = self._buildParallelMapDataset( 129 components, count, num_parallel_calls, output_buffer_size) 130 iterator = dataset.make_initializable_iterator() 131 init_op = iterator.initializer 132 get_next = iterator.get_next() 133 134 self.assertEqual([c.shape[1:] for c in components], 135 [t.shape for t in get_next]) 136 137 with self.test_session() as sess: 138 def do_test(num_parallel_calls_val, output_buffer_size_val): 139 # Test single-threaded access to the iterator. 140 sess.run(init_op, feed_dict={ 141 count: 14, 142 num_parallel_calls: num_parallel_calls_val, 143 output_buffer_size: output_buffer_size_val}) 144 for _ in range(14): 145 for i in range(7): 146 result = sess.run(get_next) 147 for component, result_component in zip(components, result): 148 self.assertAllEqual(component[i]**2, result_component) 149 with self.assertRaises(errors.OutOfRangeError): 150 sess.run(get_next) 151 152 # Test multi-threaded access to the same iterator. 153 sess.run(init_op, feed_dict={ 154 count: 18, 155 num_parallel_calls: num_parallel_calls_val, 156 output_buffer_size: output_buffer_size_val}) 157 results = [] 158 def iterator_thread(): 159 while True: 160 try: 161 results.append(sess.run(get_next)) 162 except errors.OutOfRangeError: 163 return 164 threads = [self.checkedThread(target=iterator_thread) 165 for _ in range(64)] 166 for t in threads: 167 t.start() 168 for t in threads: 169 t.join() 170 171 # `results` will contain the same elements components**2 172 # repeated 18 times, but in a non-deterministic order. Sort the 173 # results, and assert that each element of components**2 is 174 # produced 18 times. 175 results.sort(key=lambda x: x[0]) 176 for i in range(7): 177 for j in range(18): 178 for component, result_component in zip(components, 179 results[i * 18 + j]): 180 self.assertAllEqual(component[i]**2, result_component) 181 182 for num_parallel_calls_val, output_buffer_size_val in [ 183 (1, 1), (1, 2), (2, 2), (2, 4), (8, 8), (8, 16)]: 184 do_test(num_parallel_calls_val, output_buffer_size_val) 185 186 def testImplicitDisposeParallelMapDataset(self): 187 # Tests whether a parallel map dataset will be cleaned up correctly when 188 # the pipeline does not run it until exhaustion. 189 # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> 190 # RepeatDataset(1000). 191 components = (np.arange(1000), 192 np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis], 193 np.array(37.0) * np.arange(1000)) 194 195 dataset = self._buildParallelMapDataset(components, 1000, 100, 100) 196 # NOTE(mrry): Also test that the prefetching thread is cancelled correctly. 197 dataset = dataset.prefetch(100) 198 iterator = dataset.make_initializable_iterator() 199 init_op = iterator.initializer 200 get_next = iterator.get_next() 201 202 with self.test_session() as sess: 203 sess.run(init_op) 204 for _ in range(3): 205 sess.run(get_next) 206 207 def testParallelMapUnspecifiedOutputSize(self): 208 components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) 209 210 dataset = (dataset_ops.Dataset.from_tensor_slices(components) 211 .map(lambda x: array_ops.check_numerics(x, "message"), 212 num_parallel_calls=2)) 213 iterator = dataset.make_initializable_iterator() 214 init_op = iterator.initializer 215 get_next = iterator.get_next() 216 217 with self.test_session() as sess: 218 sess.run(init_op) 219 for _ in range(3): 220 sess.run(get_next) 221 222 def testParallelMapError(self): 223 components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) 224 225 dataset = (dataset_ops.Dataset.from_tensor_slices(components) 226 .map(lambda x: array_ops.check_numerics(x, "message"), 227 num_parallel_calls=2)) 228 iterator = dataset.make_initializable_iterator() 229 init_op = iterator.initializer 230 get_next = iterator.get_next() 231 232 with self.test_session() as sess: 233 sess.run(init_op) 234 for _ in range(3): 235 sess.run(get_next) 236 # The 4th element is NaN, so `array_ops.check_numerics()` should fail. 237 with self.assertRaises(errors.InvalidArgumentError): 238 sess.run(get_next) 239 sess.run(get_next) 240 with self.assertRaises(errors.OutOfRangeError): 241 sess.run(get_next) 242 243 def testPrefetchError(self): 244 components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) 245 246 dataset = (dataset_ops.Dataset.from_tensor_slices(components) 247 .map(lambda x: array_ops.check_numerics(x, "message")) 248 .prefetch(2)) 249 iterator = dataset.make_initializable_iterator() 250 init_op = iterator.initializer 251 get_next = iterator.get_next() 252 253 with self.test_session() as sess: 254 sess.run(init_op) 255 for _ in range(3): 256 sess.run(get_next) 257 # The 4th element is NaN, so `array_ops.check_numerics()` should fail. 258 with self.assertRaises(errors.InvalidArgumentError): 259 sess.run(get_next) 260 sess.run(get_next) 261 with self.assertRaises(errors.OutOfRangeError): 262 sess.run(get_next) 263 264 def testCaptureHashTable(self): 265 # NOTE(mrry): We must use the V2 variants of `HashTable` 266 # etc. because these produce a `tf.resource`-typed output that is 267 # compatible with the in-graph function implementation. 268 default_val = -1 269 keys = constant_op.constant(["brain", "salad", "surgery"]) 270 values = constant_op.constant([0, 1, 2], dtypes.int64) 271 table = lookup_ops.HashTable( 272 lookup_ops.KeyValueTensorInitializer(keys, values), default_val) 273 274 input_sentences = dataset_ops.Dataset.from_tensor_slices( 275 ["brain brain tank salad surgery", "surgery brain"]) 276 277 iterator = (input_sentences 278 .map(lambda x: string_ops.string_split([x]).values) 279 .map(table.lookup) 280 .make_initializable_iterator()) 281 init_op = iterator.initializer 282 get_next = iterator.get_next() 283 284 with self.test_session() as sess: 285 sess.run(table.init) 286 sess.run(init_op) 287 sess.run(get_next) 288 sess.run(get_next) 289 with self.assertRaises(errors.OutOfRangeError): 290 sess.run(get_next) 291 292 def testCaptureQueue(self): 293 elements = np.random.randint(100, size=[200]) 294 queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[]) 295 enqueue_op = queue.enqueue_many(elements) 296 close_op = queue.close() 297 iterator = (dataset_ops.Dataset.from_tensors(0).repeat(-1) 298 .map(lambda _: queue.dequeue()).make_initializable_iterator()) 299 init_op = iterator.initializer 300 get_next = iterator.get_next() 301 302 with self.test_session() as sess: 303 sess.run(enqueue_op) 304 sess.run(close_op) 305 sess.run(init_op) 306 for element in elements: 307 self.assertEqual(element, sess.run(get_next)) 308 with self.assertRaises(errors.OutOfRangeError): 309 sess.run(get_next) 310 311 def testCaptureSameResourceMultipleTimes(self): 312 elements = np.random.randint(100, size=[200]) 313 queue = data_flow_ops.FIFOQueue( 314 200, dtypes.int64, shapes=[], shared_name="shared_queue") 315 queue_2 = data_flow_ops.FIFOQueue( 316 200, dtypes.int64, shapes=[], shared_name="shared_queue") 317 318 enqueue_op = queue.enqueue_many(elements) 319 close_op = queue.close() 320 321 iterator = (dataset_ops.Dataset.from_tensors(0).repeat(-1) 322 .map(lambda _: (queue.dequeue(), queue_2.dequeue())) 323 .make_initializable_iterator()) 324 init_op = iterator.initializer 325 get_next = iterator.get_next() 326 327 with self.test_session() as sess: 328 sess.run(enqueue_op) 329 sess.run(close_op) 330 sess.run(init_op) 331 for i in range(100): 332 self.assertEqual(sorted([elements[i * 2], elements[i * 2 + 1]]), 333 sorted(sess.run(get_next))) 334 with self.assertRaises(errors.OutOfRangeError): 335 sess.run(get_next) 336 337 def testCaptureVariable(self): 338 counter_var = variable_scope.get_variable( 339 "counter", (), dtypes.int32, use_resource=True) 340 iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10) 341 .map(lambda _: counter_var.assign_add(1)) 342 .make_initializable_iterator()) 343 init_op = iterator.initializer 344 get_next = iterator.get_next() 345 346 with self.test_session() as sess: 347 sess.run(counter_var.initializer) 348 sess.run(init_op) 349 for i in range(10): 350 self.assertEqual(i, sess.run(counter_var)) 351 self.assertEqual(i + 1, sess.run(get_next)) 352 self.assertEqual(10, sess.run(counter_var)) 353 with self.assertRaises(errors.OutOfRangeError): 354 sess.run(get_next) 355 self.assertEqual(10, sess.run(counter_var)) 356 357 def testCaptureUninitializedVariableError(self): 358 counter_var = variable_scope.get_variable( 359 "counter", (), dtypes.int32, use_resource=True) 360 iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10) 361 .map(lambda _: counter_var.assign_add(1)) 362 .make_initializable_iterator()) 363 init_op = iterator.initializer 364 get_next = iterator.get_next() 365 366 with self.test_session() as sess: 367 sess.run(init_op) 368 with self.assertRaises(errors.NotFoundError): 369 sess.run(get_next) 370 371 def testSeededStatefulOperatorIsProperlyStateful(self): 372 iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10) 373 .map(lambda _: random_ops.random_uniform((), seed=11)).batch(2) 374 .make_initializable_iterator()) 375 init_op = iterator.initializer 376 get_next = iterator.get_next() 377 378 with self.test_session() as sess: 379 sess.run(init_op) 380 random_values = [] 381 with self.assertRaises(errors.OutOfRangeError): 382 while True: 383 random_values.extend(sess.run(get_next)) 384 self.assertEqual(10, len(random_values)) 385 self.assertGreater(np.abs(np.diff(random_values)).max(), 1e-6) 386 sess.run(init_op) 387 random_values_2 = [] 388 with self.assertRaises(errors.OutOfRangeError): 389 while True: 390 random_values_2.extend(sess.run(get_next)) 391 392 # Randomness is repeatable given same seed 393 self.assertAllClose(random_values, random_values_2) 394 395 def testMapDict(self): 396 iterator = (dataset_ops.Dataset.range(10) 397 .map(lambda x: {"foo": x * 2, "bar": x ** 2}) 398 .map(lambda d: d["foo"] + d["bar"]) 399 .make_initializable_iterator()) 400 init_op = iterator.initializer 401 get_next = iterator.get_next() 402 403 with self.test_session() as sess: 404 sess.run(init_op) 405 for i in range(10): 406 self.assertEqual(i * 2 + i ** 2, sess.run(get_next)) 407 with self.assertRaises(errors.OutOfRangeError): 408 sess.run(get_next) 409 410 def testMapNamedtuple(self, count=10): 411 # construct dataset of tuples 412 labels = dataset_ops.Dataset.range(count) 413 images = labels.map(lambda l: -l) 414 dataset_tuple = dataset_ops.Dataset.zip((labels, images)) 415 416 # convert dataset of tuples to dataset of namedtuples 417 example = namedtuple("Example", ["label", "image"]) 418 dataset_namedtuple = dataset_tuple.map(example) 419 420 def preprocess_tuple(label, image): 421 image = 2 * image 422 return label, image 423 424 def preprocess_namedtuple(example): 425 return example._replace(image=2 * example.image) 426 427 # preprocess both datasets 428 dataset_tuple = dataset_tuple.map(preprocess_tuple) 429 dataset_namedtuple = dataset_namedtuple.map(preprocess_namedtuple) 430 431 next_tuple = dataset_tuple.make_one_shot_iterator().get_next() 432 next_namedtuple = dataset_namedtuple.make_one_shot_iterator().get_next() 433 434 # make sure both datasets contain the same data 435 with self.test_session() as sess: 436 for i in range(count): 437 tuple_, namedtuple_ = sess.run([next_tuple, next_namedtuple]) 438 self.assertEqual(tuple_, namedtuple_) 439 self.assertEqual(tuple_, (i, -2 * i)) 440 441 with self.assertRaises(errors.OutOfRangeError): 442 sess.run(next_namedtuple) 443 444 def testUseStepContainerInMap(self): 445 row = np.arange(6) 446 iterator = ( 447 dataset_ops.Dataset.from_tensors(row) 448 .map(lambda elems: functional_ops.map_fn(lambda x: x * x, elems)) 449 .make_initializable_iterator()) 450 init_op = iterator.initializer 451 get_next = iterator.get_next() 452 453 with self.test_session() as sess: 454 sess.run(init_op) 455 self.assertAllEqual(row ** 2, sess.run(get_next)) 456 with self.assertRaises(errors.OutOfRangeError): 457 sess.run(get_next) 458 459 def testPrefetch(self): 460 # We will use this event to test that `_map_py_func()` has been 461 # invoked a certain number of times (6 times, to be exact) after 462 # consuming fewer elements from the iterator. 463 ev = threading.Event() 464 465 set_event_during_invocation = 5 466 467 def _map_py_func(x): 468 if x == set_event_during_invocation: 469 ev.set() 470 return x * x 471 472 def _map_fn(x): 473 return script_ops.py_func(_map_py_func, [x], x.dtype) 474 475 buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) 476 iterator = ( 477 dataset_ops.Dataset.range(100) 478 .map(_map_fn) 479 .prefetch(buffer_size_placeholder) 480 .make_initializable_iterator()) 481 init_op = iterator.initializer 482 get_next = iterator.get_next() 483 484 with self.test_session() as sess: 485 # Simple test that prefetch yields the expected values in the 486 # expected order. 487 for buffer_size in [1, 10, 100, 1000]: 488 sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size}) 489 for i in range(100): 490 self.assertEqual(i * i, sess.run(get_next)) 491 with self.assertRaises(errors.OutOfRangeError): 492 sess.run(get_next) 493 494 # We can indirectly observe that varying the buffer size has the 495 # intended effect by observing when `ev` is set (on the 6th 496 # invocation of `_map_py_func()`). 497 # NOTE(mrry): We do not test with `buffer_size == 498 # set_event_during_invocation`, because we must consume at least 499 # one element to start the prefetching. 500 for buffer_size in range(1, set_event_during_invocation): 501 event_will_be_set_after_consuming = ( 502 set_event_during_invocation - buffer_size + 1) 503 504 ev.clear() 505 sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size}) 506 for i in range(event_will_be_set_after_consuming): 507 self.assertFalse(ev.is_set()) 508 self.assertEqual(i * i, sess.run(get_next)) 509 ev.wait() 510 for i in range(event_will_be_set_after_consuming, 100): 511 self.assertEqual(i * i, sess.run(get_next)) 512 with self.assertRaises(errors.OutOfRangeError): 513 sess.run(get_next) 514 515 def testReturnList(self): 516 iterator = (dataset_ops.Dataset.range(10) 517 .map(lambda x: [x, constant_op.constant(37.0)]) 518 .make_initializable_iterator()) 519 init_op = iterator.initializer 520 get_next = iterator.get_next() 521 522 with self.test_session() as sess: 523 sess.run(init_op) 524 for i in range(10): 525 self.assertEqual((i, 37.0), sess.run(get_next)) 526 with self.assertRaises(errors.OutOfRangeError): 527 sess.run(get_next) 528 529 def testMultiOutputPyFunc(self): 530 # The `tf.py_func()` op returns a list of tensors for its outputs. 531 def _map_fn(x_tensor): 532 def _map_py_func(x): 533 return x, np.array(37.0, dtype=np.float64) 534 return script_ops.py_func( 535 _map_py_func, [x_tensor], [dtypes.int64, dtypes.float64]) 536 537 iterator = (dataset_ops.Dataset.range(10) 538 .map(_map_fn) 539 .make_initializable_iterator()) 540 init_op = iterator.initializer 541 get_next = iterator.get_next() 542 543 with self.test_session() as sess: 544 sess.run(init_op) 545 for i in range(10): 546 self.assertEqual((i, 37.0), sess.run(get_next)) 547 with self.assertRaises(errors.OutOfRangeError): 548 sess.run(get_next) 549 550 def assertSparseValuesEqual(self, a, b): 551 self.assertAllEqual(a.indices, b.indices) 552 self.assertAllEqual(a.values, b.values) 553 self.assertAllEqual(a.dense_shape, b.dense_shape) 554 555 def testSparse(self): 556 557 def _sparse(i): 558 return sparse_tensor.SparseTensorValue( 559 indices=np.array([[0, 0]]), 560 values=(i * np.array([1])), 561 dense_shape=np.array([1, 1])) 562 563 iterator = (dataset_ops.Dataset.range(10) 564 .map(_sparse) 565 .make_initializable_iterator()) 566 init_op = iterator.initializer 567 get_next = iterator.get_next() 568 569 with self.test_session() as sess: 570 sess.run(init_op) 571 for i in range(10): 572 actual = sess.run(get_next) 573 self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) 574 self.assertSparseValuesEqual(actual, _sparse(i)) 575 with self.assertRaises(errors.OutOfRangeError): 576 sess.run(get_next) 577 578 def testSparseChain(self): 579 580 def _sparse(i): 581 return sparse_tensor.SparseTensorValue( 582 indices=np.array([[0, 0]]), 583 values=(i * np.array([1])), 584 dense_shape=np.array([1, 1])) 585 586 def _check(i): 587 self.assertTrue(sparse_tensor.is_sparse(i)) 588 return sparse_ops.sparse_concat(0, [i, i]) 589 590 iterator = ( 591 dataset_ops.Dataset.range(10).map(_sparse).map(_check) 592 .make_initializable_iterator()) 593 init_op = iterator.initializer 594 get_next = iterator.get_next() 595 596 with self.test_session() as sess: 597 sess.run(init_op) 598 for i in range(10): 599 actual = sess.run(get_next) 600 self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) 601 self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval()) 602 with self.assertRaises(errors.OutOfRangeError): 603 sess.run(get_next) 604 605 606 class MapDatasetBenchmark(test.Benchmark): 607 608 def benchmarkChainOfMaps(self): 609 chain_lengths = [0, 1, 2, 5, 10, 20, 50] 610 for chain_length in chain_lengths: 611 with ops.Graph().as_default(): 612 dataset = dataset_ops.Dataset.from_tensors(0).repeat(None) 613 for _ in range(chain_length): 614 dataset = dataset.map(lambda x: x) 615 iterator = dataset.make_one_shot_iterator() 616 next_element = iterator.get_next() 617 618 with session.Session() as sess: 619 for _ in range(5): 620 sess.run(next_element.op) 621 deltas = [] 622 for _ in range(100): 623 start = time.time() 624 for _ in range(100): 625 sess.run(next_element.op) 626 end = time.time() 627 deltas.append(end - start) 628 629 median_wall_time = np.median(deltas) / 100 630 print("Map dataset chain length: %d Median wall time: %f" 631 % (chain_length, median_wall_time)) 632 self.report_benchmark( 633 iters=1000, wall_time=median_wall_time, 634 name="benchmark_map_dataset_chain_latency_%d" % chain_length) 635 636 def benchmarkMapFanOut(self): 637 fan_outs = [1, 2, 5, 10, 20, 50, 100] 638 for fan_out in fan_outs: 639 with ops.Graph().as_default(): 640 dataset = dataset_ops.Dataset.from_tensors( 641 tuple(0 for _ in range(fan_out))).repeat(None).map(lambda *xs: xs) 642 iterator = dataset.make_one_shot_iterator() 643 next_element = iterator.get_next() 644 645 with session.Session() as sess: 646 for _ in range(5): 647 sess.run(next_element[0].op) 648 deltas = [] 649 for _ in range(100): 650 start = time.time() 651 for _ in range(100): 652 sess.run(next_element[0].op) 653 end = time.time() 654 deltas.append(end - start) 655 656 median_wall_time = np.median(deltas) / 100 657 print("Map dataset fan out: %d Median wall time: %f" 658 % (fan_out, median_wall_time)) 659 self.report_benchmark( 660 iters=1000, wall_time=median_wall_time, 661 name="benchmark_map_dataset_fan_out_%d" % fan_out) 662 663 664 if __name__ == "__main__": 665 test.main() 666