Home | History | Annotate | Download | only in tpu
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """TPU Estimator Signalling Tests."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.client import session
     24 from tensorflow.python.data.ops import dataset_ops
     25 from tensorflow.python.framework import errors
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.platform import test
     28 from tensorflow.python.tpu import tpu_estimator
     29 
     30 
     31 def make_input_fn(num_samples):
     32   a = np.linspace(0, 100.0, num=num_samples)
     33   b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1))
     34 
     35   def input_fn(params):
     36     batch_size = params['batch_size']
     37     da1 = dataset_ops.Dataset.from_tensor_slices(a)
     38     da2 = dataset_ops.Dataset.from_tensor_slices(b)
     39 
     40     dataset = dataset_ops.Dataset.zip((da1, da2))
     41     dataset = dataset.map(lambda fa, fb: {'a': fa, 'b': fb})
     42     dataset = dataset.batch(batch_size)
     43     return dataset
     44   return input_fn, (a, b)
     45 
     46 
     47 def make_input_fn_with_labels(num_samples):
     48   a = np.linspace(0, 100.0, num=num_samples)
     49   b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1))
     50 
     51   def input_fn(params):
     52     batch_size = params['batch_size']
     53     da1 = dataset_ops.Dataset.from_tensor_slices(a)
     54     da2 = dataset_ops.Dataset.from_tensor_slices(b)
     55 
     56     dataset = dataset_ops.Dataset.zip((da1, da2))
     57     dataset = dataset.map(lambda fa, fb: ({'a': fa}, fb))
     58     dataset = dataset.batch(batch_size)
     59     return dataset
     60   return input_fn, (a, b)
     61 
     62 
     63 class TPUEstimatorStoppingSignalsTest(test.TestCase):
     64 
     65   def test_normal_output_without_signals(self):
     66     num_samples = 4
     67     batch_size = 2
     68 
     69     params = {'batch_size': batch_size}
     70     input_fn, (a, b) = make_input_fn(num_samples=num_samples)
     71 
     72     with ops.Graph().as_default():
     73       dataset = input_fn(params)
     74       features = dataset_ops.make_one_shot_iterator(dataset).get_next()
     75 
     76       # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape.
     77       self.assertIsNone(features['a'].shape.as_list()[0])
     78 
     79       with session.Session() as sess:
     80         result = sess.run(features)
     81         self.assertAllEqual(a[:batch_size], result['a'])
     82         self.assertAllEqual(b[:batch_size], result['b'])
     83 
     84         # This run should work as num_samples / batch_size = 2.
     85         result = sess.run(features)
     86         self.assertAllEqual(a[batch_size:num_samples], result['a'])
     87         self.assertAllEqual(b[batch_size:num_samples], result['b'])
     88 
     89         with self.assertRaises(errors.OutOfRangeError):
     90           # Given num_samples and batch_size, this run should fail.
     91           sess.run(features)
     92 
     93   def test_output_with_stopping_signals(self):
     94     num_samples = 4
     95     batch_size = 2
     96 
     97     params = {'batch_size': batch_size}
     98     input_fn, (a, b) = make_input_fn(num_samples=num_samples)
     99 
    100     with ops.Graph().as_default():
    101       dataset = input_fn(params)
    102       inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size)
    103       dataset_initializer = inputs.dataset_initializer()
    104       features, _ = inputs.features_and_labels()
    105       signals = inputs.signals()
    106 
    107       # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape.
    108       self.assertIsNone(features['a'].shape.as_list()[0])
    109 
    110       with session.Session() as sess:
    111         sess.run(dataset_initializer)
    112 
    113         result, evaluated_signals = sess.run([features, signals])
    114         self.assertAllEqual(a[:batch_size], result['a'])
    115         self.assertAllEqual(b[:batch_size], result['b'])
    116         self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
    117 
    118         # This run should work as num_samples / batch_size = 2.
    119         result, evaluated_signals = sess.run([features, signals])
    120         self.assertAllEqual(a[batch_size:num_samples], result['a'])
    121         self.assertAllEqual(b[batch_size:num_samples], result['b'])
    122         self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
    123 
    124         # This run should work, *but* see STOP ('1') as signals
    125         _, evaluated_signals = sess.run([features, signals])
    126         self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
    127 
    128         with self.assertRaises(errors.OutOfRangeError):
    129           sess.run(features)
    130 
    131 
    132 class TPUEstimatorStoppingSignalsWithPaddingTest(test.TestCase):
    133 
    134   def test_num_samples_divisible_by_batch_size(self):
    135     num_samples = 4
    136     batch_size = 2
    137 
    138     params = {'batch_size': batch_size}
    139     input_fn, (a, b) = make_input_fn(num_samples=num_samples)
    140 
    141     with ops.Graph().as_default():
    142       dataset = input_fn(params)
    143       inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size,
    144                                                         add_padding=True)
    145       dataset_initializer = inputs.dataset_initializer()
    146       features, _ = inputs.features_and_labels()
    147       signals = inputs.signals()
    148 
    149       # With padding, all shapes are static now.
    150       self.assertEqual(batch_size, features['a'].shape.as_list()[0])
    151 
    152       with session.Session() as sess:
    153         sess.run(dataset_initializer)
    154 
    155         result, evaluated_signals = sess.run([features, signals])
    156         self.assertAllEqual(a[:batch_size], result['a'])
    157         self.assertAllEqual(b[:batch_size], result['b'])
    158         self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
    159         self.assertAllEqual([0.] * batch_size,
    160                             evaluated_signals['padding_mask'])
    161 
    162         # This run should work as num_samples / batch_size = 2.
    163         result, evaluated_signals = sess.run([features, signals])
    164         self.assertAllEqual(a[batch_size:num_samples], result['a'])
    165         self.assertAllEqual(b[batch_size:num_samples], result['b'])
    166         self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
    167         self.assertAllEqual([0.] * batch_size,
    168                             evaluated_signals['padding_mask'])
    169 
    170         # This run should work, *but* see STOP ('1') as signals
    171         _, evaluated_signals = sess.run([features, signals])
    172         self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
    173 
    174         with self.assertRaises(errors.OutOfRangeError):
    175           sess.run(features)
    176 
    177   def test_num_samples_not_divisible_by_batch_size(self):
    178     num_samples = 5
    179     batch_size = 2
    180 
    181     params = {'batch_size': batch_size}
    182     input_fn, (a, b) = make_input_fn_with_labels(num_samples=num_samples)
    183 
    184     with ops.Graph().as_default():
    185       dataset = input_fn(params)
    186       inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size,
    187                                                         add_padding=True)
    188       dataset_initializer = inputs.dataset_initializer()
    189       features, labels = inputs.features_and_labels()
    190       signals = inputs.signals()
    191 
    192       # With padding, all shapes are static.
    193       self.assertEqual(batch_size, features['a'].shape.as_list()[0])
    194 
    195       with session.Session() as sess:
    196         sess.run(dataset_initializer)
    197 
    198         evaluated_features, evaluated_labels, evaluated_signals = (
    199             sess.run([features, labels, signals]))
    200         self.assertAllEqual(a[:batch_size], evaluated_features['a'])
    201         self.assertAllEqual(b[:batch_size], evaluated_labels)
    202         self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
    203         self.assertAllEqual([0.] * batch_size,
    204                             evaluated_signals['padding_mask'])
    205 
    206         # This run should work as num_samples / batch_size >= 2.
    207         evaluated_features, evaluated_labels, evaluated_signals = (
    208             sess.run([features, labels, signals]))
    209         self.assertAllEqual(a[batch_size:2*batch_size], evaluated_features['a'])
    210         self.assertAllEqual(b[batch_size:2*batch_size], evaluated_labels)
    211         self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
    212         self.assertAllEqual([0.] * batch_size,
    213                             evaluated_signals['padding_mask'])
    214 
    215         # This is the final partial batch.
    216         evaluated_features, evaluated_labels, evaluated_signals = (
    217             sess.run([features, labels, signals]))
    218         real_batch_size = num_samples % batch_size
    219 
    220         # Assert the real part.
    221         self.assertAllEqual(a[2*batch_size:num_samples],
    222                             evaluated_features['a'][:real_batch_size])
    223         self.assertAllEqual(b[2*batch_size:num_samples],
    224                             evaluated_labels[:real_batch_size])
    225         # Assert the padded part.
    226         self.assertAllEqual([0.0] * (batch_size - real_batch_size),
    227                             evaluated_features['a'][real_batch_size:])
    228         self.assertAllEqual([[0.0]] * (batch_size - real_batch_size),
    229                             evaluated_labels[real_batch_size:])
    230 
    231         self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
    232 
    233         padding = ([.0] * real_batch_size
    234                    + [1.] * (batch_size - real_batch_size))
    235         self.assertAllEqual(padding, evaluated_signals['padding_mask'])
    236 
    237         # This run should work, *but* see STOP ('1') as signals
    238         _, evaluated_signals = sess.run([features, signals])
    239         self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
    240 
    241         with self.assertRaises(errors.OutOfRangeError):
    242           sess.run(features)
    243 
    244   def test_slice(self):
    245     num_samples = 3
    246     batch_size = 2
    247 
    248     params = {'batch_size': batch_size}
    249     input_fn, (a, b) = make_input_fn(num_samples=num_samples)
    250 
    251     with ops.Graph().as_default():
    252       dataset = input_fn(params)
    253       inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size,
    254                                                         add_padding=True)
    255       dataset_initializer = inputs.dataset_initializer()
    256       features, _ = inputs.features_and_labels()
    257       signals = inputs.signals()
    258 
    259       sliced_features = (
    260           tpu_estimator._PaddingSignals.slice_tensor_or_dict(
    261               features, signals))
    262 
    263       with session.Session() as sess:
    264         sess.run(dataset_initializer)
    265 
    266         result, evaluated_signals = sess.run([sliced_features, signals])
    267         self.assertAllEqual(a[:batch_size], result['a'])
    268         self.assertAllEqual(b[:batch_size], result['b'])
    269         self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
    270 
    271         # This is the final partial batch.
    272         result, evaluated_signals = sess.run([sliced_features, signals])
    273         self.assertEqual(1, len(result['a']))
    274         self.assertAllEqual(a[batch_size:num_samples], result['a'])
    275         self.assertAllEqual(b[batch_size:num_samples], result['b'])
    276         self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
    277 
    278         # This run should work, *but* see STOP ('1') as signals
    279         _, evaluated_signals = sess.run([sliced_features, signals])
    280         self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
    281 
    282         with self.assertRaises(errors.OutOfRangeError):
    283           sess.run(sliced_features)
    284 
    285   def test_slice_with_multi_invocations_per_step(self):
    286     num_samples = 3
    287     batch_size = 2
    288 
    289     params = {'batch_size': batch_size}
    290     input_fn, (a, b) = make_input_fn(num_samples=num_samples)
    291 
    292     with ops.Graph().as_default():
    293       dataset = input_fn(params)
    294       inputs = tpu_estimator._InputsWithStoppingSignals(
    295           dataset, batch_size, add_padding=True, num_invocations_per_step=2)
    296       dataset_initializer = inputs.dataset_initializer()
    297       features, _ = inputs.features_and_labels()
    298       signals = inputs.signals()
    299 
    300       sliced_features = (
    301           tpu_estimator._PaddingSignals.slice_tensor_or_dict(features, signals))
    302 
    303       with session.Session() as sess:
    304         sess.run(dataset_initializer)
    305 
    306         result, evaluated_signals = sess.run([sliced_features, signals])
    307         self.assertAllEqual(a[:batch_size], result['a'])
    308         self.assertAllEqual(b[:batch_size], result['b'])
    309         self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
    310 
    311         # This is the final partial batch.
    312         result, evaluated_signals = sess.run([sliced_features, signals])
    313         self.assertEqual(1, len(result['a']))
    314         self.assertAllEqual(a[batch_size:num_samples], result['a'])
    315         self.assertAllEqual(b[batch_size:num_samples], result['b'])
    316         self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
    317 
    318         # We should see 3 continuous batches with STOP ('1') as signals and all
    319         # of them have mask 1.
    320         _, evaluated_signals = sess.run([sliced_features, signals])
    321         self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
    322         self.assertAllEqual([1.] * batch_size,
    323                             evaluated_signals['padding_mask'])
    324 
    325         _, evaluated_signals = sess.run([sliced_features, signals])
    326         self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
    327         self.assertAllEqual([1.] * batch_size,
    328                             evaluated_signals['padding_mask'])
    329 
    330         _, evaluated_signals = sess.run([sliced_features, signals])
    331         self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
    332         self.assertAllEqual([1.] * batch_size,
    333                             evaluated_signals['padding_mask'])
    334         with self.assertRaises(errors.OutOfRangeError):
    335           sess.run(sliced_features)
    336 
    337 
    338 if __name__ == '__main__':
    339   test.main()
    340