Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import time
     21 
     22 import numpy as np
     23 
     24 from tensorflow.core.protobuf import config_pb2
     25 from tensorflow.python.client import session
     26 from tensorflow.python.data.ops import dataset_ops
     27 from tensorflow.python.data.util import nest
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.framework import errors
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.framework import sparse_tensor
     32 from tensorflow.python.framework import tensor_shape
     33 from tensorflow.python.ops import array_ops
     34 from tensorflow.python.ops import math_ops
     35 from tensorflow.python.ops import resource_variable_ops
     36 from tensorflow.python.platform import test
     37 
     38 
     39 class DatasetConstructorTest(test.TestCase):
     40 
     41   def testFromTensors(self):
     42     """Test a dataset that represents a single tuple of tensors."""
     43     components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
     44 
     45     iterator = (dataset_ops.Dataset.from_tensors(components)
     46                 .make_initializable_iterator())
     47     init_op = iterator.initializer
     48     get_next = iterator.get_next()
     49 
     50     self.assertEqual([c.shape for c in components],
     51                      [t.shape for t in get_next])
     52 
     53     with self.test_session() as sess:
     54       sess.run(init_op)
     55       results = sess.run(get_next)
     56       for component, result_component in zip(components, results):
     57         self.assertAllEqual(component, result_component)
     58       with self.assertRaises(errors.OutOfRangeError):
     59         sess.run(get_next)
     60 
     61   def assertSparseValuesEqual(self, a, b):
     62     self.assertAllEqual(a.indices, b.indices)
     63     self.assertAllEqual(a.values, b.values)
     64     self.assertAllEqual(a.dense_shape, b.dense_shape)
     65 
     66   def testFromTensorsSparse(self):
     67     """Test a dataset that represents a single tuple of tensors."""
     68     components = (sparse_tensor.SparseTensorValue(
     69         indices=np.array([[0]]),
     70         values=np.array([0]),
     71         dense_shape=np.array([1])),
     72                   sparse_tensor.SparseTensorValue(
     73                       indices=np.array([[0, 0], [1, 1]]),
     74                       values=np.array([-1, 1]),
     75                       dense_shape=np.array([2, 2])))
     76 
     77     iterator = (
     78         dataset_ops.Dataset.from_tensors(components)
     79         .make_initializable_iterator())
     80     init_op = iterator.initializer
     81     get_next = iterator.get_next()
     82 
     83     self.assertEqual(
     84         [tensor_shape.TensorShape(c.dense_shape) for c in components],
     85         [shape for shape in iterator.output_shapes])
     86 
     87     with self.test_session() as sess:
     88       sess.run(init_op)
     89       results = sess.run(get_next)
     90       for component, result_component in zip(components, results):
     91         self.assertSparseValuesEqual(component, result_component)
     92       with self.assertRaises(errors.OutOfRangeError):
     93         sess.run(get_next)
     94 
     95   def testFromTensorsMixed(self):
     96     """Test an dataset that represents a single tuple of tensors."""
     97     components = (np.array(1), np.array([1, 2, 3]), np.array(37.0),
     98                   sparse_tensor.SparseTensorValue(
     99                       indices=np.array([[0]]),
    100                       values=np.array([0]),
    101                       dense_shape=np.array([1])),
    102                   sparse_tensor.SparseTensorValue(
    103                       indices=np.array([[0, 0], [1, 1]]),
    104                       values=np.array([-1, 1]),
    105                       dense_shape=np.array([2, 2])))
    106 
    107     iterator = (
    108         dataset_ops.Dataset.from_tensors(components)
    109         .make_initializable_iterator())
    110     init_op = iterator.initializer
    111     get_next = iterator.get_next()
    112 
    113     self.assertEqual([
    114         tensor_shape.TensorShape(c.dense_shape)
    115         if sparse_tensor.is_sparse(c) else c.shape for c in components
    116     ], [shape for shape in iterator.output_shapes])
    117 
    118     with self.test_session() as sess:
    119       sess.run(init_op)
    120       results = sess.run(get_next)
    121       for component, result_component in zip(components, results):
    122         if sparse_tensor.is_sparse(component):
    123           self.assertSparseValuesEqual(component, result_component)
    124         else:
    125           self.assertAllEqual(component, result_component)
    126       with self.assertRaises(errors.OutOfRangeError):
    127         sess.run(get_next)
    128 
    129   def testFromTensorSlices(self):
    130     """Test a dataset that represents the slices from a tuple of tensors."""
    131     components = (
    132         np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(
    133             np.array([[12], [13], [14], [15]]), 22),
    134         np.array([37.0, 38.0, 39.0, 40.0])
    135     )
    136 
    137     iterator = (dataset_ops.Dataset.from_tensor_slices(components)
    138                 .make_initializable_iterator())
    139     init_op = iterator.initializer
    140     get_next = iterator.get_next()
    141 
    142     self.assertEqual([c.shape[1:] for c in components],
    143                      [t.shape for t in get_next])
    144 
    145     with self.test_session() as sess:
    146       sess.run(init_op)
    147       for i in range(4):
    148         results = sess.run(get_next)
    149         for component, result_component in zip(components, results):
    150           self.assertAllEqual(component[i], result_component)
    151       with self.assertRaises(errors.OutOfRangeError):
    152         sess.run(get_next)
    153 
    154   def testFromTensorSlicesSparse(self):
    155     """Test a dataset that represents the slices from a tuple of tensors."""
    156     components = (sparse_tensor.SparseTensorValue(
    157         indices=np.array([[0, 0], [1, 0], [2, 0]]),
    158         values=np.array([0, 0, 0]),
    159         dense_shape=np.array([3, 1])),
    160                   sparse_tensor.SparseTensorValue(
    161                       indices=np.array([[0, 0], [1, 1], [2, 2]]),
    162                       values=np.array([1, 2, 3]),
    163                       dense_shape=np.array([3, 3])))
    164 
    165     iterator = (
    166         dataset_ops.Dataset.from_tensor_slices(components)
    167         .make_initializable_iterator())
    168     init_op = iterator.initializer
    169     get_next = iterator.get_next()
    170 
    171     self.assertEqual(
    172         [tensor_shape.TensorShape(c.dense_shape[1:]) for c in components],
    173         [shape for shape in iterator.output_shapes])
    174 
    175     with self.test_session() as sess:
    176       sess.run(init_op)
    177       expected = [
    178           (sparse_tensor.SparseTensorValue(
    179               indices=np.array([[0]]),
    180               values=np.array([0]),
    181               dense_shape=np.array([1])),
    182            sparse_tensor.SparseTensorValue(
    183                indices=np.array([[0]]),
    184                values=np.array([1]),
    185                dense_shape=np.array([3]))),
    186           (sparse_tensor.SparseTensorValue(
    187               indices=np.array([[0]]),
    188               values=np.array([0]),
    189               dense_shape=np.array([1])),
    190            sparse_tensor.SparseTensorValue(
    191                indices=np.array([[1]]),
    192                values=np.array([2]),
    193                dense_shape=np.array([3]))),
    194           (sparse_tensor.SparseTensorValue(
    195               indices=np.array([[0]]),
    196               values=np.array([0]),
    197               dense_shape=np.array([1])),
    198            sparse_tensor.SparseTensorValue(
    199                indices=np.array([[2]]),
    200                values=np.array([3]),
    201                dense_shape=np.array([3]))),
    202       ]
    203       for i in range(3):
    204         results = sess.run(get_next)
    205         for component, result_component in zip(expected[i], results):
    206           self.assertSparseValuesEqual(component, result_component)
    207       with self.assertRaises(errors.OutOfRangeError):
    208         sess.run(get_next)
    209 
    210   def testFromTensorSlicesMixed(self):
    211     """Test a dataset that represents the slices from a tuple of tensors."""
    212     components = (np.tile(np.array([[1], [2], [3]]), 20),
    213                   np.tile(np.array([[12], [13], [14]]), 22),
    214                   np.array([37.0, 38.0, 39.0]),
    215                   sparse_tensor.SparseTensorValue(
    216                       indices=np.array([[0, 0], [1, 0], [2, 0]]),
    217                       values=np.array([0, 0, 0]),
    218                       dense_shape=np.array([3, 1])),
    219                   sparse_tensor.SparseTensorValue(
    220                       indices=np.array([[0, 0], [1, 1], [2, 2]]),
    221                       values=np.array([1, 2, 3]),
    222                       dense_shape=np.array([3, 3])))
    223 
    224     iterator = (
    225         dataset_ops.Dataset.from_tensor_slices(components)
    226         .make_initializable_iterator())
    227     init_op = iterator.initializer
    228     get_next = iterator.get_next()
    229 
    230     self.assertEqual([
    231         tensor_shape.TensorShape(c.dense_shape[1:])
    232         if sparse_tensor.is_sparse(c) else c.shape[1:] for c in components
    233     ], [shape for shape in iterator.output_shapes])
    234 
    235     with self.test_session() as sess:
    236       sess.run(init_op)
    237       expected = [
    238           (sparse_tensor.SparseTensorValue(
    239               indices=np.array([[0]]),
    240               values=np.array([0]),
    241               dense_shape=np.array([1])),
    242            sparse_tensor.SparseTensorValue(
    243                indices=np.array([[0]]),
    244                values=np.array([1]),
    245                dense_shape=np.array([3]))),
    246           (sparse_tensor.SparseTensorValue(
    247               indices=np.array([[0]]),
    248               values=np.array([0]),
    249               dense_shape=np.array([1])),
    250            sparse_tensor.SparseTensorValue(
    251                indices=np.array([[1]]),
    252                values=np.array([2]),
    253                dense_shape=np.array([3]))),
    254           (sparse_tensor.SparseTensorValue(
    255               indices=np.array([[0]]),
    256               values=np.array([0]),
    257               dense_shape=np.array([1])),
    258            sparse_tensor.SparseTensorValue(
    259                indices=np.array([[2]]),
    260                values=np.array([3]),
    261                dense_shape=np.array([3]))),
    262       ]
    263       for i in range(3):
    264         results = sess.run(get_next)
    265         for component, result_component in zip(
    266             (zip(*components[:3])[i] + expected[i]), results):
    267           if sparse_tensor.is_sparse(component):
    268             self.assertSparseValuesEqual(component, result_component)
    269           else:
    270             self.assertAllEqual(component, result_component)
    271       with self.assertRaises(errors.OutOfRangeError):
    272         sess.run(get_next)
    273 
    274   def testFromTensorSlicesWithDict(self):
    275     components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]}
    276     iterator = (dataset_ops.Dataset.from_tensor_slices(components)
    277                 .make_initializable_iterator())
    278     init_op = iterator.initializer
    279     get_next = iterator.get_next()
    280 
    281     self.assertEqual(dtypes.int32, iterator.output_types["foo"])
    282     self.assertEqual(dtypes.float32, iterator.output_types["bar"])
    283     self.assertEqual((), iterator.output_shapes["foo"])
    284     self.assertEqual((1,), iterator.output_shapes["bar"])
    285 
    286     with self.test_session() as sess:
    287       sess.run(init_op)
    288       for i in range(3):
    289         results = sess.run(get_next)
    290         self.assertEqual(components["foo"][i], results["foo"])
    291         self.assertEqual(components["bar"][i], results["bar"])
    292       with self.assertRaises(errors.OutOfRangeError):
    293         sess.run(get_next)
    294 
    295   def testFromSparseTensorSlices(self):
    296     """Test a dataset based on slices of a `tf.SparseTensor`."""
    297     st = array_ops.sparse_placeholder(dtypes.float64)
    298     iterator = (dataset_ops.Dataset.from_sparse_tensor_slices(st)
    299                 .make_initializable_iterator())
    300     init_op = iterator.initializer
    301     get_next = sparse_tensor.SparseTensor(*iterator.get_next())
    302 
    303     with self.test_session() as sess:
    304       slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]
    305 
    306       # Test with sparse tensor in the appropriate order.
    307       indices = np.array(
    308           [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))])
    309       values = np.array([val for s in slices for val in s])
    310       dense_shape = np.array([len(slices), max(len(s) for s in slices) + 1])
    311       sparse_feed = sparse_tensor.SparseTensorValue(indices, values,
    312                                                     dense_shape)
    313       sess.run(init_op, feed_dict={st: sparse_feed})
    314       for i, s in enumerate(slices):
    315         results = sess.run(get_next)
    316         self.assertAllEqual(s, results.values)
    317         expected_indices = np.array(
    318             [[j] for j in range(len(slices[i]))]).reshape([-1, 1])
    319         self.assertAllEqual(expected_indices, results.indices)
    320         self.assertAllEqual(dense_shape[1:], results.dense_shape)
    321       with self.assertRaises(errors.OutOfRangeError):
    322         sess.run(get_next)
    323 
    324       # Test with sparse tensor in the reverse order, which is not
    325       # currently supported.
    326       reverse_order_indices = indices[::-1, :]
    327       reverse_order_values = values[::-1]
    328       sparse_feed = sparse_tensor.SparseTensorValue(
    329           reverse_order_indices, reverse_order_values, dense_shape)
    330       with self.assertRaises(errors.UnimplementedError):
    331         sess.run(init_op, feed_dict={st: sparse_feed})
    332 
    333       # Test with an empty sparse tensor.
    334       empty_indices = np.empty((0, 4), dtype=np.int64)
    335       empty_values = np.empty((0,), dtype=np.float64)
    336       empty_dense_shape = [0, 4, 37, 9]
    337       sparse_feed = sparse_tensor.SparseTensorValue(empty_indices, empty_values,
    338                                                     empty_dense_shape)
    339       sess.run(init_op, feed_dict={st: sparse_feed})
    340       with self.assertRaises(errors.OutOfRangeError):
    341         sess.run(get_next)
    342 
    343   # pylint: disable=g-long-lambda,unnecessary-lambda
    344   def testNestedStructure(self):
    345     components = (np.array([1, 2, 3], dtype=np.int64),
    346                   (np.array([4., 5.]), np.array([6., 7.])),
    347                   np.array([8, 9, 10], dtype=np.int64))
    348 
    349     dataset = dataset_ops.Dataset.from_tensors(components)
    350     self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
    351                        dtypes.int64), dataset.output_types)
    352     self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
    353 
    354     dataset = dataset.shuffle(10, 10)
    355     self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
    356                        dtypes.int64), dataset.output_types)
    357     self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
    358 
    359     dataset = dataset.repeat(-1)
    360     self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
    361                        dtypes.int64), dataset.output_types)
    362     self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
    363 
    364     dataset = dataset.filter(lambda x, y, z: True)
    365     self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
    366                        dtypes.int64), dataset.output_types)
    367     self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
    368 
    369     dataset = dataset.take(5)
    370     self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
    371                        dtypes.int64), dataset.output_types)
    372     self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
    373 
    374     dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1])))
    375     self.assertEquals(((dtypes.int64, dtypes.int64),
    376                        (dtypes.float64, dtypes.float64)), dataset.output_types)
    377     self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes)
    378 
    379     dataset = dataset.flat_map(
    380         lambda x, y: dataset_ops.Dataset.from_tensors(((x[0], x[1]),
    381                                                        (y[0], y[1])))
    382     )
    383     self.assertEquals(((dtypes.int64, dtypes.int64),
    384                        (dtypes.float64, dtypes.float64)), dataset.output_types)
    385     self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes)
    386 
    387     dataset = dataset.batch(32)
    388     self.assertEquals(((dtypes.int64, dtypes.int64),
    389                        (dtypes.float64, dtypes.float64)), dataset.output_types)
    390     self.assertEquals((([None, 3], [None, 3]), ([None, 2], [None, 2])),
    391                       nest.pack_sequence_as(dataset.output_shapes, [
    392                           s.as_list()
    393                           for s in nest.flatten(dataset.output_shapes)
    394                       ]))
    395 
    396     iterator = dataset.make_one_shot_iterator()
    397     (w, x), (y, z) = iterator.get_next()
    398     self.assertEquals(dtypes.int64, w.dtype)
    399     self.assertEquals(dtypes.int64, x.dtype)
    400     self.assertEquals(dtypes.float64, y.dtype)
    401     self.assertEquals(dtypes.float64, z.dtype)
    402     self.assertEquals([None, 3], w.shape.as_list())
    403     self.assertEquals([None, 3], x.shape.as_list())
    404     self.assertEquals([None, 2], y.shape.as_list())
    405     self.assertEquals([None, 2], z.shape.as_list())
    406 
    407     iterator = dataset.make_initializable_iterator()
    408     (w, x), (y, z) = iterator.get_next()
    409     self.assertEquals(dtypes.int64, w.dtype)
    410     self.assertEquals(dtypes.int64, x.dtype)
    411     self.assertEquals(dtypes.float64, y.dtype)
    412     self.assertEquals(dtypes.float64, z.dtype)
    413     self.assertEquals([None, 3], w.shape.as_list())
    414     self.assertEquals([None, 3], x.shape.as_list())
    415     self.assertEquals([None, 2], y.shape.as_list())
    416     self.assertEquals([None, 2], z.shape.as_list())
    417 
    418     # Define a separate set of components with matching leading
    419     # dimension for the from-slices constructor.
    420     components_for_slices = (np.array([1, 2, 3], dtype=np.int64),
    421                              (np.array([4., 5., 6.]),
    422                               np.array([7., 8., 9.])),
    423                              np.array([10, 11, 12], dtype=np.int64))
    424 
    425     dataset = dataset_ops.Dataset.from_tensor_slices(components_for_slices)
    426     self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
    427                        dtypes.int64), dataset.output_types)
    428     self.assertEquals(([], ([], []), []), dataset.output_shapes)
    429 
    430   def testNestedDict(self):
    431     components = {"a": {"aa": 1, "ab": [2.0, 2.0]}, "b": [3, 3, 3]}
    432     dataset = dataset_ops.Dataset.from_tensors(components)
    433     self.assertEquals(dtypes.int32, dataset.output_types["a"]["aa"])
    434     self.assertEquals(dtypes.float32, dataset.output_types["a"]["ab"])
    435     self.assertEquals(dtypes.int32, dataset.output_types["b"])
    436     self.assertEquals([], dataset.output_shapes["a"]["aa"])
    437     self.assertEquals([2], dataset.output_shapes["a"]["ab"])
    438     self.assertEquals([3], dataset.output_shapes["b"])
    439 
    440   def testNonSequenceNestedStructure(self):
    441     components = np.array([1, 2, 3], dtype=np.int64)
    442 
    443     dataset = dataset_ops.Dataset.from_tensors(components)
    444     self.assertEquals(dtypes.int64, dataset.output_types)
    445     self.assertEquals([3], dataset.output_shapes)
    446 
    447     dataset = dataset.filter(
    448         lambda x: math_ops.reduce_all(math_ops.equal(x, components)))
    449     self.assertEquals(dtypes.int64, dataset.output_types)
    450     self.assertEquals([3], dataset.output_shapes)
    451 
    452     dataset = dataset.map(lambda x: array_ops.stack([x, x]))
    453     self.assertEquals(dtypes.int64, dataset.output_types)
    454     self.assertEquals([2, 3], dataset.output_shapes)
    455 
    456     dataset = dataset.flat_map(
    457         lambda x: dataset_ops.Dataset.from_tensor_slices(x))
    458     self.assertEquals(dtypes.int64, dataset.output_types)
    459     self.assertEquals([3], dataset.output_shapes)
    460 
    461     iterator = dataset.make_one_shot_iterator()
    462     get_next = iterator.get_next()
    463     self.assertEquals(dtypes.int64, get_next.dtype)
    464     self.assertEquals([3], get_next.shape)
    465 
    466   def testSplitPipelineFailsWithPlacementError(self):
    467     with session.Session(
    468         target="",
    469         config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
    470 
    471       dataset = dataset_ops.Dataset.from_tensors(0)
    472 
    473       # Define a pipeline that attempts to use variables on two
    474       # different devices.
    475       #
    476       # Initialize the variables before creating to iterator, to avoid the
    477       # placement algorithm overriding the DT_RESOURCE colocation constraints.
    478       with ops.device("/cpu:0"):
    479         var_0 = resource_variable_ops.ResourceVariable(initial_value=0)
    480         dataset = dataset.map(lambda x: x + var_0.read_value())
    481       sess.run(var_0.initializer)
    482 
    483       with ops.device("/cpu:1"):
    484         var_1 = resource_variable_ops.ResourceVariable(initial_value=0)
    485         dataset = dataset.map(lambda x: x + var_1.read_value())
    486       sess.run(var_1.initializer)
    487 
    488       iterator = dataset.make_initializable_iterator()
    489       sess.run(iterator.initializer)
    490 
    491       with self.assertRaisesRegexp(
    492           errors.FailedPreconditionError,
    493           "Error while reading resource variable Variable"):
    494         sess.run(iterator.get_next())
    495 
    496 
    497 class DatasetConstructorBenchmark(test.Benchmark):
    498 
    499   def benchmarkSliceRepeatBatch(self):
    500     input_size = 10000
    501     batch_size = 100
    502     num_epochs = 100
    503 
    504     input_data = np.random.randn(input_size)
    505 
    506     dataset = (
    507         dataset_ops.Dataset.from_tensor_slices(input_data)
    508         .repeat(num_epochs + 1).batch(batch_size))
    509     iterator = dataset.make_initializable_iterator()
    510     next_element = iterator.get_next()
    511 
    512     with session.Session() as sess:
    513       sess.run(iterator.initializer)
    514       # Run one whole epoch to burn in the computation.
    515       for _ in range(input_size // batch_size):
    516         sess.run(next_element)
    517       deltas = []
    518       try:
    519         while True:
    520           start = time.time()
    521           sess.run(next_element)
    522           deltas.append(time.time() - start)
    523       except errors.OutOfRangeError:
    524         pass
    525 
    526     median_wall_time = np.median(deltas)
    527     print("Slice/repeat/batch with sess.run() input size: %d batch size: %d "
    528           "Median wall time per element: %f" % (input_size, batch_size,
    529                                                 median_wall_time))
    530     self.report_benchmark(
    531         iters=len(deltas),
    532         wall_time=median_wall_time,
    533         name="benchmark_slice_repeat_batch_input_%d_batch_%d" % (input_size,
    534                                                                  batch_size))
    535 
    536   def benchmarkSliceRepeatBatchCallable(self):
    537     input_size = 10000
    538     batch_size = 100
    539     num_epochs = 100
    540 
    541     input_data = np.random.randn(input_size)
    542 
    543     dataset = (
    544         dataset_ops.Dataset.from_tensor_slices(input_data)
    545         .repeat(num_epochs + 1).batch(batch_size))
    546     iterator = dataset.make_initializable_iterator()
    547     next_element = iterator.get_next()
    548 
    549     with session.Session() as sess:
    550       sess.run(iterator.initializer)
    551       get_next_element = sess.make_callable(next_element)
    552       # Run one whole epoch to burn in the computation.
    553       for _ in range(input_size // batch_size):
    554         get_next_element()
    555       deltas = []
    556       try:
    557         while True:
    558           start = time.time()
    559           get_next_element()
    560           deltas.append(time.time() - start)
    561       except errors.OutOfRangeError:
    562         pass
    563 
    564     median_wall_time = np.median(deltas)
    565     print(
    566         "Slice/repeat/batch with callable input size: %d batch size: %d Median"
    567         " wall time per element: %f" % (input_size, batch_size,
    568                                         median_wall_time))
    569     self.report_benchmark(
    570         iters=len(deltas),
    571         wall_time=median_wall_time,
    572         name="benchmark_slice_repeat_batch_callable_input_%d_batch_%d" %
    573         (input_size, batch_size))
    574 
    575   def benchmarkReshapeSliceRepeatCallable(self):
    576     input_size = 10000
    577     batch_size = 100
    578     num_epochs = 100
    579 
    580     input_data = np.random.randn(input_size)
    581 
    582     dataset = (
    583         dataset_ops.Dataset.from_tensor_slices(input_data.reshape(100, 100))
    584         .repeat(num_epochs + 1))
    585     iterator = dataset.make_initializable_iterator()
    586     next_element = iterator.get_next()
    587 
    588     with session.Session() as sess:
    589       sess.run(iterator.initializer)
    590       get_next_element = sess.make_callable(next_element)
    591       # Run one whole epoch to burn in the computation.
    592       for _ in range(input_size // batch_size):
    593         get_next_element()
    594       deltas = []
    595       try:
    596         while True:
    597           start = time.time()
    598           get_next_element()
    599           deltas.append(time.time() - start)
    600       except errors.OutOfRangeError:
    601         pass
    602 
    603     median_wall_time = np.median(deltas)
    604     print("Reshape/slice/repeat with callable input size: %d batch size: %d "
    605           "Median wall time per element: %f" % (input_size, batch_size,
    606                                                 median_wall_time))
    607     self.report_benchmark(
    608         iters=len(deltas),
    609         wall_time=median_wall_time,
    610         name="benchmark_reshape_slice_repeat_callable_input_%d_batch_%d" %
    611         (input_size, batch_size))
    612 
    613   def benchmarkSliceBatchCacheRepeatCallable(self):
    614     input_size = 10000
    615     batch_size = 100
    616     num_epochs = 100
    617 
    618     input_data = np.random.randn(input_size)
    619 
    620     dataset = (
    621         dataset_ops.Dataset.from_tensor_slices(input_data).batch(batch_size)
    622         .cache().repeat(num_epochs + 1))
    623     iterator = dataset.make_initializable_iterator()
    624     next_element = iterator.get_next()
    625 
    626     with session.Session() as sess:
    627       sess.run(iterator.initializer)
    628       get_next_element = sess.make_callable(next_element)
    629       # Run one whole epoch to burn in the computation.
    630       for _ in range(input_size // batch_size):
    631         get_next_element()
    632       deltas = []
    633       try:
    634         while True:
    635           start = time.time()
    636           get_next_element()
    637           deltas.append(time.time() - start)
    638       except errors.OutOfRangeError:
    639         pass
    640 
    641     median_wall_time = np.median(deltas)
    642     print(
    643         "Slice/batch/cache/repeat with callable input size: %d batch size: %d "
    644         "Median wall time per element: %f"
    645         % (input_size, batch_size, median_wall_time))
    646     self.report_benchmark(
    647         iters=len(deltas),
    648         wall_time=median_wall_time,
    649         name="benchmark_slice_batch_cache_repeat_callable_input_%d_batch_%d" %
    650         (input_size, batch_size))
    651 
    652 
    653 if __name__ == "__main__":
    654   test.main()
    655