Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for `tf.data.experimental.parallel_interleave()`."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import itertools
     21 import math
     22 import threading
     23 import time
     24 
     25 import numpy as np
     26 from six.moves import zip_longest
     27 
     28 from tensorflow.python.data.experimental.ops import interleave_ops
     29 from tensorflow.python.data.kernel_tests import test_base
     30 from tensorflow.python.data.ops import dataset_ops
     31 from tensorflow.python.framework import dtypes
     32 from tensorflow.python.framework import errors
     33 from tensorflow.python.framework import sparse_tensor
     34 from tensorflow.python.framework import test_util
     35 from tensorflow.python.ops import math_ops
     36 from tensorflow.python.ops import script_ops
     37 from tensorflow.python.ops import sparse_ops
     38 from tensorflow.python.platform import test
     39 
     40 
     41 @test_util.run_all_in_graph_and_eager_modes
     42 class ParallelInterleaveTest(test_base.DatasetTestBase):
     43 
     44   def setUp(self):
     45 
     46     self.error = None
     47     self.repeat_count = 2
     48 
     49     # Set up threading events used to sequence when items are produced that
     50     # are subsequently interleaved. These events allow us to deterministically
     51     # simulate slowdowns and force sloppiness.
     52     self.read_coordination_events = {}
     53     self.write_coordination_events = {}
     54     # input values [4, 5, 6] are the common case for the tests; set defaults
     55     for i in range(4, 7):
     56       self.read_coordination_events[i] = threading.Semaphore(0)
     57       self.write_coordination_events[i] = threading.Event()
     58 
     59   def dataset_fn(self, input_values, cycle_length, block_length, sloppy,
     60                  buffer_output_elements, prefetch_input_elements):
     61 
     62     def map_py_fn(x):
     63       self.write_coordination_events[x].wait()
     64       self.write_coordination_events[x].clear()
     65       self.read_coordination_events[x].release()
     66       if self.error:
     67         err = self.error
     68         self.error = None
     69         raise err  # pylint: disable=raising-bad-type
     70       return x * x
     71 
     72     def map_fn(x):
     73       return script_ops.py_func(map_py_fn, [x], x.dtype)
     74 
     75     def interleave_fn(x):
     76       dataset = dataset_ops.Dataset.from_tensors(x)
     77       dataset = dataset.repeat(x)
     78       return dataset.map(map_fn)
     79 
     80     return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
     81         self.repeat_count).apply(
     82             interleave_ops.parallel_interleave(
     83                 interleave_fn, cycle_length, block_length, sloppy,
     84                 buffer_output_elements, prefetch_input_elements))
     85 
     86   def _interleave(self, lists, cycle_length, block_length):
     87     """Python implementation of interleave used for testing."""
     88     num_open = 0
     89 
     90     # `all_iterators` acts as a queue of iterators over each element of `lists`.
     91     all_iterators = [iter(l) for l in lists]
     92 
     93     # `open_iterators` are the iterators whose elements are currently being
     94     # interleaved.
     95     open_iterators = []
     96     for i in range(cycle_length):
     97       if all_iterators:
     98         open_iterators.append(all_iterators.pop(0))
     99         num_open += 1
    100       else:
    101         open_iterators.append(None)
    102 
    103     while num_open or all_iterators:
    104       for i in range(cycle_length):
    105         if open_iterators[i] is None:
    106           if all_iterators:
    107             open_iterators[i] = all_iterators.pop(0)
    108             num_open += 1
    109           else:
    110             continue
    111         for _ in range(block_length):
    112           try:
    113             yield next(open_iterators[i])
    114           except StopIteration:
    115             open_iterators[i] = None
    116             num_open -= 1
    117             break
    118 
    119   def testPythonImplementation(self):
    120     input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
    121                    [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]
    122 
    123     # Cycle length 1 acts like `Dataset.flat_map()`.
    124     expected_elements = itertools.chain(*input_lists)
    125     for expected, produced in zip(expected_elements,
    126                                   self._interleave(input_lists, 1, 1)):
    127       self.assertEqual(expected, produced)
    128 
    129     # Cycle length > 1.
    130     expected_elements = [
    131         4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5,
    132         6, 5, 6, 5, 6, 6
    133     ]
    134     for index, (expected, produced) in enumerate(
    135         zip_longest(expected_elements, self._interleave(input_lists, 2, 1))):
    136       self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
    137                        (index, expected, produced))
    138 
    139   def testPythonImplementationBlockLength(self):
    140     input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2
    141     expected_elements = [
    142         4, 4, 5, 5, 4, 4, 5, 5, 5, 6, 6, 4, 4, 6, 6, 4, 4, 6, 6, 5, 5, 6, 6, 5,
    143         5, 6, 6, 5, 6, 6
    144     ]
    145     for index, (expected, produced) in enumerate(
    146         zip_longest(expected_elements, self._interleave(input_lists, 2, 2))):
    147       self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
    148                        (index, expected, produced))
    149 
    150   def testPythonImplementationEmptyLists(self):
    151     input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [],
    152                    [6, 6, 6, 6, 6, 6]]
    153 
    154     expected_elements = [
    155         4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6
    156     ]
    157     for index, (expected, produced) in enumerate(
    158         zip_longest(expected_elements, self._interleave(input_lists, 2, 1))):
    159       self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
    160                        (index, expected, produced))
    161 
    162   def _clear_coordination_events(self):
    163     for i in range(4, 7):
    164       self.read_coordination_events[i] = threading.Semaphore(0)
    165       self.write_coordination_events[i].clear()
    166 
    167   def _allow_all_map_threads(self):
    168     for i in range(4, 7):
    169       self.write_coordination_events[i].set()
    170 
    171   def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0):
    172     # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
    173     # `Dataset.flat_map()` and is single-threaded. No synchronization required.
    174     self._clear_coordination_events()
    175     next_element = self.getNext(
    176         self.dataset_fn(
    177             input_values=np.int64([4, 5, 6]),
    178             cycle_length=1,
    179             block_length=1,
    180             sloppy=sloppy,
    181             buffer_output_elements=1,
    182             prefetch_input_elements=prefetch_input_elements))
    183     for expected_element in self._interleave(
    184         [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 1):
    185       self.write_coordination_events[expected_element].set()
    186       self.assertEqual(expected_element * expected_element,
    187                        self.evaluate(next_element()))
    188     with self.assertRaises(errors.OutOfRangeError):
    189       self.evaluate(next_element())
    190 
    191   def testSingleThreaded(self):
    192     self._testSingleThreaded()
    193 
    194   def testSingleThreadedSloppy(self):
    195     self._testSingleThreaded(sloppy=True)
    196 
    197   def testSingleThreadedPrefetch1Itr(self):
    198     self._testSingleThreaded(prefetch_input_elements=1)
    199 
    200   def testSingleThreadedPrefetch1ItrSloppy(self):
    201     self._testSingleThreaded(prefetch_input_elements=1, sloppy=True)
    202 
    203   def testSingleThreadedRagged(self):
    204     # Tests a sequence with wildly different elements per iterator.
    205     self._clear_coordination_events()
    206     next_element = self.getNext(
    207         self.dataset_fn(
    208             input_values=np.int64([3, 7, 4]),
    209             cycle_length=2,
    210             block_length=1,
    211             sloppy=False,
    212             buffer_output_elements=1,
    213             prefetch_input_elements=1))
    214 
    215     # Add coordination values for 3 and 7
    216     self.read_coordination_events[3] = threading.Semaphore(0)
    217     self.write_coordination_events[3] = threading.Event()
    218     self.read_coordination_events[7] = threading.Semaphore(0)
    219     self.write_coordination_events[7] = threading.Event()
    220 
    221     for expected_element in self._interleave(
    222         [[3] * 3, [7] * 7, [4] * 4] * self.repeat_count, 2, 1):
    223       self.write_coordination_events[expected_element].set()
    224       output = self.evaluate(next_element())
    225       self.assertEqual(expected_element * expected_element, output)
    226     with self.assertRaises(errors.OutOfRangeError):
    227       self.evaluate(next_element())
    228 
    229   def _testTwoThreadsNoContention(self, sloppy=False):
    230     # num_threads > 1.
    231     # Explicit coordination should result in `Dataset.interleave()` behavior
    232     self._clear_coordination_events()
    233     done_first_event = False
    234     next_element = self.getNext(
    235         self.dataset_fn(
    236             input_values=np.int64([4, 5, 6]),
    237             cycle_length=2,
    238             block_length=1,
    239             sloppy=sloppy,
    240             buffer_output_elements=1,
    241             prefetch_input_elements=1))
    242     for i, expected_element in enumerate(
    243         self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
    244                          1)):
    245       self.write_coordination_events[expected_element].set()
    246       if done_first_event:  # First event starts the worker threads.
    247         self.read_coordination_events[expected_element].acquire()
    248       actual_element = self.evaluate(next_element())
    249       if not done_first_event:
    250         self.read_coordination_events[expected_element].acquire()
    251         done_first_event = True
    252       self.assertEqual(
    253           expected_element * expected_element, actual_element,
    254           "At index %s: %s expected, got: %s" % (i, expected_element,
    255                                                  actual_element))
    256     with self.assertRaises(errors.OutOfRangeError):
    257       self.evaluate(next_element())
    258 
    259   def testTwoThreadsNoContention(self):
    260     self._testTwoThreadsNoContention()
    261 
    262   def testTwoThreadsNoContentionSloppy(self):
    263     self._testTwoThreadsNoContention(sloppy=True)
    264 
    265   def _testTwoThreadsNoContentionWithRaces(self, sloppy=False):
    266     """Tests where all the workers race in producing elements.
    267 
    268     Note: this is in contrast with the previous test which carefully sequences
    269     the execution of the map functions.
    270 
    271     Args:
    272       sloppy: Whether to be sloppy or not.
    273     """
    274     self._clear_coordination_events()
    275     done_first_event = False
    276     next_element = self.getNext(
    277         self.dataset_fn(
    278             input_values=np.int64([4, 5, 6]),
    279             cycle_length=2,
    280             block_length=1,
    281             sloppy=sloppy,
    282             buffer_output_elements=1,
    283             prefetch_input_elements=1))
    284     for i, expected_element in enumerate(
    285         self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
    286                          1)):
    287       if done_first_event:  # First event starts the worker threads.
    288         self._allow_all_map_threads()
    289         self.read_coordination_events[expected_element].acquire()
    290       else:
    291         self.write_coordination_events[expected_element].set()
    292       time.sleep(0.5)  # Sleep to consistently "avoid" the race condition.
    293       actual_element = self.evaluate(next_element())
    294       if not done_first_event:
    295         done_first_event = True
    296         self.assertTrue(
    297             self.read_coordination_events[expected_element].acquire(False))
    298       self.assertEqual(
    299           expected_element * expected_element, actual_element,
    300           "At index %s: %s expected, got: %s" % (i, expected_element,
    301                                                  actual_element))
    302     with self.assertRaises(errors.OutOfRangeError):
    303       self.evaluate(next_element())
    304 
    305   def testTwoThreadsNoContentionWithRaces(self):
    306     self._testTwoThreadsNoContentionWithRaces()
    307 
    308   def testTwoThreadsNoContentionWithRacesSloppy(self):
    309     self._testTwoThreadsNoContentionWithRaces(sloppy=True)
    310 
    311   def _testTwoThreadsNoContentionBlockLength(self, sloppy=False):
    312     # num_threads > 1.
    313     # Explicit coordination should result in `Dataset.interleave()` behavior
    314     self._clear_coordination_events()
    315     done_first_event = False
    316     next_element = self.getNext(
    317         self.dataset_fn(
    318             input_values=np.int64([4, 5, 6]),
    319             cycle_length=2,
    320             block_length=2,
    321             sloppy=sloppy,
    322             buffer_output_elements=1,
    323             prefetch_input_elements=1))
    324     for i, expected_element in enumerate(
    325         self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
    326                          2)):
    327       self.write_coordination_events[expected_element].set()
    328       if done_first_event:  # First event starts the worker threads.
    329         self.read_coordination_events[expected_element].acquire()
    330       actual_element = self.evaluate(next_element())
    331       if not done_first_event:
    332         done_first_event = True
    333         self.read_coordination_events[expected_element].acquire()
    334       self.assertEqual(
    335           expected_element * expected_element, actual_element,
    336           "At index %s: %s expected, got: %s" % (i, expected_element,
    337                                                  actual_element))
    338     with self.assertRaises(errors.OutOfRangeError):
    339       self.evaluate(next_element())
    340 
    341   def testTwoThreadsNoContentionBlockLength(self):
    342     self._testTwoThreadsNoContentionBlockLength()
    343 
    344   def testTwoThreadsNoContentionBlockLengthSloppy(self):
    345     self._testTwoThreadsNoContentionBlockLength(sloppy=True)
    346 
    347   def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False):
    348     """Tests where all the workers race in producing elements.
    349 
    350     Note: this is in contrast with the previous test which carefully sequences
    351     the execution of the map functions.
    352 
    353 
    354     Args:
    355       sloppy: Whether to be sloppy or not.
    356     """
    357     self._clear_coordination_events()
    358     done_first_event = False
    359     next_element = self.getNext(
    360         self.dataset_fn(
    361             input_values=np.int64([4, 5, 6]),
    362             cycle_length=2,
    363             block_length=2,
    364             sloppy=sloppy,
    365             buffer_output_elements=1,
    366             prefetch_input_elements=1))
    367     for i, expected_element in enumerate(
    368         self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
    369                          2)):
    370       if done_first_event:  # First event starts the worker threads.
    371         self._allow_all_map_threads()
    372         self.read_coordination_events[expected_element].acquire()
    373       else:
    374         self.write_coordination_events[expected_element].set()
    375       time.sleep(0.5)  # Sleep to consistently "avoid" the race condition.
    376       actual_element = self.evaluate(next_element())
    377       if not done_first_event:
    378         done_first_event = True
    379         self.assertTrue(
    380             self.read_coordination_events[expected_element].acquire(False))
    381       self.assertEqual(
    382           expected_element * expected_element, actual_element,
    383           "At index %s: %s expected, got: %s" % (i, expected_element,
    384                                                  actual_element))
    385     with self.assertRaises(errors.OutOfRangeError):
    386       self.evaluate(next_element())
    387 
    388   def testTwoThreadsNoContentionWithRacesAndBlocking(self):
    389     self._testTwoThreadsNoContentionWithRacesAndBlocking()
    390 
    391   def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self):
    392     self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
    393 
    394   def _testEmptyInput(self, sloppy=False):
    395     # Empty input.
    396     self._clear_coordination_events()
    397     next_element = self.getNext(
    398         self.dataset_fn(
    399             input_values=np.int64([]),
    400             cycle_length=2,
    401             block_length=3,
    402             sloppy=sloppy,
    403             buffer_output_elements=1,
    404             prefetch_input_elements=0))
    405     with self.assertRaises(errors.OutOfRangeError):
    406       self.evaluate(next_element())
    407 
    408   def testEmptyInput(self):
    409     self._testEmptyInput()
    410 
    411   def testEmptyInputSloppy(self):
    412     self._testEmptyInput(sloppy=True)
    413 
    414   def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False):
    415     # Non-empty input leading to empty output.
    416     self._clear_coordination_events()
    417     next_element = self.getNext(
    418         self.dataset_fn(
    419             input_values=np.int64([0, 0, 0]),
    420             cycle_length=2,
    421             block_length=3,
    422             sloppy=sloppy,
    423             buffer_output_elements=1,
    424             prefetch_input_elements=0))
    425     with self.assertRaises(errors.OutOfRangeError):
    426       self.evaluate(next_element())
    427 
    428   def testNonEmptyInputIntoEmptyOutputs(self):
    429     self._testNonEmptyInputIntoEmptyOutputs()
    430 
    431   def testNonEmptyInputIntoEmptyOutputsSloppy(self):
    432     self._testNonEmptyInputIntoEmptyOutputs(sloppy=True)
    433 
    434   def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1):
    435     race_indices = {2, 8, 14}  # Sequence points when sloppy mode has race conds
    436     # Mixture of non-empty and empty interleaved datasets.
    437     self._clear_coordination_events()
    438     done_first_event = False
    439     next_element = self.getNext(
    440         self.dataset_fn(
    441             input_values=np.int64([4, 0, 6]),
    442             cycle_length=2,
    443             block_length=1,
    444             sloppy=sloppy,
    445             buffer_output_elements=1,
    446             prefetch_input_elements=prefetch_input_elements))
    447     for i, expected_element in enumerate(
    448         self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)):
    449       self.write_coordination_events[expected_element].set()
    450       # First event starts the worker threads. Additionally, when running the
    451       # sloppy case with prefetch_input_elements=0, we get stuck if we wait
    452       # for the read coordination event for certain event orderings in the
    453       # presence of finishing iterators.
    454       if done_first_event and not (sloppy and (i in race_indices)):
    455         self.read_coordination_events[expected_element].acquire()
    456       actual_element = self.evaluate(next_element())
    457       if not done_first_event or (sloppy and (i in race_indices)):
    458         done_first_event = True
    459         self.read_coordination_events[expected_element].acquire()
    460       self.assertEqual(
    461           expected_element * expected_element, actual_element,
    462           "At index %s: %s expected, got: %s" % (i, expected_element,
    463                                                  actual_element))
    464 
    465   def testPartiallyEmptyOutputs(self):
    466     self._testPartiallyEmptyOutputs()
    467 
    468   def testPartiallyEmptyOutputsSloppy(self):
    469     self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0)
    470 
    471   def testDelayedOutputSloppy(self):
    472     # Explicitly control the sequence of events to ensure we correctly avoid
    473     # head-of-line blocking.
    474     self._clear_coordination_events()
    475     next_element = self.getNext(
    476         self.dataset_fn(
    477             input_values=np.int64([4, 5, 6]),
    478             cycle_length=2,
    479             block_length=1,
    480             sloppy=True,
    481             buffer_output_elements=1,
    482             prefetch_input_elements=0))
    483 
    484     mis_ordering = [
    485         4, 4, 5, 4, 5, 5, 4, 5, 6, 6, 6, 5, 4, 4, 6, 6, 4, 4, 6, 5, 6, 6, 6, 6,
    486         5, 5, 5, 5, 6, 6
    487     ]
    488     for element in mis_ordering:
    489       self.write_coordination_events[element].set()
    490       self.assertEqual(element * element, self.evaluate(next_element()))
    491       self.assertTrue(self.read_coordination_events[element].acquire(False))
    492     with self.assertRaises(errors.OutOfRangeError):
    493       self.evaluate(next_element())
    494 
    495   def testBlockLengthWithContentionSloppy(self):
    496     self._clear_coordination_events()
    497     done_first_event = False
    498     next_element = self.getNext(
    499         self.dataset_fn(
    500             input_values=np.int64([4, 5, 6]),
    501             cycle_length=2,
    502             block_length=1,
    503             sloppy=True,
    504             buffer_output_elements=1,
    505             prefetch_input_elements=1))
    506     # Test against a generating sequence that differs from the uncontended
    507     # case, in order to prove sloppy correctness.
    508     for i, expected_element in enumerate(
    509         self._interleave(
    510             [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count,
    511             cycle_length=2,
    512             block_length=3)):
    513       self.write_coordination_events[expected_element].set()
    514       if done_first_event:  # First event starts the worker threads.
    515         self.read_coordination_events[expected_element].acquire()
    516       actual_element = self.evaluate(next_element())
    517       if not done_first_event:
    518         self.read_coordination_events[expected_element].acquire()
    519         done_first_event = True
    520       self.assertEqual(
    521           expected_element * expected_element, actual_element,
    522           "At index %s: %s expected, got: %s" % (i, expected_element,
    523                                                  actual_element))
    524     with self.assertRaises(errors.OutOfRangeError):
    525       self.evaluate(next_element())
    526 
    527   def _testEarlyExit(self, sloppy=False):
    528     # Exiting without consuming all input should not block
    529     self._clear_coordination_events()
    530     next_element = self.getNext(
    531         self.dataset_fn(
    532             input_values=np.int64([4, 5, 6]),
    533             cycle_length=3,
    534             block_length=2,
    535             sloppy=sloppy,
    536             buffer_output_elements=1,
    537             prefetch_input_elements=0))
    538     for i in range(4, 7):
    539       self.write_coordination_events[i].set()
    540     elem = self.evaluate(next_element())  # Start all workers
    541     # Allow the one successful worker to progress beyond the py_func again.
    542     elem = int(math.sqrt(elem))
    543     self.write_coordination_events[elem].set()
    544     self.read_coordination_events[elem].acquire()
    545     # Allow the prefetch to succeed
    546     for i in range(4, 7):
    547       self.read_coordination_events[i].acquire()
    548       self.write_coordination_events[i].set()
    549 
    550   def testEarlyExit(self):
    551     self._testEarlyExit()
    552 
    553   def testEarlyExitSloppy(self):
    554     self._testEarlyExit(sloppy=True)
    555 
    556   def _testTooManyReaders(self, sloppy=False):
    557 
    558     def interleave_fn(x):
    559       dataset = dataset_ops.Dataset.from_tensors(x)
    560       dataset = dataset.repeat(math_ops.cast(x, dtype=dtypes.int64))
    561       return dataset
    562 
    563     dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6])
    564     dataset = dataset.repeat(self.repeat_count)
    565     dataset = dataset.apply(
    566         interleave_ops.parallel_interleave(
    567             interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy))
    568     get_next = self.getNext(dataset)
    569     output_values = []
    570     for _ in range(30):
    571       output_values.append(self.evaluate(get_next()))
    572 
    573     expected_values = self._interleave(
    574         [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
    575     self.assertItemsEqual(output_values, expected_values)
    576 
    577   def testTooManyReaders(self):
    578     self._testTooManyReaders()
    579 
    580   def testTooManyReadersSloppy(self):
    581     self._testTooManyReaders(sloppy=True)
    582 
    583   def testSparse(self):
    584     def _map_fn(i):
    585       return sparse_tensor.SparseTensor(
    586           indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
    587 
    588     def _interleave_fn(x):
    589       return dataset_ops.Dataset.from_tensor_slices(
    590           sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
    591 
    592     dataset = dataset_ops.Dataset.range(10).map(_map_fn).apply(
    593         interleave_ops.parallel_interleave(_interleave_fn, cycle_length=1))
    594     get_next = self.getNext(dataset)
    595 
    596     for i in range(10):
    597       for j in range(2):
    598         expected = [i, 0] if j % 2 == 0 else [0, -i]
    599         self.assertAllEqual(expected, self.evaluate(get_next()))
    600     with self.assertRaises(errors.OutOfRangeError):
    601       self.evaluate(get_next())
    602 
    603   def testErrorsInOutputFn(self):
    604     self._clear_coordination_events()
    605     next_element = self.getNext(
    606         self.dataset_fn(
    607             input_values=np.int64([4, 5, 6]),
    608             cycle_length=2,
    609             block_length=1,
    610             sloppy=False,
    611             buffer_output_elements=1,
    612             prefetch_input_elements=0))
    613 
    614     except_on_element_indices = set([3])
    615 
    616     for i, expected_element in enumerate(
    617         self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
    618                          1)):
    619       if i in except_on_element_indices:
    620         self.error = ValueError()
    621         self.write_coordination_events[expected_element].set()
    622         with self.assertRaises(errors.InvalidArgumentError):
    623           self.evaluate(next_element())
    624       else:
    625         self.write_coordination_events[expected_element].set()
    626         actual_element = self.evaluate(next_element())
    627         self.assertEqual(
    628             expected_element * expected_element, actual_element,
    629             "At index %s: %s expected, got: %s" % (i, expected_element,
    630                                                    actual_element))
    631     with self.assertRaises(errors.OutOfRangeError):
    632       self.evaluate(next_element())
    633 
    634   def testErrorsInInputFn(self):
    635 
    636     def map_py_fn(x):
    637       if x == 5:
    638         raise ValueError()
    639       return x
    640 
    641     def map_fn(x):
    642       return script_ops.py_func(map_py_fn, [x], x.dtype)
    643 
    644     def interleave_fn(x):
    645       dataset = dataset_ops.Dataset.from_tensors(x)
    646       dataset = dataset.repeat(x)
    647       return dataset
    648 
    649     def dataset_fn(input_values, cycle_length, block_length, sloppy,
    650                    buffer_output_elements, prefetch_input_elements):
    651       return dataset_ops.Dataset.from_tensor_slices(input_values).map(
    652           map_fn).repeat(self.repeat_count).apply(
    653               interleave_ops.parallel_interleave(
    654                   interleave_fn, cycle_length, block_length, sloppy,
    655                   buffer_output_elements, prefetch_input_elements))
    656 
    657     next_element = self.getNext(
    658         dataset_fn(
    659             input_values=np.int64([4, 5, 6]),
    660             cycle_length=2,
    661             block_length=1,
    662             sloppy=False,
    663             buffer_output_elements=1,
    664             prefetch_input_elements=0))
    665     for i, expected_element in enumerate(
    666         self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
    667       if expected_element == 5:
    668         with self.assertRaises(errors.InvalidArgumentError):
    669           self.evaluate(next_element())
    670       else:
    671         actual_element = self.evaluate(next_element())
    672         self.assertEqual(
    673             expected_element, actual_element,
    674             "At index %s: %s expected, got: %s" % (i, expected_element,
    675                                                    actual_element))
    676     with self.assertRaises(errors.OutOfRangeError):
    677       self.evaluate(next_element())
    678 
    679   def testErrorsInInterleaveFn(self):
    680 
    681     def map_py_fn(x):
    682       if x == 5:
    683         raise ValueError()
    684       return x
    685 
    686     def interleave_fn(x):
    687       dataset = dataset_ops.Dataset.from_tensors(x)
    688       y = script_ops.py_func(map_py_fn, [x], x.dtype)
    689       dataset = dataset.repeat(y)
    690       return dataset
    691 
    692     def dataset_fn(input_values, cycle_length, block_length, sloppy,
    693                    buffer_output_elements, prefetch_input_elements):
    694       return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
    695           self.repeat_count).apply(
    696               interleave_ops.parallel_interleave(
    697                   interleave_fn, cycle_length, block_length, sloppy,
    698                   buffer_output_elements, prefetch_input_elements))
    699 
    700     next_element = self.getNext(
    701         dataset_fn(
    702             input_values=np.int64([4, 5, 6]),
    703             cycle_length=2,
    704             block_length=1,
    705             sloppy=False,
    706             buffer_output_elements=1,
    707             prefetch_input_elements=0))
    708     for i, expected_element in enumerate(
    709         self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
    710       if expected_element == 5:
    711         with self.assertRaises(errors.InvalidArgumentError):
    712           self.evaluate(next_element())
    713       else:
    714         actual_element = self.evaluate(next_element())
    715         self.assertEqual(
    716             expected_element, actual_element,
    717             "At index %s: %s expected, got: %s" % (i, expected_element,
    718                                                    actual_element))
    719     with self.assertRaises(errors.OutOfRangeError):
    720       self.evaluate(next_element())
    721 
    722   def testShutdownRace(self):
    723     dataset = dataset_ops.Dataset.range(20)
    724     map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1))
    725     dataset = dataset.apply(
    726         interleave_ops.parallel_interleave(
    727             map_fn,
    728             cycle_length=3,
    729             sloppy=False,
    730             buffer_output_elements=1,
    731             prefetch_input_elements=0))
    732     dataset = dataset.batch(32)
    733 
    734     results = []
    735     for _ in range(2):
    736       elements = []
    737       next_element = self.getNext(dataset)
    738       try:
    739         while True:
    740           elements.extend(self.evaluate(next_element()))
    741       except errors.OutOfRangeError:
    742         pass
    743       results.append(elements)
    744     self.assertAllEqual(results[0], results[1])
    745 
    746 
    747 if __name__ == "__main__":
    748   test.main()
    749