Home | History | Annotate | Download | only in python
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 from __future__ import absolute_import
     16 from __future__ import division
     17 from __future__ import print_function
     18 
     19 import os
     20 
     21 import threading
     22 import time
     23 
     24 import numpy as np
     25 
     26 from tensorflow.contrib import lookup
     27 from tensorflow.contrib.eager.python import datasets
     28 from tensorflow.python.data import Dataset
     29 from tensorflow.python.data.experimental.ops import threadpool
     30 from tensorflow.python.data.experimental.ops import unique
     31 from tensorflow.python.eager import test
     32 from tensorflow.python.framework import constant_op
     33 from tensorflow.python.framework import dtypes
     34 from tensorflow.python.framework import errors
     35 from tensorflow.python.framework import ops
     36 from tensorflow.python.framework import sparse_tensor
     37 from tensorflow.python.ops import math_ops
     38 from tensorflow.python.ops import script_ops
     39 from tensorflow.python.training import checkpoint_management
     40 from tensorflow.python.training.tracking import util as trackable_utils
     41 
     42 
     43 class IteratorTest(test.TestCase):
     44 
     45   def testBasic(self):
     46     got = []
     47     for t in datasets.Iterator(Dataset.range(4)):
     48       got.append(t.numpy())
     49     self.assertAllEqual([0, 1, 2, 3], got)
     50 
     51   def testBasicOneShotIterator(self):
     52     got = []
     53     for t in Dataset.range(4).make_one_shot_iterator():
     54       got.append(t.numpy())
     55     self.assertAllEqual([0, 1, 2, 3], got)
     56 
     57   def testBasicImplicitIterator(self):
     58     got = []
     59     for t in Dataset.range(4):
     60       got.append(t.numpy())
     61     self.assertAllEqual([0, 1, 2, 3], got)
     62 
     63   def testGetNext(self):
     64     iterator = datasets.Iterator(Dataset.range(4))
     65     self.assertEqual(0, iterator.get_next().numpy())
     66     self.assertEqual(1, iterator.get_next().numpy())
     67     self.assertEqual(2, iterator.get_next().numpy())
     68     self.assertEqual(3, iterator.get_next().numpy())
     69     with self.assertRaises(errors.OutOfRangeError):
     70       iterator.get_next()
     71 
     72   def testGetNextOneShotIterator(self):
     73     iterator = Dataset.range(4).make_one_shot_iterator()
     74     self.assertEqual(0, iterator.get_next().numpy())
     75     self.assertEqual(1, iterator.get_next().numpy())
     76     self.assertEqual(2, iterator.get_next().numpy())
     77     self.assertEqual(3, iterator.get_next().numpy())
     78     with self.assertRaises(errors.OutOfRangeError):
     79       iterator.get_next()
     80 
     81   def testMultipleIteratorsOnTheSameDataset(self):
     82     ds = Dataset.range(4)
     83     it1 = datasets.Iterator(ds)
     84     it2 = datasets.Iterator(ds)
     85     got = [x.numpy() for x in it1]
     86     self.assertAllEqual([0, 1, 2, 3], got)
     87 
     88     got = [x.numpy() for x in it2]
     89     self.assertAllEqual([0, 1, 2, 3], got)
     90 
     91   def testNestedOutputs(self):
     92     ds = Dataset.zip((Dataset.range(4), Dataset.zip((Dataset.range(4),
     93                                                      Dataset.range(4)))))
     94     total = 0
     95     # The Iterator will return a nested structure of Tensor objects.
     96     # Some funkiness to compare against simple integers.
     97     for (i, x) in enumerate(datasets.Iterator(ds)):
     98       want = (i, (i, i))
     99       got = (x[0].numpy(), (x[1][0].numpy(), x[1][1].numpy()))
    100       self.assertEqual(got, want)
    101       total += 1
    102     self.assertEqual(4, total)
    103 
    104   def testMapAndFilter(self):
    105     def even(x):
    106       return math_ops.equal(math_ops.mod(x, 2), 0)
    107 
    108     it = datasets.Iterator(Dataset.range(8).map(math_ops.square).filter(even))
    109     got = [x.numpy() for x in it]
    110     self.assertAllEqual([0, 4, 16, 36], got)
    111 
    112   def testMapCaptureLookupTable(self):
    113     default_val = -1
    114     keys = constant_op.constant(['brain', 'salad', 'surgery'])
    115     values = constant_op.constant([0, 1, 2], dtypes.int64)
    116     table = lookup.HashTable(
    117         lookup.KeyValueTensorInitializer(keys, values), default_val)
    118     dataset = Dataset.from_tensor_slices(['brain', 'salad', 'surgery'])
    119     dataset = dataset.map(table.lookup)
    120     it = datasets.Iterator(dataset)
    121     got = [x.numpy() for x in it]
    122     self.assertAllEqual([0, 1, 2], got)
    123 
    124   def testMultipleIteratorsOnADatasetThatUsesFunctions(self):
    125     ds = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(math_ops.square)
    126 
    127     got1 = [x.numpy() for x in datasets.Iterator(ds)]
    128     self.assertAllEqual([1, 4, 9, 16, 25, 36], got1)
    129     got2 = [x.numpy() for x in datasets.Iterator(ds)]
    130     self.assertAllEqual(got1, got2)
    131 
    132   def assertSparseValuesEqual(self, a, b):
    133     self.assertAllEqual(a.indices, b.indices)
    134     self.assertAllEqual(a.values, b.values)
    135     self.assertAllEqual(a.dense_shape, b.dense_shape)
    136 
    137   def testSparseTensorElements(self):
    138     components = (sparse_tensor.SparseTensorValue(
    139         indices=np.array([[0, 0], [1, 0], [2, 0]]),
    140         values=np.array([0, 0, 0]),
    141         dense_shape=np.array([3, 1])),
    142                   sparse_tensor.SparseTensorValue(
    143                       indices=np.array([[0, 0], [1, 1], [2, 2]]),
    144                       values=np.array([1, 2, 3]),
    145                       dense_shape=np.array([3, 3])))
    146 
    147     expected = [
    148         (sparse_tensor.SparseTensorValue(
    149             indices=np.array([[0]]),
    150             values=np.array([0]),
    151             dense_shape=np.array([1])),
    152          sparse_tensor.SparseTensorValue(
    153              indices=np.array([[0]]),
    154              values=np.array([1]),
    155              dense_shape=np.array([3]))),
    156         (sparse_tensor.SparseTensorValue(
    157             indices=np.array([[0]]),
    158             values=np.array([0]),
    159             dense_shape=np.array([1])),
    160          sparse_tensor.SparseTensorValue(
    161              indices=np.array([[1]]),
    162              values=np.array([2]),
    163              dense_shape=np.array([3]))),
    164         (sparse_tensor.SparseTensorValue(
    165             indices=np.array([[0]]),
    166             values=np.array([0]),
    167             dense_shape=np.array([1])),
    168          sparse_tensor.SparseTensorValue(
    169              indices=np.array([[2]]),
    170              values=np.array([3]),
    171              dense_shape=np.array([3]))),
    172     ]
    173 
    174     for i, result in enumerate(
    175         datasets.Iterator(Dataset.from_tensor_slices(components))):
    176       self.assertSparseValuesEqual(expected[i][0], result[0])
    177       self.assertSparseValuesEqual(expected[i][1], result[1])
    178 
    179   def testPyFunc(self):
    180 
    181     def my_map(inp):
    182       return [[x + 1 for x in inp]]
    183 
    184     ds = Dataset.range(4).map(
    185         lambda x: script_ops.py_func(my_map, [[x]], dtypes.int64))
    186     got = [x.numpy() for x in datasets.Iterator(ds)]
    187     self.assertAllEqual([[1], [2], [3], [4]], got)
    188 
    189   def testTensorsPlacedOnDevice(self):
    190     ds = Dataset.from_tensors([0., 1.])
    191     with ops.device(test.gpu_device_name()):
    192       x = datasets.Iterator(ds).next()
    193       x = math_ops.add(x, x)
    194     self.assertAllEqual([0., 2.], x.numpy())
    195 
    196   def testGpuTensor(self):
    197     ds = Dataset.from_tensors([0., 1.])
    198     with ops.device(test.gpu_device_name()):
    199       for x in ds:
    200         y = math_ops.add(x, x)
    201     self.assertAllEqual([0., 2.], y.numpy())
    202 
    203   def testOverrideThreadPool(self):
    204 
    205     def get_thread_id(_):
    206       # Python creates a dummy thread object to represent the current
    207       # thread when called from an "alien" thread (such as a
    208       # `PrivateThreadPool` thread in this case). It does not include
    209       # the TensorFlow-given display name, but it has a unique
    210       # identifier that maps one-to-one with the underlying OS thread.
    211       return np.array(threading.current_thread().ident).astype(np.int64)
    212 
    213     for num_threads in [1, 2, 4, 8, 16]:
    214 
    215       dataset = (
    216           Dataset.range(1000).map(
    217               lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
    218               num_parallel_calls=32).apply(unique.unique()))
    219 
    220       dataset = threadpool.override_threadpool(
    221           dataset,
    222           threadpool.PrivateThreadPool(
    223               num_threads, display_name='private_thread_pool_%d' % num_threads))
    224 
    225       thread_ids = []
    226       for next_element in datasets.Iterator(dataset):
    227         thread_ids.append(next_element)
    228       self.assertEqual(len(thread_ids), len(set(thread_ids)))
    229       self.assertGreater(len(thread_ids), 0)
    230       # NOTE(mrry): We don't control the thread pool scheduling, and
    231       # so cannot guarantee that all of the threads in the pool will
    232       # perform work.
    233       self.assertLessEqual(len(thread_ids), num_threads)
    234 
    235   def testSaveRestore(self):
    236     checkpoint_directory = self.get_temp_dir()
    237     checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
    238     dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
    239     dataset = dataset.map(math_ops.square).batch(2)
    240     iterator = datasets.Iterator(dataset)
    241     checkpoint = trackable_utils.Checkpoint(iterator=iterator)
    242     self.assertAllEqual([1, 4], iterator.get_next().numpy())
    243     save_path = checkpoint.save(checkpoint_prefix)
    244     self.assertAllEqual([9, 16], iterator.get_next().numpy())
    245     self.assertAllEqual([25, 36], iterator.get_next().numpy())
    246     checkpoint.restore(save_path)
    247     self.assertAllEqual([9, 16], iterator.get_next().numpy())
    248     self.assertAllEqual([25, 36], iterator.get_next().numpy())
    249 
    250   def testSaveRestoreMultipleIterator(self):
    251     checkpoint_directory = self.get_temp_dir()
    252     checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
    253     dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
    254     dataset = dataset.map(math_ops.square).batch(2)
    255     iterator_1 = datasets.Iterator(dataset)
    256     iterator_2 = datasets.Iterator(dataset)
    257     dataset_2 = Dataset.range(10)
    258     iterator_3 = datasets.Iterator(dataset_2)
    259 
    260     checkpoint = trackable_utils.Checkpoint(
    261         iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
    262     self.assertAllEqual([1, 4], iterator_1.get_next().numpy())
    263     self.assertEqual(0, iterator_3.get_next().numpy())
    264     self.assertEqual(1, iterator_3.get_next().numpy())
    265     self.assertEqual(2, iterator_3.get_next().numpy())
    266 
    267     save_path = checkpoint.save(checkpoint_prefix)
    268     self.assertAllEqual([1, 4], iterator_2.get_next().numpy())
    269     self.assertAllEqual([9, 16], iterator_2.get_next().numpy())
    270     self.assertEqual(3, iterator_3.get_next().numpy())
    271     checkpoint.restore(save_path)
    272     self.assertAllEqual([9, 16], iterator_1.get_next().numpy())
    273     self.assertAllEqual([1, 4], iterator_2.get_next().numpy())
    274     self.assertEqual(3, iterator_3.get_next().numpy())
    275 
    276   def testRestoreExhaustedIterator(self):
    277     checkpoint_directory = self.get_temp_dir()
    278     checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
    279     dataset = Dataset.range(3)
    280     iterator = datasets.Iterator(dataset)
    281 
    282     checkpoint = trackable_utils.Checkpoint(iterator=iterator)
    283     self.assertEqual(0, iterator.get_next().numpy())
    284     self.assertEqual(1, iterator.get_next().numpy())
    285     save_path = checkpoint.save(checkpoint_prefix)
    286     self.assertEqual(2, iterator.get_next().numpy())
    287     checkpoint.restore(save_path)
    288     self.assertEqual(2, iterator.get_next().numpy())
    289 
    290   def testRestoreInReconstructedIterator(self):
    291     checkpoint_directory = self.get_temp_dir()
    292     checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
    293     dataset = Dataset.range(10)
    294     for i in range(5):
    295       iterator = datasets.Iterator(dataset)
    296       checkpoint = trackable_utils.Checkpoint(iterator=iterator)
    297       checkpoint.restore(checkpoint_management.latest_checkpoint(
    298           checkpoint_directory))
    299       for j in range(2):
    300         self.assertEqual(i * 2 + j, iterator.get_next().numpy())
    301       checkpoint.save(file_prefix=checkpoint_prefix)
    302 
    303 
    304 class DatasetConstructorBenchmark(test.Benchmark):
    305 
    306   def benchmarkSliceRepeatBatchEager(self):
    307     input_size = 10000
    308     batch_size = 100
    309     num_epochs = 100
    310 
    311     input_data = np.random.randn(input_size)
    312 
    313     dataset = (
    314         Dataset.from_tensor_slices(input_data).repeat(num_epochs)
    315         .batch(batch_size))
    316     iterator = datasets.Iterator(dataset)
    317 
    318     ends = [time.time()]
    319     for _ in iterator:
    320       ends.append(time.time())
    321 
    322     deltas = np.ediff1d(ends)
    323     median_wall_time = np.median(deltas)
    324     print(
    325         'Slice/repeat/batch eager input size: %d batch size: %d Median wall '
    326         'time per element: %f'
    327         % (input_size, batch_size, median_wall_time))
    328     self.report_benchmark(
    329         iters=len(deltas),
    330         wall_time=median_wall_time,
    331         name='benchmark_slice_repeat_batch_eager_input_%d_batch_%d' %
    332         (input_size, batch_size))
    333 
    334   def benchmarkSliceBatchCacheRepeatCallable(self):
    335     input_size = 10000
    336     batch_size = 100
    337     num_epochs = 100
    338 
    339     input_data = np.random.randn(input_size)
    340 
    341     dataset = (
    342         Dataset.from_tensor_slices(input_data).batch(batch_size).cache()
    343         .repeat(num_epochs))
    344     iterator = datasets.Iterator(dataset)
    345 
    346     ends = [time.time()]
    347     for _ in iterator:
    348       ends.append(time.time())
    349 
    350     deltas = np.ediff1d(ends)
    351     median_wall_time = np.median(deltas)
    352     print(
    353         'Slice/batch/cache/repeat eager input size: %d batch size: %d Median '
    354         'wall time per element: %f'
    355         % (input_size, batch_size, median_wall_time))
    356     self.report_benchmark(
    357         iters=len(deltas),
    358         wall_time=median_wall_time,
    359         name='benchmark_slice_batch_cache_repeat_eager_input_%d_batch_%d' %
    360         (input_size, batch_size))
    361 
    362 
    363 if __name__ == '__main__':
    364   test.main()
    365