1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for `tf.data.experimental.parallel_interleave()`.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import itertools 21 import math 22 import threading 23 import time 24 25 import numpy as np 26 from six.moves import zip_longest 27 28 from tensorflow.python.data.experimental.ops import interleave_ops 29 from tensorflow.python.data.kernel_tests import test_base 30 from tensorflow.python.data.ops import dataset_ops 31 from tensorflow.python.framework import dtypes 32 from tensorflow.python.framework import errors 33 from tensorflow.python.framework import sparse_tensor 34 from tensorflow.python.framework import test_util 35 from tensorflow.python.ops import math_ops 36 from tensorflow.python.ops import script_ops 37 from tensorflow.python.ops import sparse_ops 38 from tensorflow.python.platform import test 39 40 41 @test_util.run_all_in_graph_and_eager_modes 42 class ParallelInterleaveTest(test_base.DatasetTestBase): 43 44 def setUp(self): 45 46 self.error = None 47 self.repeat_count = 2 48 49 # Set up threading events used to sequence when items are produced that 50 # are subsequently interleaved. These events allow us to deterministically 51 # simulate slowdowns and force sloppiness. 52 self.read_coordination_events = {} 53 self.write_coordination_events = {} 54 # input values [4, 5, 6] are the common case for the tests; set defaults 55 for i in range(4, 7): 56 self.read_coordination_events[i] = threading.Semaphore(0) 57 self.write_coordination_events[i] = threading.Event() 58 59 def dataset_fn(self, input_values, cycle_length, block_length, sloppy, 60 buffer_output_elements, prefetch_input_elements): 61 62 def map_py_fn(x): 63 self.write_coordination_events[x].wait() 64 self.write_coordination_events[x].clear() 65 self.read_coordination_events[x].release() 66 if self.error: 67 err = self.error 68 self.error = None 69 raise err # pylint: disable=raising-bad-type 70 return x * x 71 72 def map_fn(x): 73 return script_ops.py_func(map_py_fn, [x], x.dtype) 74 75 def interleave_fn(x): 76 dataset = dataset_ops.Dataset.from_tensors(x) 77 dataset = dataset.repeat(x) 78 return dataset.map(map_fn) 79 80 return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( 81 self.repeat_count).apply( 82 interleave_ops.parallel_interleave( 83 interleave_fn, cycle_length, block_length, sloppy, 84 buffer_output_elements, prefetch_input_elements)) 85 86 def _interleave(self, lists, cycle_length, block_length): 87 """Python implementation of interleave used for testing.""" 88 num_open = 0 89 90 # `all_iterators` acts as a queue of iterators over each element of `lists`. 91 all_iterators = [iter(l) for l in lists] 92 93 # `open_iterators` are the iterators whose elements are currently being 94 # interleaved. 95 open_iterators = [] 96 for i in range(cycle_length): 97 if all_iterators: 98 open_iterators.append(all_iterators.pop(0)) 99 num_open += 1 100 else: 101 open_iterators.append(None) 102 103 while num_open or all_iterators: 104 for i in range(cycle_length): 105 if open_iterators[i] is None: 106 if all_iterators: 107 open_iterators[i] = all_iterators.pop(0) 108 num_open += 1 109 else: 110 continue 111 for _ in range(block_length): 112 try: 113 yield next(open_iterators[i]) 114 except StopIteration: 115 open_iterators[i] = None 116 num_open -= 1 117 break 118 119 def testPythonImplementation(self): 120 input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], 121 [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] 122 123 # Cycle length 1 acts like `Dataset.flat_map()`. 124 expected_elements = itertools.chain(*input_lists) 125 for expected, produced in zip(expected_elements, 126 self._interleave(input_lists, 1, 1)): 127 self.assertEqual(expected, produced) 128 129 # Cycle length > 1. 130 expected_elements = [ 131 4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 132 6, 5, 6, 5, 6, 6 133 ] 134 for index, (expected, produced) in enumerate( 135 zip_longest(expected_elements, self._interleave(input_lists, 2, 1))): 136 self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % 137 (index, expected, produced)) 138 139 def testPythonImplementationBlockLength(self): 140 input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2 141 expected_elements = [ 142 4, 4, 5, 5, 4, 4, 5, 5, 5, 6, 6, 4, 4, 6, 6, 4, 4, 6, 6, 5, 5, 6, 6, 5, 143 5, 6, 6, 5, 6, 6 144 ] 145 for index, (expected, produced) in enumerate( 146 zip_longest(expected_elements, self._interleave(input_lists, 2, 2))): 147 self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % 148 (index, expected, produced)) 149 150 def testPythonImplementationEmptyLists(self): 151 input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [], 152 [6, 6, 6, 6, 6, 6]] 153 154 expected_elements = [ 155 4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6 156 ] 157 for index, (expected, produced) in enumerate( 158 zip_longest(expected_elements, self._interleave(input_lists, 2, 1))): 159 self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % 160 (index, expected, produced)) 161 162 def _clear_coordination_events(self): 163 for i in range(4, 7): 164 self.read_coordination_events[i] = threading.Semaphore(0) 165 self.write_coordination_events[i].clear() 166 167 def _allow_all_map_threads(self): 168 for i in range(4, 7): 169 self.write_coordination_events[i].set() 170 171 def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0): 172 # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and 173 # `Dataset.flat_map()` and is single-threaded. No synchronization required. 174 self._clear_coordination_events() 175 next_element = self.getNext( 176 self.dataset_fn( 177 input_values=np.int64([4, 5, 6]), 178 cycle_length=1, 179 block_length=1, 180 sloppy=sloppy, 181 buffer_output_elements=1, 182 prefetch_input_elements=prefetch_input_elements)) 183 for expected_element in self._interleave( 184 [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 1): 185 self.write_coordination_events[expected_element].set() 186 self.assertEqual(expected_element * expected_element, 187 self.evaluate(next_element())) 188 with self.assertRaises(errors.OutOfRangeError): 189 self.evaluate(next_element()) 190 191 def testSingleThreaded(self): 192 self._testSingleThreaded() 193 194 def testSingleThreadedSloppy(self): 195 self._testSingleThreaded(sloppy=True) 196 197 def testSingleThreadedPrefetch1Itr(self): 198 self._testSingleThreaded(prefetch_input_elements=1) 199 200 def testSingleThreadedPrefetch1ItrSloppy(self): 201 self._testSingleThreaded(prefetch_input_elements=1, sloppy=True) 202 203 def testSingleThreadedRagged(self): 204 # Tests a sequence with wildly different elements per iterator. 205 self._clear_coordination_events() 206 next_element = self.getNext( 207 self.dataset_fn( 208 input_values=np.int64([3, 7, 4]), 209 cycle_length=2, 210 block_length=1, 211 sloppy=False, 212 buffer_output_elements=1, 213 prefetch_input_elements=1)) 214 215 # Add coordination values for 3 and 7 216 self.read_coordination_events[3] = threading.Semaphore(0) 217 self.write_coordination_events[3] = threading.Event() 218 self.read_coordination_events[7] = threading.Semaphore(0) 219 self.write_coordination_events[7] = threading.Event() 220 221 for expected_element in self._interleave( 222 [[3] * 3, [7] * 7, [4] * 4] * self.repeat_count, 2, 1): 223 self.write_coordination_events[expected_element].set() 224 output = self.evaluate(next_element()) 225 self.assertEqual(expected_element * expected_element, output) 226 with self.assertRaises(errors.OutOfRangeError): 227 self.evaluate(next_element()) 228 229 def _testTwoThreadsNoContention(self, sloppy=False): 230 # num_threads > 1. 231 # Explicit coordination should result in `Dataset.interleave()` behavior 232 self._clear_coordination_events() 233 done_first_event = False 234 next_element = self.getNext( 235 self.dataset_fn( 236 input_values=np.int64([4, 5, 6]), 237 cycle_length=2, 238 block_length=1, 239 sloppy=sloppy, 240 buffer_output_elements=1, 241 prefetch_input_elements=1)) 242 for i, expected_element in enumerate( 243 self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, 244 1)): 245 self.write_coordination_events[expected_element].set() 246 if done_first_event: # First event starts the worker threads. 247 self.read_coordination_events[expected_element].acquire() 248 actual_element = self.evaluate(next_element()) 249 if not done_first_event: 250 self.read_coordination_events[expected_element].acquire() 251 done_first_event = True 252 self.assertEqual( 253 expected_element * expected_element, actual_element, 254 "At index %s: %s expected, got: %s" % (i, expected_element, 255 actual_element)) 256 with self.assertRaises(errors.OutOfRangeError): 257 self.evaluate(next_element()) 258 259 def testTwoThreadsNoContention(self): 260 self._testTwoThreadsNoContention() 261 262 def testTwoThreadsNoContentionSloppy(self): 263 self._testTwoThreadsNoContention(sloppy=True) 264 265 def _testTwoThreadsNoContentionWithRaces(self, sloppy=False): 266 """Tests where all the workers race in producing elements. 267 268 Note: this is in contrast with the previous test which carefully sequences 269 the execution of the map functions. 270 271 Args: 272 sloppy: Whether to be sloppy or not. 273 """ 274 self._clear_coordination_events() 275 done_first_event = False 276 next_element = self.getNext( 277 self.dataset_fn( 278 input_values=np.int64([4, 5, 6]), 279 cycle_length=2, 280 block_length=1, 281 sloppy=sloppy, 282 buffer_output_elements=1, 283 prefetch_input_elements=1)) 284 for i, expected_element in enumerate( 285 self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, 286 1)): 287 if done_first_event: # First event starts the worker threads. 288 self._allow_all_map_threads() 289 self.read_coordination_events[expected_element].acquire() 290 else: 291 self.write_coordination_events[expected_element].set() 292 time.sleep(0.5) # Sleep to consistently "avoid" the race condition. 293 actual_element = self.evaluate(next_element()) 294 if not done_first_event: 295 done_first_event = True 296 self.assertTrue( 297 self.read_coordination_events[expected_element].acquire(False)) 298 self.assertEqual( 299 expected_element * expected_element, actual_element, 300 "At index %s: %s expected, got: %s" % (i, expected_element, 301 actual_element)) 302 with self.assertRaises(errors.OutOfRangeError): 303 self.evaluate(next_element()) 304 305 def testTwoThreadsNoContentionWithRaces(self): 306 self._testTwoThreadsNoContentionWithRaces() 307 308 def testTwoThreadsNoContentionWithRacesSloppy(self): 309 self._testTwoThreadsNoContentionWithRaces(sloppy=True) 310 311 def _testTwoThreadsNoContentionBlockLength(self, sloppy=False): 312 # num_threads > 1. 313 # Explicit coordination should result in `Dataset.interleave()` behavior 314 self._clear_coordination_events() 315 done_first_event = False 316 next_element = self.getNext( 317 self.dataset_fn( 318 input_values=np.int64([4, 5, 6]), 319 cycle_length=2, 320 block_length=2, 321 sloppy=sloppy, 322 buffer_output_elements=1, 323 prefetch_input_elements=1)) 324 for i, expected_element in enumerate( 325 self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, 326 2)): 327 self.write_coordination_events[expected_element].set() 328 if done_first_event: # First event starts the worker threads. 329 self.read_coordination_events[expected_element].acquire() 330 actual_element = self.evaluate(next_element()) 331 if not done_first_event: 332 done_first_event = True 333 self.read_coordination_events[expected_element].acquire() 334 self.assertEqual( 335 expected_element * expected_element, actual_element, 336 "At index %s: %s expected, got: %s" % (i, expected_element, 337 actual_element)) 338 with self.assertRaises(errors.OutOfRangeError): 339 self.evaluate(next_element()) 340 341 def testTwoThreadsNoContentionBlockLength(self): 342 self._testTwoThreadsNoContentionBlockLength() 343 344 def testTwoThreadsNoContentionBlockLengthSloppy(self): 345 self._testTwoThreadsNoContentionBlockLength(sloppy=True) 346 347 def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False): 348 """Tests where all the workers race in producing elements. 349 350 Note: this is in contrast with the previous test which carefully sequences 351 the execution of the map functions. 352 353 354 Args: 355 sloppy: Whether to be sloppy or not. 356 """ 357 self._clear_coordination_events() 358 done_first_event = False 359 next_element = self.getNext( 360 self.dataset_fn( 361 input_values=np.int64([4, 5, 6]), 362 cycle_length=2, 363 block_length=2, 364 sloppy=sloppy, 365 buffer_output_elements=1, 366 prefetch_input_elements=1)) 367 for i, expected_element in enumerate( 368 self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, 369 2)): 370 if done_first_event: # First event starts the worker threads. 371 self._allow_all_map_threads() 372 self.read_coordination_events[expected_element].acquire() 373 else: 374 self.write_coordination_events[expected_element].set() 375 time.sleep(0.5) # Sleep to consistently "avoid" the race condition. 376 actual_element = self.evaluate(next_element()) 377 if not done_first_event: 378 done_first_event = True 379 self.assertTrue( 380 self.read_coordination_events[expected_element].acquire(False)) 381 self.assertEqual( 382 expected_element * expected_element, actual_element, 383 "At index %s: %s expected, got: %s" % (i, expected_element, 384 actual_element)) 385 with self.assertRaises(errors.OutOfRangeError): 386 self.evaluate(next_element()) 387 388 def testTwoThreadsNoContentionWithRacesAndBlocking(self): 389 self._testTwoThreadsNoContentionWithRacesAndBlocking() 390 391 def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self): 392 self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True) 393 394 def _testEmptyInput(self, sloppy=False): 395 # Empty input. 396 self._clear_coordination_events() 397 next_element = self.getNext( 398 self.dataset_fn( 399 input_values=np.int64([]), 400 cycle_length=2, 401 block_length=3, 402 sloppy=sloppy, 403 buffer_output_elements=1, 404 prefetch_input_elements=0)) 405 with self.assertRaises(errors.OutOfRangeError): 406 self.evaluate(next_element()) 407 408 def testEmptyInput(self): 409 self._testEmptyInput() 410 411 def testEmptyInputSloppy(self): 412 self._testEmptyInput(sloppy=True) 413 414 def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False): 415 # Non-empty input leading to empty output. 416 self._clear_coordination_events() 417 next_element = self.getNext( 418 self.dataset_fn( 419 input_values=np.int64([0, 0, 0]), 420 cycle_length=2, 421 block_length=3, 422 sloppy=sloppy, 423 buffer_output_elements=1, 424 prefetch_input_elements=0)) 425 with self.assertRaises(errors.OutOfRangeError): 426 self.evaluate(next_element()) 427 428 def testNonEmptyInputIntoEmptyOutputs(self): 429 self._testNonEmptyInputIntoEmptyOutputs() 430 431 def testNonEmptyInputIntoEmptyOutputsSloppy(self): 432 self._testNonEmptyInputIntoEmptyOutputs(sloppy=True) 433 434 def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1): 435 race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds 436 # Mixture of non-empty and empty interleaved datasets. 437 self._clear_coordination_events() 438 done_first_event = False 439 next_element = self.getNext( 440 self.dataset_fn( 441 input_values=np.int64([4, 0, 6]), 442 cycle_length=2, 443 block_length=1, 444 sloppy=sloppy, 445 buffer_output_elements=1, 446 prefetch_input_elements=prefetch_input_elements)) 447 for i, expected_element in enumerate( 448 self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)): 449 self.write_coordination_events[expected_element].set() 450 # First event starts the worker threads. Additionally, when running the 451 # sloppy case with prefetch_input_elements=0, we get stuck if we wait 452 # for the read coordination event for certain event orderings in the 453 # presence of finishing iterators. 454 if done_first_event and not (sloppy and (i in race_indices)): 455 self.read_coordination_events[expected_element].acquire() 456 actual_element = self.evaluate(next_element()) 457 if not done_first_event or (sloppy and (i in race_indices)): 458 done_first_event = True 459 self.read_coordination_events[expected_element].acquire() 460 self.assertEqual( 461 expected_element * expected_element, actual_element, 462 "At index %s: %s expected, got: %s" % (i, expected_element, 463 actual_element)) 464 465 def testPartiallyEmptyOutputs(self): 466 self._testPartiallyEmptyOutputs() 467 468 def testPartiallyEmptyOutputsSloppy(self): 469 self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0) 470 471 def testDelayedOutputSloppy(self): 472 # Explicitly control the sequence of events to ensure we correctly avoid 473 # head-of-line blocking. 474 self._clear_coordination_events() 475 next_element = self.getNext( 476 self.dataset_fn( 477 input_values=np.int64([4, 5, 6]), 478 cycle_length=2, 479 block_length=1, 480 sloppy=True, 481 buffer_output_elements=1, 482 prefetch_input_elements=0)) 483 484 mis_ordering = [ 485 4, 4, 5, 4, 5, 5, 4, 5, 6, 6, 6, 5, 4, 4, 6, 6, 4, 4, 6, 5, 6, 6, 6, 6, 486 5, 5, 5, 5, 6, 6 487 ] 488 for element in mis_ordering: 489 self.write_coordination_events[element].set() 490 self.assertEqual(element * element, self.evaluate(next_element())) 491 self.assertTrue(self.read_coordination_events[element].acquire(False)) 492 with self.assertRaises(errors.OutOfRangeError): 493 self.evaluate(next_element()) 494 495 def testBlockLengthWithContentionSloppy(self): 496 self._clear_coordination_events() 497 done_first_event = False 498 next_element = self.getNext( 499 self.dataset_fn( 500 input_values=np.int64([4, 5, 6]), 501 cycle_length=2, 502 block_length=1, 503 sloppy=True, 504 buffer_output_elements=1, 505 prefetch_input_elements=1)) 506 # Test against a generating sequence that differs from the uncontended 507 # case, in order to prove sloppy correctness. 508 for i, expected_element in enumerate( 509 self._interleave( 510 [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 511 cycle_length=2, 512 block_length=3)): 513 self.write_coordination_events[expected_element].set() 514 if done_first_event: # First event starts the worker threads. 515 self.read_coordination_events[expected_element].acquire() 516 actual_element = self.evaluate(next_element()) 517 if not done_first_event: 518 self.read_coordination_events[expected_element].acquire() 519 done_first_event = True 520 self.assertEqual( 521 expected_element * expected_element, actual_element, 522 "At index %s: %s expected, got: %s" % (i, expected_element, 523 actual_element)) 524 with self.assertRaises(errors.OutOfRangeError): 525 self.evaluate(next_element()) 526 527 def _testEarlyExit(self, sloppy=False): 528 # Exiting without consuming all input should not block 529 self._clear_coordination_events() 530 next_element = self.getNext( 531 self.dataset_fn( 532 input_values=np.int64([4, 5, 6]), 533 cycle_length=3, 534 block_length=2, 535 sloppy=sloppy, 536 buffer_output_elements=1, 537 prefetch_input_elements=0)) 538 for i in range(4, 7): 539 self.write_coordination_events[i].set() 540 elem = self.evaluate(next_element()) # Start all workers 541 # Allow the one successful worker to progress beyond the py_func again. 542 elem = int(math.sqrt(elem)) 543 self.write_coordination_events[elem].set() 544 self.read_coordination_events[elem].acquire() 545 # Allow the prefetch to succeed 546 for i in range(4, 7): 547 self.read_coordination_events[i].acquire() 548 self.write_coordination_events[i].set() 549 550 def testEarlyExit(self): 551 self._testEarlyExit() 552 553 def testEarlyExitSloppy(self): 554 self._testEarlyExit(sloppy=True) 555 556 def _testTooManyReaders(self, sloppy=False): 557 558 def interleave_fn(x): 559 dataset = dataset_ops.Dataset.from_tensors(x) 560 dataset = dataset.repeat(math_ops.cast(x, dtype=dtypes.int64)) 561 return dataset 562 563 dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6]) 564 dataset = dataset.repeat(self.repeat_count) 565 dataset = dataset.apply( 566 interleave_ops.parallel_interleave( 567 interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy)) 568 get_next = self.getNext(dataset) 569 output_values = [] 570 for _ in range(30): 571 output_values.append(self.evaluate(get_next())) 572 573 expected_values = self._interleave( 574 [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2) 575 self.assertItemsEqual(output_values, expected_values) 576 577 def testTooManyReaders(self): 578 self._testTooManyReaders() 579 580 def testTooManyReadersSloppy(self): 581 self._testTooManyReaders(sloppy=True) 582 583 def testSparse(self): 584 def _map_fn(i): 585 return sparse_tensor.SparseTensor( 586 indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) 587 588 def _interleave_fn(x): 589 return dataset_ops.Dataset.from_tensor_slices( 590 sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) 591 592 dataset = dataset_ops.Dataset.range(10).map(_map_fn).apply( 593 interleave_ops.parallel_interleave(_interleave_fn, cycle_length=1)) 594 get_next = self.getNext(dataset) 595 596 for i in range(10): 597 for j in range(2): 598 expected = [i, 0] if j % 2 == 0 else [0, -i] 599 self.assertAllEqual(expected, self.evaluate(get_next())) 600 with self.assertRaises(errors.OutOfRangeError): 601 self.evaluate(get_next()) 602 603 def testErrorsInOutputFn(self): 604 self._clear_coordination_events() 605 next_element = self.getNext( 606 self.dataset_fn( 607 input_values=np.int64([4, 5, 6]), 608 cycle_length=2, 609 block_length=1, 610 sloppy=False, 611 buffer_output_elements=1, 612 prefetch_input_elements=0)) 613 614 except_on_element_indices = set([3]) 615 616 for i, expected_element in enumerate( 617 self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, 618 1)): 619 if i in except_on_element_indices: 620 self.error = ValueError() 621 self.write_coordination_events[expected_element].set() 622 with self.assertRaises(errors.InvalidArgumentError): 623 self.evaluate(next_element()) 624 else: 625 self.write_coordination_events[expected_element].set() 626 actual_element = self.evaluate(next_element()) 627 self.assertEqual( 628 expected_element * expected_element, actual_element, 629 "At index %s: %s expected, got: %s" % (i, expected_element, 630 actual_element)) 631 with self.assertRaises(errors.OutOfRangeError): 632 self.evaluate(next_element()) 633 634 def testErrorsInInputFn(self): 635 636 def map_py_fn(x): 637 if x == 5: 638 raise ValueError() 639 return x 640 641 def map_fn(x): 642 return script_ops.py_func(map_py_fn, [x], x.dtype) 643 644 def interleave_fn(x): 645 dataset = dataset_ops.Dataset.from_tensors(x) 646 dataset = dataset.repeat(x) 647 return dataset 648 649 def dataset_fn(input_values, cycle_length, block_length, sloppy, 650 buffer_output_elements, prefetch_input_elements): 651 return dataset_ops.Dataset.from_tensor_slices(input_values).map( 652 map_fn).repeat(self.repeat_count).apply( 653 interleave_ops.parallel_interleave( 654 interleave_fn, cycle_length, block_length, sloppy, 655 buffer_output_elements, prefetch_input_elements)) 656 657 next_element = self.getNext( 658 dataset_fn( 659 input_values=np.int64([4, 5, 6]), 660 cycle_length=2, 661 block_length=1, 662 sloppy=False, 663 buffer_output_elements=1, 664 prefetch_input_elements=0)) 665 for i, expected_element in enumerate( 666 self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)): 667 if expected_element == 5: 668 with self.assertRaises(errors.InvalidArgumentError): 669 self.evaluate(next_element()) 670 else: 671 actual_element = self.evaluate(next_element()) 672 self.assertEqual( 673 expected_element, actual_element, 674 "At index %s: %s expected, got: %s" % (i, expected_element, 675 actual_element)) 676 with self.assertRaises(errors.OutOfRangeError): 677 self.evaluate(next_element()) 678 679 def testErrorsInInterleaveFn(self): 680 681 def map_py_fn(x): 682 if x == 5: 683 raise ValueError() 684 return x 685 686 def interleave_fn(x): 687 dataset = dataset_ops.Dataset.from_tensors(x) 688 y = script_ops.py_func(map_py_fn, [x], x.dtype) 689 dataset = dataset.repeat(y) 690 return dataset 691 692 def dataset_fn(input_values, cycle_length, block_length, sloppy, 693 buffer_output_elements, prefetch_input_elements): 694 return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( 695 self.repeat_count).apply( 696 interleave_ops.parallel_interleave( 697 interleave_fn, cycle_length, block_length, sloppy, 698 buffer_output_elements, prefetch_input_elements)) 699 700 next_element = self.getNext( 701 dataset_fn( 702 input_values=np.int64([4, 5, 6]), 703 cycle_length=2, 704 block_length=1, 705 sloppy=False, 706 buffer_output_elements=1, 707 prefetch_input_elements=0)) 708 for i, expected_element in enumerate( 709 self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)): 710 if expected_element == 5: 711 with self.assertRaises(errors.InvalidArgumentError): 712 self.evaluate(next_element()) 713 else: 714 actual_element = self.evaluate(next_element()) 715 self.assertEqual( 716 expected_element, actual_element, 717 "At index %s: %s expected, got: %s" % (i, expected_element, 718 actual_element)) 719 with self.assertRaises(errors.OutOfRangeError): 720 self.evaluate(next_element()) 721 722 def testShutdownRace(self): 723 dataset = dataset_ops.Dataset.range(20) 724 map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1)) 725 dataset = dataset.apply( 726 interleave_ops.parallel_interleave( 727 map_fn, 728 cycle_length=3, 729 sloppy=False, 730 buffer_output_elements=1, 731 prefetch_input_elements=0)) 732 dataset = dataset.batch(32) 733 734 results = [] 735 for _ in range(2): 736 elements = [] 737 next_element = self.getNext(dataset) 738 try: 739 while True: 740 elements.extend(self.evaluate(next_element())) 741 except errors.OutOfRangeError: 742 pass 743 results.append(elements) 744 self.assertAllEqual(results[0], results[1]) 745 746 747 if __name__ == "__main__": 748 test.main() 749