1 # -*- coding: utf-8 -*- 2 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 # 4 # Licensed under the Apache License, Version 2.0 (the "License"); 5 # you may not use this file except in compliance with the License. 6 # You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 # ============================================================================== 16 """Tests for the experimental input pipeline ops.""" 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import math 22 23 import numpy as np 24 25 from tensorflow.python.data.ops import dataset_ops 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.framework import errors 29 from tensorflow.python.framework import sparse_tensor 30 from tensorflow.python.framework import tensor_shape 31 from tensorflow.python.ops import array_ops 32 from tensorflow.python.ops import math_ops 33 from tensorflow.python.ops import string_ops 34 from tensorflow.python.platform import test 35 from tensorflow.python.util import compat 36 37 38 class BatchDatasetTest(test.TestCase): 39 40 def testBatchDataset(self): 41 """Test an dataset that maps a TF function across its input elements.""" 42 # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> 43 # RepeatDataset(count) -> BatchDataset(batch_size). 44 components = (np.arange(7), 45 np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], 46 np.array(37.0) * np.arange(7)) 47 48 count = array_ops.placeholder(dtypes.int64, shape=[]) 49 batch_size = array_ops.placeholder(dtypes.int64, shape=[]) 50 51 def _map_fn(x, y, z): 52 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 53 54 iterator = ( 55 dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) 56 .repeat(count).batch(batch_size).make_initializable_iterator()) 57 init_op = iterator.initializer 58 get_next = iterator.get_next() 59 60 self.assertEqual([[None] + list(c.shape[1:]) for c in components], 61 [t.shape.as_list() for t in get_next]) 62 63 with self.test_session() as sess: 64 # Batch of a finite input, where the batch_size divides the 65 # total number of elements. 66 sess.run(init_op, feed_dict={count: 28, batch_size: 14}) 67 num_batches = (28 * 7) // 14 68 for i in range(num_batches): 69 result = sess.run(get_next) 70 for component, result_component in zip(components, result): 71 for j in range(14): 72 self.assertAllEqual(component[(i * 14 + j) % 7]**2, 73 result_component[j]) 74 with self.assertRaises(errors.OutOfRangeError): 75 sess.run(get_next) 76 77 # Batch of a finite input, where the batch_size does not 78 # divide the total number of elements. 79 sess.run(init_op, feed_dict={count: 14, batch_size: 8}) 80 81 # We expect (num_batches - 1) full-sized batches. 82 num_batches = int(math.ceil((14 * 7) / 8)) 83 for i in range(num_batches - 1): 84 result = sess.run(get_next) 85 for component, result_component in zip(components, result): 86 for j in range(8): 87 self.assertAllEqual(component[(i * 8 + j) % 7]**2, 88 result_component[j]) 89 result = sess.run(get_next) 90 for component, result_component in zip(components, result): 91 for j in range((14 * 7) % 8): 92 self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2, 93 result_component[j]) 94 with self.assertRaises(errors.OutOfRangeError): 95 sess.run(get_next) 96 97 # Batch of an empty input should fail straight away. 98 sess.run(init_op, feed_dict={count: 0, batch_size: 8}) 99 with self.assertRaises(errors.OutOfRangeError): 100 sess.run(get_next) 101 102 # Empty batch should be an initialization time error. 103 with self.assertRaises(errors.InvalidArgumentError): 104 sess.run(init_op, feed_dict={count: 14, batch_size: 0}) 105 106 def assertSparseValuesEqual(self, a, b): 107 self.assertAllEqual(a.indices, b.indices) 108 self.assertAllEqual(a.values, b.values) 109 self.assertAllEqual(a.dense_shape, b.dense_shape) 110 111 def testBatchSparse(self): 112 113 def _sparse(i): 114 return sparse_tensor.SparseTensorValue( 115 indices=[[0]], values=(i * [1]), dense_shape=[1]) 116 117 iterator = dataset_ops.Dataset.range(10).map(_sparse).batch( 118 5).make_initializable_iterator() 119 init_op = iterator.initializer 120 get_next = iterator.get_next() 121 122 with self.test_session() as sess: 123 sess.run(init_op) 124 for i in range(2): 125 actual = sess.run(get_next) 126 expected = sparse_tensor.SparseTensorValue( 127 indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], 128 values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4], 129 dense_shape=[5, 1]) 130 self.assertTrue(sparse_tensor.is_sparse(actual)) 131 self.assertSparseValuesEqual(actual, expected) 132 with self.assertRaises(errors.OutOfRangeError): 133 sess.run(get_next) 134 135 def testBatchSparseWithDifferentDenseShapes(self): 136 137 def _sparse(i): 138 return sparse_tensor.SparseTensorValue( 139 indices=array_ops.expand_dims( 140 math_ops.range(i, dtype=dtypes.int64), 1), 141 values=array_ops.fill([math_ops.to_int32(i)], i), 142 dense_shape=[i]) 143 144 iterator = dataset_ops.Dataset.range(10).map(_sparse).batch( 145 5).make_initializable_iterator() 146 init_op = iterator.initializer 147 get_next = iterator.get_next() 148 149 with self.test_session() as sess: 150 sess.run(init_op) 151 for i in range(2): 152 actual = sess.run(get_next) 153 expected_indices = [] 154 expected_values = [] 155 for j in range(5): 156 for k in range(i * 5 + j): 157 expected_indices.append([j, k]) 158 expected_values.append(i * 5 + j) 159 expected = sparse_tensor.SparseTensorValue( 160 indices=expected_indices, 161 values=expected_values, 162 dense_shape=[5, (i + 1) * 5 - 1]) 163 self.assertTrue(sparse_tensor.is_sparse(actual)) 164 self.assertSparseValuesEqual(actual, expected) 165 with self.assertRaises(errors.OutOfRangeError): 166 sess.run(get_next) 167 168 def testNestedBatchSparse(self): 169 170 def _sparse(i): 171 return sparse_tensor.SparseTensorValue( 172 indices=[[0]], values=(i * [1]), dense_shape=[1]) 173 174 iterator = dataset_ops.Dataset.range(10).map(_sparse).batch(5).batch( 175 2).make_initializable_iterator() 176 init_op = iterator.initializer 177 get_next = iterator.get_next() 178 179 with self.test_session() as sess: 180 sess.run(init_op) 181 actual = sess.run(get_next) 182 expected = sparse_tensor.SparseTensorValue( 183 indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0], 184 [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [1, 4, 0]], 185 values=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 186 dense_shape=[2, 5, 1]) 187 self.assertTrue(sparse_tensor.is_sparse(actual)) 188 self.assertSparseValuesEqual(actual, expected) 189 with self.assertRaises(errors.OutOfRangeError): 190 sess.run(get_next) 191 192 def testBatchShapeError(self): 193 194 def generator(): 195 yield [1.0, 2.0, 3.0] 196 yield [4.0, 5.0, 6.0] 197 yield [7.0, 8.0, 9.0, 10.0] 198 199 iterator = ( 200 dataset_ops.Dataset.from_generator( 201 generator, dtypes.float32, output_shapes=[None]).batch(3) 202 .make_initializable_iterator()) 203 next_element = iterator.get_next() 204 205 with self.test_session() as sess: 206 sess.run(iterator.initializer) 207 with self.assertRaisesRegexp( 208 errors.InvalidArgumentError, 209 r'Cannot batch tensors with different shapes in component 0. ' 210 r'First element had shape \[3\] and element 2 had shape \[4\].'): 211 sess.run(next_element) 212 213 def testPaddedBatchDataset(self): 214 seq_lens = array_ops.placeholder(dtypes.int32, shape=[None]) 215 padded_shape = array_ops.placeholder(dtypes.int64, shape=[1]) 216 217 iterator = ( 218 dataset_ops.Dataset.from_tensor_slices(seq_lens) 219 .map(lambda x: array_ops.fill([x], x)).padded_batch( 220 4, padded_shapes=padded_shape).make_initializable_iterator()) 221 222 init_op = iterator.initializer 223 get_next = iterator.get_next() 224 225 with self.test_session() as sess: 226 # Test with random sequence lengths, and max padding. 227 random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32) 228 sess.run( 229 init_op, feed_dict={ 230 padded_shape: [-1], 231 seq_lens: random_seq_lens 232 }) 233 for i in range(8): 234 result = sess.run(get_next) 235 padded_len = np.max(result) 236 self.assertEqual((4, padded_len), result.shape) 237 for j in range(4): 238 seq_len = random_seq_lens[(i * 4) + j] 239 self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) 240 self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len)) 241 with self.assertRaises(errors.OutOfRangeError): 242 sess.run(get_next) 243 244 # Test with random sequence lengths, and constant padding. 245 sess.run( 246 init_op, feed_dict={ 247 padded_shape: [25], 248 seq_lens: random_seq_lens 249 }) 250 for i in range(8): 251 result = sess.run(get_next) 252 self.assertEqual((4, 25), result.shape) 253 for j in range(4): 254 seq_len = random_seq_lens[(i * 4) + j] 255 self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) 256 self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len)) 257 with self.assertRaises(errors.OutOfRangeError): 258 sess.run(get_next) 259 260 # Test correct handling of empty tensors. 261 sess.run(init_op, feed_dict={padded_shape: [-1], seq_lens: [0, 0, 0, 0]}) 262 result = sess.run(get_next) 263 self.assertAllEqual([[], [], [], []], result) 264 with self.assertRaises(errors.OutOfRangeError): 265 sess.run(get_next) 266 267 # Test error handling with constant sequence lengths, and 268 # too-short padding. 269 sess.run(init_op, feed_dict={padded_shape: [5], seq_lens: [6, 5, 5, 5]}) 270 with self.assertRaises(errors.DataLossError): 271 result = sess.run(get_next) 272 273 def testPaddedBatchDatasetNonDefaultPadding(self): 274 seq_lens = array_ops.placeholder(dtypes.int32, shape=[None]) 275 padded_shape = array_ops.placeholder(dtypes.int64, shape=[1]) 276 277 def fill_tuple(x): 278 filled = array_ops.fill([x], x) 279 return (filled, string_ops.as_string(filled)) 280 281 iterator = ( 282 dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple) 283 .padded_batch( 284 4, 285 padded_shapes=(padded_shape, padded_shape), 286 padding_values=(-1, '<end>')).make_initializable_iterator()) 287 288 init_op = iterator.initializer 289 get_next = iterator.get_next() 290 291 with self.test_session() as sess: 292 # Test with random sequence lengths, and max padding. 293 random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32) 294 sess.run( 295 init_op, feed_dict={ 296 padded_shape: [-1], 297 seq_lens: random_seq_lens 298 }) 299 for i in range(8): 300 result = sess.run(get_next) 301 padded_len = np.max(result[0]) 302 self.assertEqual((4, padded_len), result[0].shape) 303 self.assertEqual((4, padded_len), result[1].shape) 304 for j in range(4): 305 seq_len = random_seq_lens[(i * 4) + j] 306 self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len) 307 self.assertAllEqual(result[0][j, seq_len:], 308 [-1] * (padded_len - seq_len)) 309 self.assertAllEqual(result[1][j, :seq_len], 310 [compat.as_bytes(str(seq_len))] * seq_len) 311 self.assertAllEqual(result[1][j, seq_len:], 312 [b'<end>'] * (padded_len - seq_len)) 313 with self.assertRaises(errors.OutOfRangeError): 314 sess.run(get_next) 315 316 def testPaddedBatchDatasetUnicode(self): 317 # See GitHub issue 16149 318 def generator(): 319 data = [[u'', u'', u''], 320 [u'', u'', u'', u'']] 321 322 for seq in data: 323 yield seq, [0, 1, 2, 3] 324 325 dataset = dataset_ops.Dataset.from_generator( 326 generator, (dtypes.string, dtypes.int32), 327 (tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None]))) 328 padded_dataset = dataset.padded_batch( 329 2, padded_shapes=([None], [None]), padding_values=('', 0)) 330 with self.test_session() as sess: 331 next_element = padded_dataset.make_one_shot_iterator().get_next() 332 sess.run(next_element) 333 334 def testPaddedBatchDatasetShapeSpecifications(self): 335 int_placeholder = array_ops.placeholder(dtypes.int32) 336 float_placeholder = array_ops.placeholder(dtypes.float32) 337 string_placeholder = array_ops.placeholder(dtypes.string) 338 input_dataset = dataset_ops.Dataset.from_tensors( 339 (int_placeholder, float_placeholder, string_placeholder)) 340 341 # Test different ways of specifying the `padded_shapes` argument. 342 dynamic_padding_from_tensor_shapes = input_dataset.padded_batch( 343 32, 344 padded_shapes=(tensor_shape.TensorShape([None]), 345 tensor_shape.TensorShape([None, None]), 346 tensor_shape.TensorShape([37]))) 347 dynamic_padding_from_lists = input_dataset.padded_batch( 348 32, padded_shapes=([None], [None, None], [37])) 349 dynamic_padding_from_lists_with_minus_one = input_dataset.padded_batch( 350 32, padded_shapes=([-1], [-1, -1], [37])) 351 dynamic_padding_from_tensors = input_dataset.padded_batch( 352 32, 353 padded_shapes=(constant_op.constant([-1], dtype=dtypes.int64), 354 constant_op.constant([-1, -1], dtype=dtypes.int64), 355 constant_op.constant([37], dtype=dtypes.int64))) 356 357 for dataset in [ 358 dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists, 359 dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors 360 ]: 361 self.assertEqual([None, None], dataset.output_shapes[0].as_list()) 362 self.assertEqual([None, None, None], dataset.output_shapes[1].as_list()) 363 self.assertEqual([None, 37], dataset.output_shapes[2].as_list()) 364 365 def testPaddedBatchSparseError(self): 366 367 def _map_fn(i): 368 return sparse_tensor.SparseTensorValue( 369 indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i 370 371 with self.assertRaises(TypeError): 372 _ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10) 373 374 375 if __name__ == '__main__': 376 test.main() 377