Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from collections import namedtuple
     21 import threading
     22 import time
     23 
     24 import numpy as np
     25 
     26 from tensorflow.python.client import session
     27 from tensorflow.python.data.ops import dataset_ops
     28 from tensorflow.python.framework import constant_op
     29 from tensorflow.python.framework import dtypes
     30 from tensorflow.python.framework import errors
     31 from tensorflow.python.framework import ops
     32 from tensorflow.python.framework import sparse_tensor
     33 from tensorflow.python.ops import array_ops
     34 from tensorflow.python.ops import data_flow_ops
     35 from tensorflow.python.ops import functional_ops
     36 from tensorflow.python.ops import lookup_ops
     37 from tensorflow.python.ops import math_ops
     38 from tensorflow.python.ops import random_ops
     39 from tensorflow.python.ops import script_ops
     40 from tensorflow.python.ops import sparse_ops
     41 from tensorflow.python.ops import string_ops
     42 from tensorflow.python.ops import variable_scope
     43 from tensorflow.python.platform import test
     44 
     45 
     46 class MapDatasetTest(test.TestCase):
     47 
     48   def _buildMapDataset(self, components, count):
     49     def _map_fn(x, y, z):
     50       return math_ops.square(x), math_ops.square(y), math_ops.square(z)
     51     return (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
     52             .repeat(count))
     53 
     54   def testMapDataset(self):
     55     """Test an dataset that maps a TF function across its input elements."""
     56     # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
     57     # RepeatDataset(count).
     58     components = (np.arange(7),
     59                   np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
     60                   np.array(37.0) * np.arange(7))
     61     count = array_ops.placeholder(dtypes.int64, shape=[])
     62 
     63     dataset = self._buildMapDataset(components, count)
     64     iterator = dataset.make_initializable_iterator()
     65     init_op = iterator.initializer
     66     get_next = iterator.get_next()
     67 
     68     self.assertEqual([c.shape[1:] for c in components],
     69                      [t.shape for t in get_next])
     70 
     71     with self.test_session() as sess:
     72       # Test single-threaded access to the iterator.
     73       sess.run(init_op, feed_dict={count: 14})
     74       for _ in range(14):
     75         for i in range(7):
     76           result = sess.run(get_next)
     77           for component, result_component in zip(components, result):
     78             self.assertAllEqual(component[i]**2, result_component)
     79       with self.assertRaises(errors.OutOfRangeError):
     80         sess.run(get_next)
     81 
     82       # Test multi-threaded access to the same iterator.
     83       sess.run(init_op, feed_dict={count: 18})
     84       results = []
     85       def iterator_thread():
     86         while True:
     87           try:
     88             results.append(sess.run(get_next))
     89           except errors.OutOfRangeError:
     90             return
     91       threads = [self.checkedThread(target=iterator_thread) for _ in range(8)]
     92       for t in threads:
     93         t.start()
     94       for t in threads:
     95         t.join()
     96 
     97       # `results` will contain the same elements components**2
     98       # repeated 18 times, but in a non-deterministic order. Sort the
     99       # results, and assert that each element of components**2 is
    100       # produced 18 times.
    101       results.sort(key=lambda x: x[0])
    102       for i in range(7):
    103         for j in range(18):
    104           for component, result_component in zip(components,
    105                                                  results[i * 18 + j]):
    106             self.assertAllEqual(component[i]**2, result_component)
    107 
    108   def _buildParallelMapDataset(self, components, count, num_parallel_calls,
    109                                output_buffer_size):
    110     def _map_fn(x, y, z):
    111       return math_ops.square(x), math_ops.square(y), math_ops.square(z)
    112     return (dataset_ops.Dataset.from_tensor_slices(components)
    113             .map(_map_fn, num_parallel_calls=num_parallel_calls)
    114             .prefetch(output_buffer_size)
    115             .repeat(count))
    116 
    117   def testParallelMapDataset(self):
    118     """Test an dataset that maps a TF function across its input elements."""
    119     # The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) ->
    120     # RepeatDataset(count).
    121     components = (np.arange(7),
    122                   np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
    123                   np.array(37.0) * np.arange(7))
    124     count = array_ops.placeholder(dtypes.int64, shape=[])
    125     num_parallel_calls = array_ops.placeholder(dtypes.int32, shape=[])
    126     output_buffer_size = array_ops.placeholder(dtypes.int64, shape=[])
    127 
    128     dataset = self._buildParallelMapDataset(
    129         components, count, num_parallel_calls, output_buffer_size)
    130     iterator = dataset.make_initializable_iterator()
    131     init_op = iterator.initializer
    132     get_next = iterator.get_next()
    133 
    134     self.assertEqual([c.shape[1:] for c in components],
    135                      [t.shape for t in get_next])
    136 
    137     with self.test_session() as sess:
    138       def do_test(num_parallel_calls_val, output_buffer_size_val):
    139         # Test single-threaded access to the iterator.
    140         sess.run(init_op, feed_dict={
    141             count: 14,
    142             num_parallel_calls: num_parallel_calls_val,
    143             output_buffer_size: output_buffer_size_val})
    144         for _ in range(14):
    145           for i in range(7):
    146             result = sess.run(get_next)
    147             for component, result_component in zip(components, result):
    148               self.assertAllEqual(component[i]**2, result_component)
    149         with self.assertRaises(errors.OutOfRangeError):
    150           sess.run(get_next)
    151 
    152         # Test multi-threaded access to the same iterator.
    153         sess.run(init_op, feed_dict={
    154             count: 18,
    155             num_parallel_calls: num_parallel_calls_val,
    156             output_buffer_size: output_buffer_size_val})
    157         results = []
    158         def iterator_thread():
    159           while True:
    160             try:
    161               results.append(sess.run(get_next))
    162             except errors.OutOfRangeError:
    163               return
    164         threads = [self.checkedThread(target=iterator_thread)
    165                    for _ in range(64)]
    166         for t in threads:
    167           t.start()
    168         for t in threads:
    169           t.join()
    170 
    171         # `results` will contain the same elements components**2
    172         # repeated 18 times, but in a non-deterministic order. Sort the
    173         # results, and assert that each element of components**2 is
    174         # produced 18 times.
    175         results.sort(key=lambda x: x[0])
    176         for i in range(7):
    177           for j in range(18):
    178             for component, result_component in zip(components,
    179                                                    results[i * 18 + j]):
    180               self.assertAllEqual(component[i]**2, result_component)
    181 
    182       for num_parallel_calls_val, output_buffer_size_val in [
    183           (1, 1), (1, 2), (2, 2), (2, 4), (8, 8), (8, 16)]:
    184         do_test(num_parallel_calls_val, output_buffer_size_val)
    185 
    186   def testImplicitDisposeParallelMapDataset(self):
    187     # Tests whether a parallel map dataset will be cleaned up correctly when
    188     # the pipeline does not run it until exhaustion.
    189     # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
    190     # RepeatDataset(1000).
    191     components = (np.arange(1000),
    192                   np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
    193                   np.array(37.0) * np.arange(1000))
    194 
    195     dataset = self._buildParallelMapDataset(components, 1000, 100, 100)
    196     # NOTE(mrry): Also test that the prefetching thread is cancelled correctly.
    197     dataset = dataset.prefetch(100)
    198     iterator = dataset.make_initializable_iterator()
    199     init_op = iterator.initializer
    200     get_next = iterator.get_next()
    201 
    202     with self.test_session() as sess:
    203       sess.run(init_op)
    204       for _ in range(3):
    205         sess.run(get_next)
    206 
    207   def testParallelMapUnspecifiedOutputSize(self):
    208     components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
    209 
    210     dataset = (dataset_ops.Dataset.from_tensor_slices(components)
    211                .map(lambda x: array_ops.check_numerics(x, "message"),
    212                     num_parallel_calls=2))
    213     iterator = dataset.make_initializable_iterator()
    214     init_op = iterator.initializer
    215     get_next = iterator.get_next()
    216 
    217     with self.test_session() as sess:
    218       sess.run(init_op)
    219       for _ in range(3):
    220         sess.run(get_next)
    221 
    222   def testParallelMapError(self):
    223     components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
    224 
    225     dataset = (dataset_ops.Dataset.from_tensor_slices(components)
    226                .map(lambda x: array_ops.check_numerics(x, "message"),
    227                     num_parallel_calls=2))
    228     iterator = dataset.make_initializable_iterator()
    229     init_op = iterator.initializer
    230     get_next = iterator.get_next()
    231 
    232     with self.test_session() as sess:
    233       sess.run(init_op)
    234       for _ in range(3):
    235         sess.run(get_next)
    236       # The 4th element is NaN, so `array_ops.check_numerics()` should fail.
    237       with self.assertRaises(errors.InvalidArgumentError):
    238         sess.run(get_next)
    239       sess.run(get_next)
    240       with self.assertRaises(errors.OutOfRangeError):
    241         sess.run(get_next)
    242 
    243   def testPrefetchError(self):
    244     components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
    245 
    246     dataset = (dataset_ops.Dataset.from_tensor_slices(components)
    247                .map(lambda x: array_ops.check_numerics(x, "message"))
    248                .prefetch(2))
    249     iterator = dataset.make_initializable_iterator()
    250     init_op = iterator.initializer
    251     get_next = iterator.get_next()
    252 
    253     with self.test_session() as sess:
    254       sess.run(init_op)
    255       for _ in range(3):
    256         sess.run(get_next)
    257       # The 4th element is NaN, so `array_ops.check_numerics()` should fail.
    258       with self.assertRaises(errors.InvalidArgumentError):
    259         sess.run(get_next)
    260       sess.run(get_next)
    261       with self.assertRaises(errors.OutOfRangeError):
    262         sess.run(get_next)
    263 
    264   def testCaptureHashTable(self):
    265     # NOTE(mrry): We must use the V2 variants of `HashTable`
    266     # etc. because these produce a `tf.resource`-typed output that is
    267     # compatible with the in-graph function implementation.
    268     default_val = -1
    269     keys = constant_op.constant(["brain", "salad", "surgery"])
    270     values = constant_op.constant([0, 1, 2], dtypes.int64)
    271     table = lookup_ops.HashTable(
    272         lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
    273 
    274     input_sentences = dataset_ops.Dataset.from_tensor_slices(
    275         ["brain brain tank salad surgery", "surgery brain"])
    276 
    277     iterator = (input_sentences
    278                 .map(lambda x: string_ops.string_split([x]).values)
    279                 .map(table.lookup)
    280                 .make_initializable_iterator())
    281     init_op = iterator.initializer
    282     get_next = iterator.get_next()
    283 
    284     with self.test_session() as sess:
    285       sess.run(table.init)
    286       sess.run(init_op)
    287       sess.run(get_next)
    288       sess.run(get_next)
    289       with self.assertRaises(errors.OutOfRangeError):
    290         sess.run(get_next)
    291 
    292   def testCaptureQueue(self):
    293     elements = np.random.randint(100, size=[200])
    294     queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[])
    295     enqueue_op = queue.enqueue_many(elements)
    296     close_op = queue.close()
    297     iterator = (dataset_ops.Dataset.from_tensors(0).repeat(-1)
    298                 .map(lambda _: queue.dequeue()).make_initializable_iterator())
    299     init_op = iterator.initializer
    300     get_next = iterator.get_next()
    301 
    302     with self.test_session() as sess:
    303       sess.run(enqueue_op)
    304       sess.run(close_op)
    305       sess.run(init_op)
    306       for element in elements:
    307         self.assertEqual(element, sess.run(get_next))
    308       with self.assertRaises(errors.OutOfRangeError):
    309         sess.run(get_next)
    310 
    311   def testCaptureSameResourceMultipleTimes(self):
    312     elements = np.random.randint(100, size=[200])
    313     queue = data_flow_ops.FIFOQueue(
    314         200, dtypes.int64, shapes=[], shared_name="shared_queue")
    315     queue_2 = data_flow_ops.FIFOQueue(
    316         200, dtypes.int64, shapes=[], shared_name="shared_queue")
    317 
    318     enqueue_op = queue.enqueue_many(elements)
    319     close_op = queue.close()
    320 
    321     iterator = (dataset_ops.Dataset.from_tensors(0).repeat(-1)
    322                 .map(lambda _: (queue.dequeue(), queue_2.dequeue()))
    323                 .make_initializable_iterator())
    324     init_op = iterator.initializer
    325     get_next = iterator.get_next()
    326 
    327     with self.test_session() as sess:
    328       sess.run(enqueue_op)
    329       sess.run(close_op)
    330       sess.run(init_op)
    331       for i in range(100):
    332         self.assertEqual(sorted([elements[i * 2], elements[i * 2 + 1]]),
    333                          sorted(sess.run(get_next)))
    334       with self.assertRaises(errors.OutOfRangeError):
    335         sess.run(get_next)
    336 
    337   def testCaptureVariable(self):
    338     counter_var = variable_scope.get_variable(
    339         "counter", (), dtypes.int32, use_resource=True)
    340     iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10)
    341                 .map(lambda _: counter_var.assign_add(1))
    342                 .make_initializable_iterator())
    343     init_op = iterator.initializer
    344     get_next = iterator.get_next()
    345 
    346     with self.test_session() as sess:
    347       sess.run(counter_var.initializer)
    348       sess.run(init_op)
    349       for i in range(10):
    350         self.assertEqual(i, sess.run(counter_var))
    351         self.assertEqual(i + 1, sess.run(get_next))
    352       self.assertEqual(10, sess.run(counter_var))
    353       with self.assertRaises(errors.OutOfRangeError):
    354         sess.run(get_next)
    355       self.assertEqual(10, sess.run(counter_var))
    356 
    357   def testCaptureUninitializedVariableError(self):
    358     counter_var = variable_scope.get_variable(
    359         "counter", (), dtypes.int32, use_resource=True)
    360     iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10)
    361                 .map(lambda _: counter_var.assign_add(1))
    362                 .make_initializable_iterator())
    363     init_op = iterator.initializer
    364     get_next = iterator.get_next()
    365 
    366     with self.test_session() as sess:
    367       sess.run(init_op)
    368       with self.assertRaises(errors.NotFoundError):
    369         sess.run(get_next)
    370 
    371   def testSeededStatefulOperatorIsProperlyStateful(self):
    372     iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10)
    373                 .map(lambda _: random_ops.random_uniform((), seed=11)).batch(2)
    374                 .make_initializable_iterator())
    375     init_op = iterator.initializer
    376     get_next = iterator.get_next()
    377 
    378     with self.test_session() as sess:
    379       sess.run(init_op)
    380       random_values = []
    381       with self.assertRaises(errors.OutOfRangeError):
    382         while True:
    383           random_values.extend(sess.run(get_next))
    384       self.assertEqual(10, len(random_values))
    385       self.assertGreater(np.abs(np.diff(random_values)).max(), 1e-6)
    386       sess.run(init_op)
    387       random_values_2 = []
    388       with self.assertRaises(errors.OutOfRangeError):
    389         while True:
    390           random_values_2.extend(sess.run(get_next))
    391 
    392       # Randomness is repeatable given same seed
    393       self.assertAllClose(random_values, random_values_2)
    394 
    395   def testMapDict(self):
    396     iterator = (dataset_ops.Dataset.range(10)
    397                 .map(lambda x: {"foo": x * 2, "bar": x ** 2})
    398                 .map(lambda d: d["foo"] + d["bar"])
    399                 .make_initializable_iterator())
    400     init_op = iterator.initializer
    401     get_next = iterator.get_next()
    402 
    403     with self.test_session() as sess:
    404       sess.run(init_op)
    405       for i in range(10):
    406         self.assertEqual(i * 2 + i ** 2, sess.run(get_next))
    407       with self.assertRaises(errors.OutOfRangeError):
    408         sess.run(get_next)
    409 
    410   def testMapNamedtuple(self, count=10):
    411     # construct dataset of tuples
    412     labels = dataset_ops.Dataset.range(count)
    413     images = labels.map(lambda l: -l)
    414     dataset_tuple = dataset_ops.Dataset.zip((labels, images))
    415 
    416     # convert dataset of tuples to dataset of namedtuples
    417     example = namedtuple("Example", ["label", "image"])
    418     dataset_namedtuple = dataset_tuple.map(example)
    419 
    420     def preprocess_tuple(label, image):
    421       image = 2 * image
    422       return label, image
    423 
    424     def preprocess_namedtuple(example):
    425       return example._replace(image=2 * example.image)
    426 
    427     # preprocess both datasets
    428     dataset_tuple = dataset_tuple.map(preprocess_tuple)
    429     dataset_namedtuple = dataset_namedtuple.map(preprocess_namedtuple)
    430 
    431     next_tuple = dataset_tuple.make_one_shot_iterator().get_next()
    432     next_namedtuple = dataset_namedtuple.make_one_shot_iterator().get_next()
    433 
    434     # make sure both datasets contain the same data
    435     with self.test_session() as sess:
    436       for i in range(count):
    437         tuple_, namedtuple_ = sess.run([next_tuple, next_namedtuple])
    438         self.assertEqual(tuple_, namedtuple_)
    439         self.assertEqual(tuple_, (i, -2 * i))
    440 
    441       with self.assertRaises(errors.OutOfRangeError):
    442         sess.run(next_namedtuple)
    443 
    444   def testUseStepContainerInMap(self):
    445     row = np.arange(6)
    446     iterator = (
    447         dataset_ops.Dataset.from_tensors(row)
    448         .map(lambda elems: functional_ops.map_fn(lambda x: x * x, elems))
    449         .make_initializable_iterator())
    450     init_op = iterator.initializer
    451     get_next = iterator.get_next()
    452 
    453     with self.test_session() as sess:
    454       sess.run(init_op)
    455       self.assertAllEqual(row ** 2, sess.run(get_next))
    456       with self.assertRaises(errors.OutOfRangeError):
    457         sess.run(get_next)
    458 
    459   def testPrefetch(self):
    460     # We will use this event to test that `_map_py_func()` has been
    461     # invoked a certain number of times (6 times, to be exact) after
    462     # consuming fewer elements from the iterator.
    463     ev = threading.Event()
    464 
    465     set_event_during_invocation = 5
    466 
    467     def _map_py_func(x):
    468       if x == set_event_during_invocation:
    469         ev.set()
    470       return x * x
    471 
    472     def _map_fn(x):
    473       return script_ops.py_func(_map_py_func, [x], x.dtype)
    474 
    475     buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
    476     iterator = (
    477         dataset_ops.Dataset.range(100)
    478         .map(_map_fn)
    479         .prefetch(buffer_size_placeholder)
    480         .make_initializable_iterator())
    481     init_op = iterator.initializer
    482     get_next = iterator.get_next()
    483 
    484     with self.test_session() as sess:
    485       # Simple test that prefetch yields the expected values in the
    486       # expected order.
    487       for buffer_size in [1, 10, 100, 1000]:
    488         sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size})
    489         for i in range(100):
    490           self.assertEqual(i * i, sess.run(get_next))
    491         with self.assertRaises(errors.OutOfRangeError):
    492           sess.run(get_next)
    493 
    494       # We can indirectly observe that varying the buffer size has the
    495       # intended effect by observing when `ev` is set (on the 6th
    496       # invocation of `_map_py_func()`).
    497       # NOTE(mrry): We do not test with `buffer_size ==
    498       # set_event_during_invocation`, because we must consume at least
    499       # one element to start the prefetching.
    500       for buffer_size in range(1, set_event_during_invocation):
    501         event_will_be_set_after_consuming = (
    502             set_event_during_invocation - buffer_size + 1)
    503 
    504         ev.clear()
    505         sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size})
    506         for i in range(event_will_be_set_after_consuming):
    507           self.assertFalse(ev.is_set())
    508           self.assertEqual(i * i, sess.run(get_next))
    509         ev.wait()
    510         for i in range(event_will_be_set_after_consuming, 100):
    511           self.assertEqual(i * i, sess.run(get_next))
    512         with self.assertRaises(errors.OutOfRangeError):
    513           sess.run(get_next)
    514 
    515   def testReturnList(self):
    516     iterator = (dataset_ops.Dataset.range(10)
    517                 .map(lambda x: [x, constant_op.constant(37.0)])
    518                 .make_initializable_iterator())
    519     init_op = iterator.initializer
    520     get_next = iterator.get_next()
    521 
    522     with self.test_session() as sess:
    523       sess.run(init_op)
    524       for i in range(10):
    525         self.assertEqual((i, 37.0), sess.run(get_next))
    526       with self.assertRaises(errors.OutOfRangeError):
    527         sess.run(get_next)
    528 
    529   def testMultiOutputPyFunc(self):
    530     # The `tf.py_func()` op returns a list of tensors for its outputs.
    531     def _map_fn(x_tensor):
    532       def _map_py_func(x):
    533         return x, np.array(37.0, dtype=np.float64)
    534       return script_ops.py_func(
    535           _map_py_func, [x_tensor], [dtypes.int64, dtypes.float64])
    536 
    537     iterator = (dataset_ops.Dataset.range(10)
    538                 .map(_map_fn)
    539                 .make_initializable_iterator())
    540     init_op = iterator.initializer
    541     get_next = iterator.get_next()
    542 
    543     with self.test_session() as sess:
    544       sess.run(init_op)
    545       for i in range(10):
    546         self.assertEqual((i, 37.0), sess.run(get_next))
    547       with self.assertRaises(errors.OutOfRangeError):
    548         sess.run(get_next)
    549 
    550   def assertSparseValuesEqual(self, a, b):
    551     self.assertAllEqual(a.indices, b.indices)
    552     self.assertAllEqual(a.values, b.values)
    553     self.assertAllEqual(a.dense_shape, b.dense_shape)
    554 
    555   def testSparse(self):
    556 
    557     def _sparse(i):
    558       return sparse_tensor.SparseTensorValue(
    559           indices=np.array([[0, 0]]),
    560           values=(i * np.array([1])),
    561           dense_shape=np.array([1, 1]))
    562 
    563     iterator = (dataset_ops.Dataset.range(10)
    564                 .map(_sparse)
    565                 .make_initializable_iterator())
    566     init_op = iterator.initializer
    567     get_next = iterator.get_next()
    568 
    569     with self.test_session() as sess:
    570       sess.run(init_op)
    571       for i in range(10):
    572         actual = sess.run(get_next)
    573         self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
    574         self.assertSparseValuesEqual(actual, _sparse(i))
    575       with self.assertRaises(errors.OutOfRangeError):
    576         sess.run(get_next)
    577 
    578   def testSparseChain(self):
    579 
    580     def _sparse(i):
    581       return sparse_tensor.SparseTensorValue(
    582           indices=np.array([[0, 0]]),
    583           values=(i * np.array([1])),
    584           dense_shape=np.array([1, 1]))
    585 
    586     def _check(i):
    587       self.assertTrue(sparse_tensor.is_sparse(i))
    588       return sparse_ops.sparse_concat(0, [i, i])
    589 
    590     iterator = (
    591         dataset_ops.Dataset.range(10).map(_sparse).map(_check)
    592         .make_initializable_iterator())
    593     init_op = iterator.initializer
    594     get_next = iterator.get_next()
    595 
    596     with self.test_session() as sess:
    597       sess.run(init_op)
    598       for i in range(10):
    599         actual = sess.run(get_next)
    600         self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
    601         self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval())
    602       with self.assertRaises(errors.OutOfRangeError):
    603         sess.run(get_next)
    604 
    605 
    606 class MapDatasetBenchmark(test.Benchmark):
    607 
    608   def benchmarkChainOfMaps(self):
    609     chain_lengths = [0, 1, 2, 5, 10, 20, 50]
    610     for chain_length in chain_lengths:
    611       with ops.Graph().as_default():
    612         dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
    613         for _ in range(chain_length):
    614           dataset = dataset.map(lambda x: x)
    615         iterator = dataset.make_one_shot_iterator()
    616         next_element = iterator.get_next()
    617 
    618         with session.Session() as sess:
    619           for _ in range(5):
    620             sess.run(next_element.op)
    621           deltas = []
    622           for _ in range(100):
    623             start = time.time()
    624             for _ in range(100):
    625               sess.run(next_element.op)
    626             end = time.time()
    627             deltas.append(end - start)
    628 
    629           median_wall_time = np.median(deltas) / 100
    630           print("Map dataset chain length: %d Median wall time: %f"
    631                 % (chain_length, median_wall_time))
    632           self.report_benchmark(
    633               iters=1000, wall_time=median_wall_time,
    634               name="benchmark_map_dataset_chain_latency_%d" % chain_length)
    635 
    636   def benchmarkMapFanOut(self):
    637     fan_outs = [1, 2, 5, 10, 20, 50, 100]
    638     for fan_out in fan_outs:
    639       with ops.Graph().as_default():
    640         dataset = dataset_ops.Dataset.from_tensors(
    641             tuple(0 for _ in range(fan_out))).repeat(None).map(lambda *xs: xs)
    642         iterator = dataset.make_one_shot_iterator()
    643         next_element = iterator.get_next()
    644 
    645         with session.Session() as sess:
    646           for _ in range(5):
    647             sess.run(next_element[0].op)
    648           deltas = []
    649           for _ in range(100):
    650             start = time.time()
    651             for _ in range(100):
    652               sess.run(next_element[0].op)
    653             end = time.time()
    654             deltas.append(end - start)
    655 
    656           median_wall_time = np.median(deltas) / 100
    657           print("Map dataset fan out: %d Median wall time: %f"
    658                 % (fan_out, median_wall_time))
    659           self.report_benchmark(
    660               iters=1000, wall_time=median_wall_time,
    661               name="benchmark_map_dataset_fan_out_%d" % fan_out)
    662 
    663 
    664 if __name__ == "__main__":
    665   test.main()
    666