1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """TPU Estimator Signalling Tests.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.client import session 24 from tensorflow.python.data.ops import dataset_ops 25 from tensorflow.python.framework import errors 26 from tensorflow.python.framework import ops 27 from tensorflow.python.platform import test 28 from tensorflow.python.tpu import tpu_estimator 29 30 31 def make_input_fn(num_samples): 32 a = np.linspace(0, 100.0, num=num_samples) 33 b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1)) 34 35 def input_fn(params): 36 batch_size = params['batch_size'] 37 da1 = dataset_ops.Dataset.from_tensor_slices(a) 38 da2 = dataset_ops.Dataset.from_tensor_slices(b) 39 40 dataset = dataset_ops.Dataset.zip((da1, da2)) 41 dataset = dataset.map(lambda fa, fb: {'a': fa, 'b': fb}) 42 dataset = dataset.batch(batch_size) 43 return dataset 44 return input_fn, (a, b) 45 46 47 def make_input_fn_with_labels(num_samples): 48 a = np.linspace(0, 100.0, num=num_samples) 49 b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1)) 50 51 def input_fn(params): 52 batch_size = params['batch_size'] 53 da1 = dataset_ops.Dataset.from_tensor_slices(a) 54 da2 = dataset_ops.Dataset.from_tensor_slices(b) 55 56 dataset = dataset_ops.Dataset.zip((da1, da2)) 57 dataset = dataset.map(lambda fa, fb: ({'a': fa}, fb)) 58 dataset = dataset.batch(batch_size) 59 return dataset 60 return input_fn, (a, b) 61 62 63 class TPUEstimatorStoppingSignalsTest(test.TestCase): 64 65 def test_normal_output_without_signals(self): 66 num_samples = 4 67 batch_size = 2 68 69 params = {'batch_size': batch_size} 70 input_fn, (a, b) = make_input_fn(num_samples=num_samples) 71 72 with ops.Graph().as_default(): 73 dataset = input_fn(params) 74 features = dataset_ops.make_one_shot_iterator(dataset).get_next() 75 76 # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape. 77 self.assertIsNone(features['a'].shape.as_list()[0]) 78 79 with session.Session() as sess: 80 result = sess.run(features) 81 self.assertAllEqual(a[:batch_size], result['a']) 82 self.assertAllEqual(b[:batch_size], result['b']) 83 84 # This run should work as num_samples / batch_size = 2. 85 result = sess.run(features) 86 self.assertAllEqual(a[batch_size:num_samples], result['a']) 87 self.assertAllEqual(b[batch_size:num_samples], result['b']) 88 89 with self.assertRaises(errors.OutOfRangeError): 90 # Given num_samples and batch_size, this run should fail. 91 sess.run(features) 92 93 def test_output_with_stopping_signals(self): 94 num_samples = 4 95 batch_size = 2 96 97 params = {'batch_size': batch_size} 98 input_fn, (a, b) = make_input_fn(num_samples=num_samples) 99 100 with ops.Graph().as_default(): 101 dataset = input_fn(params) 102 inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size) 103 dataset_initializer = inputs.dataset_initializer() 104 features, _ = inputs.features_and_labels() 105 signals = inputs.signals() 106 107 # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape. 108 self.assertIsNone(features['a'].shape.as_list()[0]) 109 110 with session.Session() as sess: 111 sess.run(dataset_initializer) 112 113 result, evaluated_signals = sess.run([features, signals]) 114 self.assertAllEqual(a[:batch_size], result['a']) 115 self.assertAllEqual(b[:batch_size], result['b']) 116 self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) 117 118 # This run should work as num_samples / batch_size = 2. 119 result, evaluated_signals = sess.run([features, signals]) 120 self.assertAllEqual(a[batch_size:num_samples], result['a']) 121 self.assertAllEqual(b[batch_size:num_samples], result['b']) 122 self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) 123 124 # This run should work, *but* see STOP ('1') as signals 125 _, evaluated_signals = sess.run([features, signals]) 126 self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping']) 127 128 with self.assertRaises(errors.OutOfRangeError): 129 sess.run(features) 130 131 132 class TPUEstimatorStoppingSignalsWithPaddingTest(test.TestCase): 133 134 def test_num_samples_divisible_by_batch_size(self): 135 num_samples = 4 136 batch_size = 2 137 138 params = {'batch_size': batch_size} 139 input_fn, (a, b) = make_input_fn(num_samples=num_samples) 140 141 with ops.Graph().as_default(): 142 dataset = input_fn(params) 143 inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size, 144 add_padding=True) 145 dataset_initializer = inputs.dataset_initializer() 146 features, _ = inputs.features_and_labels() 147 signals = inputs.signals() 148 149 # With padding, all shapes are static now. 150 self.assertEqual(batch_size, features['a'].shape.as_list()[0]) 151 152 with session.Session() as sess: 153 sess.run(dataset_initializer) 154 155 result, evaluated_signals = sess.run([features, signals]) 156 self.assertAllEqual(a[:batch_size], result['a']) 157 self.assertAllEqual(b[:batch_size], result['b']) 158 self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) 159 self.assertAllEqual([0.] * batch_size, 160 evaluated_signals['padding_mask']) 161 162 # This run should work as num_samples / batch_size = 2. 163 result, evaluated_signals = sess.run([features, signals]) 164 self.assertAllEqual(a[batch_size:num_samples], result['a']) 165 self.assertAllEqual(b[batch_size:num_samples], result['b']) 166 self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) 167 self.assertAllEqual([0.] * batch_size, 168 evaluated_signals['padding_mask']) 169 170 # This run should work, *but* see STOP ('1') as signals 171 _, evaluated_signals = sess.run([features, signals]) 172 self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping']) 173 174 with self.assertRaises(errors.OutOfRangeError): 175 sess.run(features) 176 177 def test_num_samples_not_divisible_by_batch_size(self): 178 num_samples = 5 179 batch_size = 2 180 181 params = {'batch_size': batch_size} 182 input_fn, (a, b) = make_input_fn_with_labels(num_samples=num_samples) 183 184 with ops.Graph().as_default(): 185 dataset = input_fn(params) 186 inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size, 187 add_padding=True) 188 dataset_initializer = inputs.dataset_initializer() 189 features, labels = inputs.features_and_labels() 190 signals = inputs.signals() 191 192 # With padding, all shapes are static. 193 self.assertEqual(batch_size, features['a'].shape.as_list()[0]) 194 195 with session.Session() as sess: 196 sess.run(dataset_initializer) 197 198 evaluated_features, evaluated_labels, evaluated_signals = ( 199 sess.run([features, labels, signals])) 200 self.assertAllEqual(a[:batch_size], evaluated_features['a']) 201 self.assertAllEqual(b[:batch_size], evaluated_labels) 202 self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) 203 self.assertAllEqual([0.] * batch_size, 204 evaluated_signals['padding_mask']) 205 206 # This run should work as num_samples / batch_size >= 2. 207 evaluated_features, evaluated_labels, evaluated_signals = ( 208 sess.run([features, labels, signals])) 209 self.assertAllEqual(a[batch_size:2*batch_size], evaluated_features['a']) 210 self.assertAllEqual(b[batch_size:2*batch_size], evaluated_labels) 211 self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) 212 self.assertAllEqual([0.] * batch_size, 213 evaluated_signals['padding_mask']) 214 215 # This is the final partial batch. 216 evaluated_features, evaluated_labels, evaluated_signals = ( 217 sess.run([features, labels, signals])) 218 real_batch_size = num_samples % batch_size 219 220 # Assert the real part. 221 self.assertAllEqual(a[2*batch_size:num_samples], 222 evaluated_features['a'][:real_batch_size]) 223 self.assertAllEqual(b[2*batch_size:num_samples], 224 evaluated_labels[:real_batch_size]) 225 # Assert the padded part. 226 self.assertAllEqual([0.0] * (batch_size - real_batch_size), 227 evaluated_features['a'][real_batch_size:]) 228 self.assertAllEqual([[0.0]] * (batch_size - real_batch_size), 229 evaluated_labels[real_batch_size:]) 230 231 self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) 232 233 padding = ([.0] * real_batch_size 234 + [1.] * (batch_size - real_batch_size)) 235 self.assertAllEqual(padding, evaluated_signals['padding_mask']) 236 237 # This run should work, *but* see STOP ('1') as signals 238 _, evaluated_signals = sess.run([features, signals]) 239 self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping']) 240 241 with self.assertRaises(errors.OutOfRangeError): 242 sess.run(features) 243 244 def test_slice(self): 245 num_samples = 3 246 batch_size = 2 247 248 params = {'batch_size': batch_size} 249 input_fn, (a, b) = make_input_fn(num_samples=num_samples) 250 251 with ops.Graph().as_default(): 252 dataset = input_fn(params) 253 inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size, 254 add_padding=True) 255 dataset_initializer = inputs.dataset_initializer() 256 features, _ = inputs.features_and_labels() 257 signals = inputs.signals() 258 259 sliced_features = ( 260 tpu_estimator._PaddingSignals.slice_tensor_or_dict( 261 features, signals)) 262 263 with session.Session() as sess: 264 sess.run(dataset_initializer) 265 266 result, evaluated_signals = sess.run([sliced_features, signals]) 267 self.assertAllEqual(a[:batch_size], result['a']) 268 self.assertAllEqual(b[:batch_size], result['b']) 269 self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) 270 271 # This is the final partial batch. 272 result, evaluated_signals = sess.run([sliced_features, signals]) 273 self.assertEqual(1, len(result['a'])) 274 self.assertAllEqual(a[batch_size:num_samples], result['a']) 275 self.assertAllEqual(b[batch_size:num_samples], result['b']) 276 self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) 277 278 # This run should work, *but* see STOP ('1') as signals 279 _, evaluated_signals = sess.run([sliced_features, signals]) 280 self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping']) 281 282 with self.assertRaises(errors.OutOfRangeError): 283 sess.run(sliced_features) 284 285 def test_slice_with_multi_invocations_per_step(self): 286 num_samples = 3 287 batch_size = 2 288 289 params = {'batch_size': batch_size} 290 input_fn, (a, b) = make_input_fn(num_samples=num_samples) 291 292 with ops.Graph().as_default(): 293 dataset = input_fn(params) 294 inputs = tpu_estimator._InputsWithStoppingSignals( 295 dataset, batch_size, add_padding=True, num_invocations_per_step=2) 296 dataset_initializer = inputs.dataset_initializer() 297 features, _ = inputs.features_and_labels() 298 signals = inputs.signals() 299 300 sliced_features = ( 301 tpu_estimator._PaddingSignals.slice_tensor_or_dict(features, signals)) 302 303 with session.Session() as sess: 304 sess.run(dataset_initializer) 305 306 result, evaluated_signals = sess.run([sliced_features, signals]) 307 self.assertAllEqual(a[:batch_size], result['a']) 308 self.assertAllEqual(b[:batch_size], result['b']) 309 self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) 310 311 # This is the final partial batch. 312 result, evaluated_signals = sess.run([sliced_features, signals]) 313 self.assertEqual(1, len(result['a'])) 314 self.assertAllEqual(a[batch_size:num_samples], result['a']) 315 self.assertAllEqual(b[batch_size:num_samples], result['b']) 316 self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) 317 318 # We should see 3 continuous batches with STOP ('1') as signals and all 319 # of them have mask 1. 320 _, evaluated_signals = sess.run([sliced_features, signals]) 321 self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping']) 322 self.assertAllEqual([1.] * batch_size, 323 evaluated_signals['padding_mask']) 324 325 _, evaluated_signals = sess.run([sliced_features, signals]) 326 self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping']) 327 self.assertAllEqual([1.] * batch_size, 328 evaluated_signals['padding_mask']) 329 330 _, evaluated_signals = sess.run([sliced_features, signals]) 331 self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping']) 332 self.assertAllEqual([1.] * batch_size, 333 evaluated_signals['padding_mask']) 334 with self.assertRaises(errors.OutOfRangeError): 335 sess.run(sliced_features) 336 337 338 if __name__ == '__main__': 339 test.main() 340