Home | History | Annotate | Download | only in kernel_tests
      1 # -*- coding: utf-8 -*-
      2 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #     http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 # ==============================================================================
     16 """Tests for the experimental input pipeline ops."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import math
     22 
     23 import numpy as np
     24 
     25 from tensorflow.python.data.ops import dataset_ops
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import errors
     29 from tensorflow.python.framework import sparse_tensor
     30 from tensorflow.python.framework import tensor_shape
     31 from tensorflow.python.ops import array_ops
     32 from tensorflow.python.ops import math_ops
     33 from tensorflow.python.ops import string_ops
     34 from tensorflow.python.platform import test
     35 from tensorflow.python.util import compat
     36 
     37 
     38 class BatchDatasetTest(test.TestCase):
     39 
     40   def testBatchDataset(self):
     41     """Test an dataset that maps a TF function across its input elements."""
     42     # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
     43     # RepeatDataset(count) -> BatchDataset(batch_size).
     44     components = (np.arange(7),
     45                   np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
     46                   np.array(37.0) * np.arange(7))
     47 
     48     count = array_ops.placeholder(dtypes.int64, shape=[])
     49     batch_size = array_ops.placeholder(dtypes.int64, shape=[])
     50 
     51     def _map_fn(x, y, z):
     52       return math_ops.square(x), math_ops.square(y), math_ops.square(z)
     53 
     54     iterator = (
     55         dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
     56         .repeat(count).batch(batch_size).make_initializable_iterator())
     57     init_op = iterator.initializer
     58     get_next = iterator.get_next()
     59 
     60     self.assertEqual([[None] + list(c.shape[1:]) for c in components],
     61                      [t.shape.as_list() for t in get_next])
     62 
     63     with self.test_session() as sess:
     64       # Batch of a finite input, where the batch_size divides the
     65       # total number of elements.
     66       sess.run(init_op, feed_dict={count: 28, batch_size: 14})
     67       num_batches = (28 * 7) // 14
     68       for i in range(num_batches):
     69         result = sess.run(get_next)
     70         for component, result_component in zip(components, result):
     71           for j in range(14):
     72             self.assertAllEqual(component[(i * 14 + j) % 7]**2,
     73                                 result_component[j])
     74       with self.assertRaises(errors.OutOfRangeError):
     75         sess.run(get_next)
     76 
     77       # Batch of a finite input, where the batch_size does not
     78       # divide the total number of elements.
     79       sess.run(init_op, feed_dict={count: 14, batch_size: 8})
     80 
     81       # We expect (num_batches - 1) full-sized batches.
     82       num_batches = int(math.ceil((14 * 7) / 8))
     83       for i in range(num_batches - 1):
     84         result = sess.run(get_next)
     85         for component, result_component in zip(components, result):
     86           for j in range(8):
     87             self.assertAllEqual(component[(i * 8 + j) % 7]**2,
     88                                 result_component[j])
     89       result = sess.run(get_next)
     90       for component, result_component in zip(components, result):
     91         for j in range((14 * 7) % 8):
     92           self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
     93                               result_component[j])
     94       with self.assertRaises(errors.OutOfRangeError):
     95         sess.run(get_next)
     96 
     97       # Batch of an empty input should fail straight away.
     98       sess.run(init_op, feed_dict={count: 0, batch_size: 8})
     99       with self.assertRaises(errors.OutOfRangeError):
    100         sess.run(get_next)
    101 
    102       # Empty batch should be an initialization time error.
    103       with self.assertRaises(errors.InvalidArgumentError):
    104         sess.run(init_op, feed_dict={count: 14, batch_size: 0})
    105 
    106   def assertSparseValuesEqual(self, a, b):
    107     self.assertAllEqual(a.indices, b.indices)
    108     self.assertAllEqual(a.values, b.values)
    109     self.assertAllEqual(a.dense_shape, b.dense_shape)
    110 
    111   def testBatchSparse(self):
    112 
    113     def _sparse(i):
    114       return sparse_tensor.SparseTensorValue(
    115           indices=[[0]], values=(i * [1]), dense_shape=[1])
    116 
    117     iterator = dataset_ops.Dataset.range(10).map(_sparse).batch(
    118         5).make_initializable_iterator()
    119     init_op = iterator.initializer
    120     get_next = iterator.get_next()
    121 
    122     with self.test_session() as sess:
    123       sess.run(init_op)
    124       for i in range(2):
    125         actual = sess.run(get_next)
    126         expected = sparse_tensor.SparseTensorValue(
    127             indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
    128             values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
    129             dense_shape=[5, 1])
    130         self.assertTrue(sparse_tensor.is_sparse(actual))
    131         self.assertSparseValuesEqual(actual, expected)
    132       with self.assertRaises(errors.OutOfRangeError):
    133         sess.run(get_next)
    134 
    135   def testBatchSparseWithDifferentDenseShapes(self):
    136 
    137     def _sparse(i):
    138       return sparse_tensor.SparseTensorValue(
    139           indices=array_ops.expand_dims(
    140               math_ops.range(i, dtype=dtypes.int64), 1),
    141           values=array_ops.fill([math_ops.to_int32(i)], i),
    142           dense_shape=[i])
    143 
    144     iterator = dataset_ops.Dataset.range(10).map(_sparse).batch(
    145         5).make_initializable_iterator()
    146     init_op = iterator.initializer
    147     get_next = iterator.get_next()
    148 
    149     with self.test_session() as sess:
    150       sess.run(init_op)
    151       for i in range(2):
    152         actual = sess.run(get_next)
    153         expected_indices = []
    154         expected_values = []
    155         for j in range(5):
    156           for k in range(i * 5 + j):
    157             expected_indices.append([j, k])
    158             expected_values.append(i * 5 + j)
    159         expected = sparse_tensor.SparseTensorValue(
    160             indices=expected_indices,
    161             values=expected_values,
    162             dense_shape=[5, (i + 1) * 5 - 1])
    163         self.assertTrue(sparse_tensor.is_sparse(actual))
    164         self.assertSparseValuesEqual(actual, expected)
    165       with self.assertRaises(errors.OutOfRangeError):
    166         sess.run(get_next)
    167 
    168   def testNestedBatchSparse(self):
    169 
    170     def _sparse(i):
    171       return sparse_tensor.SparseTensorValue(
    172           indices=[[0]], values=(i * [1]), dense_shape=[1])
    173 
    174     iterator = dataset_ops.Dataset.range(10).map(_sparse).batch(5).batch(
    175         2).make_initializable_iterator()
    176     init_op = iterator.initializer
    177     get_next = iterator.get_next()
    178 
    179     with self.test_session() as sess:
    180       sess.run(init_op)
    181       actual = sess.run(get_next)
    182       expected = sparse_tensor.SparseTensorValue(
    183           indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0],
    184                    [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [1, 4, 0]],
    185           values=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
    186           dense_shape=[2, 5, 1])
    187       self.assertTrue(sparse_tensor.is_sparse(actual))
    188       self.assertSparseValuesEqual(actual, expected)
    189       with self.assertRaises(errors.OutOfRangeError):
    190         sess.run(get_next)
    191 
    192   def testBatchShapeError(self):
    193 
    194     def generator():
    195       yield [1.0, 2.0, 3.0]
    196       yield [4.0, 5.0, 6.0]
    197       yield [7.0, 8.0, 9.0, 10.0]
    198 
    199     iterator = (
    200         dataset_ops.Dataset.from_generator(
    201             generator, dtypes.float32, output_shapes=[None]).batch(3)
    202         .make_initializable_iterator())
    203     next_element = iterator.get_next()
    204 
    205     with self.test_session() as sess:
    206       sess.run(iterator.initializer)
    207       with self.assertRaisesRegexp(
    208           errors.InvalidArgumentError,
    209           r'Cannot batch tensors with different shapes in component 0. '
    210           r'First element had shape \[3\] and element 2 had shape \[4\].'):
    211         sess.run(next_element)
    212 
    213   def testPaddedBatchDataset(self):
    214     seq_lens = array_ops.placeholder(dtypes.int32, shape=[None])
    215     padded_shape = array_ops.placeholder(dtypes.int64, shape=[1])
    216 
    217     iterator = (
    218         dataset_ops.Dataset.from_tensor_slices(seq_lens)
    219         .map(lambda x: array_ops.fill([x], x)).padded_batch(
    220             4, padded_shapes=padded_shape).make_initializable_iterator())
    221 
    222     init_op = iterator.initializer
    223     get_next = iterator.get_next()
    224 
    225     with self.test_session() as sess:
    226       # Test with random sequence lengths, and max padding.
    227       random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
    228       sess.run(
    229           init_op, feed_dict={
    230               padded_shape: [-1],
    231               seq_lens: random_seq_lens
    232           })
    233       for i in range(8):
    234         result = sess.run(get_next)
    235         padded_len = np.max(result)
    236         self.assertEqual((4, padded_len), result.shape)
    237         for j in range(4):
    238           seq_len = random_seq_lens[(i * 4) + j]
    239           self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
    240           self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len))
    241       with self.assertRaises(errors.OutOfRangeError):
    242         sess.run(get_next)
    243 
    244       # Test with random sequence lengths, and constant padding.
    245       sess.run(
    246           init_op, feed_dict={
    247               padded_shape: [25],
    248               seq_lens: random_seq_lens
    249           })
    250       for i in range(8):
    251         result = sess.run(get_next)
    252         self.assertEqual((4, 25), result.shape)
    253         for j in range(4):
    254           seq_len = random_seq_lens[(i * 4) + j]
    255           self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
    256           self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len))
    257       with self.assertRaises(errors.OutOfRangeError):
    258         sess.run(get_next)
    259 
    260       # Test correct handling of empty tensors.
    261       sess.run(init_op, feed_dict={padded_shape: [-1], seq_lens: [0, 0, 0, 0]})
    262       result = sess.run(get_next)
    263       self.assertAllEqual([[], [], [], []], result)
    264       with self.assertRaises(errors.OutOfRangeError):
    265         sess.run(get_next)
    266 
    267       # Test error handling with constant sequence lengths, and
    268       # too-short padding.
    269       sess.run(init_op, feed_dict={padded_shape: [5], seq_lens: [6, 5, 5, 5]})
    270       with self.assertRaises(errors.DataLossError):
    271         result = sess.run(get_next)
    272 
    273   def testPaddedBatchDatasetNonDefaultPadding(self):
    274     seq_lens = array_ops.placeholder(dtypes.int32, shape=[None])
    275     padded_shape = array_ops.placeholder(dtypes.int64, shape=[1])
    276 
    277     def fill_tuple(x):
    278       filled = array_ops.fill([x], x)
    279       return (filled, string_ops.as_string(filled))
    280 
    281     iterator = (
    282         dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple)
    283         .padded_batch(
    284             4,
    285             padded_shapes=(padded_shape, padded_shape),
    286             padding_values=(-1, '<end>')).make_initializable_iterator())
    287 
    288     init_op = iterator.initializer
    289     get_next = iterator.get_next()
    290 
    291     with self.test_session() as sess:
    292       # Test with random sequence lengths, and max padding.
    293       random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
    294       sess.run(
    295           init_op, feed_dict={
    296               padded_shape: [-1],
    297               seq_lens: random_seq_lens
    298           })
    299       for i in range(8):
    300         result = sess.run(get_next)
    301         padded_len = np.max(result[0])
    302         self.assertEqual((4, padded_len), result[0].shape)
    303         self.assertEqual((4, padded_len), result[1].shape)
    304         for j in range(4):
    305           seq_len = random_seq_lens[(i * 4) + j]
    306           self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len)
    307           self.assertAllEqual(result[0][j, seq_len:],
    308                               [-1] * (padded_len - seq_len))
    309           self.assertAllEqual(result[1][j, :seq_len],
    310                               [compat.as_bytes(str(seq_len))] * seq_len)
    311           self.assertAllEqual(result[1][j, seq_len:],
    312                               [b'<end>'] * (padded_len - seq_len))
    313       with self.assertRaises(errors.OutOfRangeError):
    314         sess.run(get_next)
    315 
    316   def testPaddedBatchDatasetUnicode(self):
    317     # See GitHub issue 16149
    318     def generator():
    319       data = [[u'', u'', u''],
    320               [u'', u'', u'', u'']]
    321 
    322       for seq in data:
    323         yield seq, [0, 1, 2, 3]
    324 
    325     dataset = dataset_ops.Dataset.from_generator(
    326         generator, (dtypes.string, dtypes.int32),
    327         (tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None])))
    328     padded_dataset = dataset.padded_batch(
    329         2, padded_shapes=([None], [None]), padding_values=('', 0))
    330     with self.test_session() as sess:
    331       next_element = padded_dataset.make_one_shot_iterator().get_next()
    332       sess.run(next_element)
    333 
    334   def testPaddedBatchDatasetShapeSpecifications(self):
    335     int_placeholder = array_ops.placeholder(dtypes.int32)
    336     float_placeholder = array_ops.placeholder(dtypes.float32)
    337     string_placeholder = array_ops.placeholder(dtypes.string)
    338     input_dataset = dataset_ops.Dataset.from_tensors(
    339         (int_placeholder, float_placeholder, string_placeholder))
    340 
    341     # Test different ways of specifying the `padded_shapes` argument.
    342     dynamic_padding_from_tensor_shapes = input_dataset.padded_batch(
    343         32,
    344         padded_shapes=(tensor_shape.TensorShape([None]),
    345                        tensor_shape.TensorShape([None, None]),
    346                        tensor_shape.TensorShape([37])))
    347     dynamic_padding_from_lists = input_dataset.padded_batch(
    348         32, padded_shapes=([None], [None, None], [37]))
    349     dynamic_padding_from_lists_with_minus_one = input_dataset.padded_batch(
    350         32, padded_shapes=([-1], [-1, -1], [37]))
    351     dynamic_padding_from_tensors = input_dataset.padded_batch(
    352         32,
    353         padded_shapes=(constant_op.constant([-1], dtype=dtypes.int64),
    354                        constant_op.constant([-1, -1], dtype=dtypes.int64),
    355                        constant_op.constant([37], dtype=dtypes.int64)))
    356 
    357     for dataset in [
    358         dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists,
    359         dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors
    360     ]:
    361       self.assertEqual([None, None], dataset.output_shapes[0].as_list())
    362       self.assertEqual([None, None, None], dataset.output_shapes[1].as_list())
    363       self.assertEqual([None, 37], dataset.output_shapes[2].as_list())
    364 
    365   def testPaddedBatchSparseError(self):
    366 
    367     def _map_fn(i):
    368       return sparse_tensor.SparseTensorValue(
    369           indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i
    370 
    371     with self.assertRaises(TypeError):
    372       _ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10)
    373 
    374 
    375 if __name__ == '__main__':
    376   test.main()
    377