1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for training.input.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import itertools 22 import os 23 24 import numpy as np 25 from six.moves import xrange # pylint: disable=redefined-builtin 26 27 from tensorflow.python.framework import constant_op 28 from tensorflow.python.framework import dtypes 29 from tensorflow.python.framework import errors_impl 30 from tensorflow.python.framework import sparse_tensor 31 from tensorflow.python.ops import array_ops 32 from tensorflow.python.ops import math_ops 33 from tensorflow.python.ops import variables 34 from tensorflow.python.platform import test as test_lib 35 from tensorflow.python.platform import tf_logging 36 from tensorflow.python.training import coordinator 37 from tensorflow.python.training import input as inp 38 from tensorflow.python.training import queue_runner_impl 39 from tensorflow.python.util import compat 40 41 42 class MatchFilenamesOnceTest(test_lib.TestCase): 43 44 def test(self): 45 temp_dir = self.get_temp_dir() 46 filenames = [os.path.join(temp_dir, n) for n in os.listdir(temp_dir)] 47 additional = [ 48 os.path.join(self.get_temp_dir(), "match_filenames.%d" % i) 49 for i in range(3) 50 ] 51 for name in additional: 52 open(name, "w").write("Some contents") 53 filenames = list(set(filenames + additional)) 54 with self.test_session(): 55 star = inp.match_filenames_once(os.path.join(self.get_temp_dir(), "*")) 56 question = inp.match_filenames_once( 57 os.path.join(self.get_temp_dir(), "match_filenames.?")) 58 one = inp.match_filenames_once(additional[1]) 59 variables.global_variables_initializer().run() 60 variables.local_variables_initializer().run() 61 self.assertItemsEqual(map(compat.as_bytes, filenames), star.eval()) 62 self.assertItemsEqual(map(compat.as_bytes, additional), question.eval()) 63 self.assertItemsEqual([compat.as_bytes(additional[1])], one.eval()) 64 65 66 class LimitEpochsTest(test_lib.TestCase): 67 68 def testNoLimit(self): 69 with self.test_session(): 70 seven = constant_op.constant(7) 71 seven_forever = inp.limit_epochs(seven) 72 variables.local_variables_initializer().run() 73 for _ in range(100): 74 self.assertEqual(7, seven_forever.eval()) 75 76 def testLimit(self): 77 with self.test_session(): 78 love_me = constant_op.constant("Love Me") 79 love_me_two_times = inp.limit_epochs(love_me, num_epochs=2) 80 variables.global_variables_initializer().run() 81 variables.local_variables_initializer().run() 82 self.assertEqual(b"Love Me", love_me_two_times.eval()) 83 self.assertEqual(b"Love Me", love_me_two_times.eval()) 84 with self.assertRaises(errors_impl.OutOfRangeError): 85 love_me_two_times.eval() 86 87 88 class InputProducerTest(test_lib.TestCase): 89 90 def testNoShuffle(self): 91 with self.test_session(): 92 input_tensor = [[1, 2, 3, 4], 93 [5, 6, 7, 8], 94 [9, 10, 11, 12]] 95 num_epochs = 2 96 queue = inp.input_producer( 97 input_tensor, num_epochs=num_epochs, shuffle=False) 98 dequeue_many = queue.dequeue_many(len(input_tensor) * num_epochs) 99 dequeue = queue.dequeue() 100 variables.global_variables_initializer().run() 101 variables.local_variables_initializer().run() 102 threads = queue_runner_impl.start_queue_runners() 103 104 # No randomness, so just see repeated copies of the input. 105 self.assertAllEqual(input_tensor * num_epochs, dequeue_many.eval()) 106 107 # Reached the limit. 108 with self.assertRaises(errors_impl.OutOfRangeError): 109 dequeue.eval() 110 for thread in threads: 111 thread.join() 112 113 def testNoShapeInference(self): 114 with self.test_session(): 115 # Disable shape inference for the input. 116 input_value = [[1, 2, 3, 4], 117 [5, 6, 7, 8], 118 [9, 10, 11, 12]] 119 input_tensor = array_ops.placeholder_with_default(input_value, shape=None) 120 num_epochs = 2 121 queue = inp.input_producer( 122 input_tensor, element_shape=[4], num_epochs=num_epochs, shuffle=False) 123 dequeue_many = queue.dequeue_many(len(input_value) * num_epochs) 124 dequeue = queue.dequeue() 125 variables.global_variables_initializer().run() 126 variables.local_variables_initializer().run() 127 threads = queue_runner_impl.start_queue_runners() 128 129 # No randomness, so just see repeated copies of the input. 130 self.assertAllEqual(input_value * num_epochs, dequeue_many.eval()) 131 132 # Reached the limit. 133 with self.assertRaises(errors_impl.OutOfRangeError): 134 dequeue.eval() 135 for thread in threads: 136 thread.join() 137 138 def testShapeError(self): 139 input_tensor = array_ops.placeholder(dtypes.float32, None) 140 with self.assertRaisesRegexp(ValueError, "fully defined shape"): 141 _ = inp.input_producer(input_tensor) 142 143 144 class StringInputProducerTest(test_lib.TestCase): 145 146 def testNoShuffle(self): 147 with self.test_session(): 148 strings = [b"to", b"be", b"or", b"not", b"to", b"be"] 149 num_epochs = 3 150 queue = inp.string_input_producer( 151 strings, num_epochs=num_epochs, shuffle=False) 152 dequeue_many = queue.dequeue_many(len(strings) * num_epochs) 153 dequeue = queue.dequeue() 154 variables.global_variables_initializer().run() 155 variables.local_variables_initializer().run() 156 threads = queue_runner_impl.start_queue_runners() 157 158 # No randomness, so just see repeated copies of the input. 159 output = dequeue_many.eval() 160 self.assertAllEqual(strings * num_epochs, output) 161 162 # Reached the limit. 163 with self.assertRaises(errors_impl.OutOfRangeError): 164 dequeue.eval() 165 for thread in threads: 166 thread.join() 167 168 def testShuffle(self): 169 with self.test_session(): 170 strings = [b"a", b"b", b"c"] 171 num_epochs = 600 172 queue = inp.string_input_producer( 173 strings, num_epochs=num_epochs, shuffle=True, seed=271828) 174 dequeue_many = queue.dequeue_many(len(strings)) 175 dequeue = queue.dequeue() 176 variables.global_variables_initializer().run() 177 variables.local_variables_initializer().run() 178 threads = queue_runner_impl.start_queue_runners() 179 180 # Validate that we only shuffle the strings within an epoch and 181 # count how often each possible order appears. 182 expected = [b"abc", b"acb", b"bac", b"bca", b"cab", b"cba"] 183 frequency = {} 184 for e in expected: 185 frequency[e] = 0 186 for _ in range(num_epochs): 187 output = dequeue_many.eval() 188 key = b"".join(output) 189 self.assertIn(key, expected) 190 frequency[key] += 1 191 192 # Expect an approximately even distribution over all possible orders. 193 expected_frequency = num_epochs / len(expected) 194 margin = expected_frequency * 0.4 195 tf_logging.info("Observed counts: %s", frequency) 196 for key in expected: 197 value = frequency[key] 198 self.assertGreater(value, expected_frequency - margin) 199 self.assertLess(value, expected_frequency + margin) 200 201 # Reached the limit. 202 with self.assertRaises(errors_impl.OutOfRangeError): 203 dequeue.eval() 204 for thread in threads: 205 thread.join() 206 207 def testNullStringPython(self): 208 # Graph-construction time check for empty string list: 209 with self.test_session(): 210 with self.assertRaises(ValueError): 211 _ = inp.string_input_producer([]) 212 213 def testNullString(self): 214 # Runtime check for empty string list. This is slightly oblique: 215 # The queue runner should die with an assertion error on the null 216 # input tensor, causing the dequeue to fail with an OutOfRangeError. 217 with self.test_session(): 218 coord = coordinator.Coordinator() 219 queue = inp.string_input_producer( 220 constant_op.constant( 221 [], dtype=dtypes.string)) 222 dequeue = queue.dequeue() 223 variables.global_variables_initializer().run() 224 variables.local_variables_initializer().run() 225 threads = queue_runner_impl.start_queue_runners(coord=coord) 226 with self.assertRaises(errors_impl.OutOfRangeError): 227 dequeue.eval() 228 coord.request_stop() 229 for thread in threads: 230 thread.join() 231 232 def testSharedName(self): 233 with self.test_session(): 234 strings = [b"to", b"be", b"or", b"not", b"to", b"be"] 235 queue = inp.string_input_producer( 236 strings, shared_name="SHARED_NAME_XYZ", name="Q") 237 self.assertProtoEquals("s: 'SHARED_NAME_XYZ'", 238 queue.queue_ref.op.node_def.attr["shared_name"]) 239 240 def testConstructionRace(self): 241 with self.test_session() as sess: 242 strings = [b"to", b"be", b"or", b"not", b"to", b"be"] 243 queue = inp.string_input_producer(strings, shuffle=False) 244 coord = coordinator.Coordinator() 245 threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) 246 for _ in range(2): 247 for string in strings: 248 # NOTE(mrry): This is not the recommended way to write 249 # dequeuing code (instead you should create a single dequeue 250 # op before starting the queue runners, and run it 251 # repeatedly), because it leads to concurrent reading and 252 # writing of the `tf.Graph` object. However, many users 253 # write code this way, so we include this test to ensure 254 # that we can support it. 255 self.assertEquals(string, sess.run(queue.dequeue())) 256 coord.request_stop() 257 coord.join(threads) 258 259 260 class RangeInputProducerTest(test_lib.TestCase): 261 262 def testNoShuffle(self): 263 with self.test_session(): 264 num_epochs = 3 265 range_size = 5 266 queue = inp.range_input_producer( 267 range_size, num_epochs=num_epochs, shuffle=False) 268 dequeue_many = queue.dequeue_many(range_size * num_epochs) 269 dequeue = queue.dequeue() 270 variables.global_variables_initializer().run() 271 variables.local_variables_initializer().run() 272 threads = queue_runner_impl.start_queue_runners() 273 274 # No randomness, so just see repeated copies of the input. 275 output = dequeue_many.eval() 276 self.assertAllEqual(list(xrange(range_size)) * num_epochs, output) 277 278 # Reached the limit. 279 with self.assertRaises(errors_impl.OutOfRangeError): 280 dequeue.eval() 281 for thread in threads: 282 thread.join() 283 284 def testShuffle(self): 285 with self.test_session(): 286 num_epochs = 200 287 range_size = 2 288 queue = inp.range_input_producer( 289 range_size, num_epochs=num_epochs, shuffle=True, seed=314159) 290 dequeue_many = queue.dequeue_many(range_size) 291 dequeue = queue.dequeue() 292 variables.global_variables_initializer().run() 293 variables.local_variables_initializer().run() 294 threads = queue_runner_impl.start_queue_runners() 295 296 # Validate that we only shuffle the integers within an epoch and 297 # count how often each possible order appears. 298 expected = [12, 21] 299 frequency = {} 300 for e in expected: 301 frequency[e] = 0 302 for _ in range(num_epochs): 303 output = dequeue_many.eval() 304 key = 10 * (output[0] + 1) + (output[1] + 1) 305 self.assertIn(key, expected) 306 frequency[key] += 1 307 308 # Expect an approximately even distribution over all possible orders. 309 expected_frequency = num_epochs / len(expected) 310 margin = expected_frequency * 0.4 311 tf_logging.info("Observed counts: %s", frequency) 312 for key in expected: 313 value = frequency[key] 314 self.assertGreater(value, expected_frequency - margin) 315 self.assertLess(value, expected_frequency + margin) 316 317 # Reached the limit. 318 with self.assertRaises(errors_impl.OutOfRangeError): 319 dequeue.eval() 320 for thread in threads: 321 thread.join() 322 323 def testSharedName(self): 324 with self.test_session(): 325 range_size = 5 326 queue = inp.range_input_producer( 327 range_size, shared_name="SHARED_NAME_XYZ", name="Q") 328 self.assertProtoEquals("s: 'SHARED_NAME_XYZ'", 329 queue.queue_ref.op.node_def.attr["shared_name"]) 330 331 332 class SliceInputProducerTest(test_lib.TestCase): 333 334 def testNoShuffle(self): 335 with self.test_session() as sess: 336 num_epochs = 3 337 source_strings = [b"Alpha", b"Beta", b"Delta", b"Gamma"] 338 source_ints = [2, 3, 5, 7] 339 slices = inp.slice_input_producer( 340 [source_strings, source_ints], num_epochs=num_epochs, shuffle=False) 341 variables.global_variables_initializer().run() 342 variables.local_variables_initializer().run() 343 threads = queue_runner_impl.start_queue_runners() 344 345 # No randomness, so just see repeated copies of the input. 346 num_items = len(source_strings) * num_epochs 347 output = [sess.run(slices) for _ in range(num_items)] 348 out_strings, out_ints = zip(*output) 349 self.assertAllEqual(source_strings * num_epochs, out_strings) 350 self.assertAllEqual(source_ints * num_epochs, out_ints) 351 352 # Reached the limit. 353 with self.assertRaises(errors_impl.OutOfRangeError): 354 sess.run(slices) 355 for thread in threads: 356 thread.join() 357 358 def testShuffle(self): 359 with self.test_session() as sess: 360 num_epochs = 1200 361 source_strings = ["A", "B", "D", "G"] 362 source_ints = [7, 3, 5, 2] 363 slices = inp.slice_input_producer( 364 [source_strings, source_ints], 365 num_epochs=num_epochs, 366 shuffle=True, 367 seed=161803) 368 variables.global_variables_initializer().run() 369 variables.local_variables_initializer().run() 370 threads = queue_runner_impl.start_queue_runners() 371 372 # Validate that we only shuffle the integers within an epoch and 373 # count how often each possible order appears. 374 expected = [ 375 b",".join(x) 376 for x in itertools.permutations([b"A7", b"B3", b"D5", b"G2"]) 377 ] 378 frequency = {} 379 for e in expected: 380 frequency[e] = 0 381 for _ in range(num_epochs): 382 output = [sess.run(slices) for _ in range(len(source_strings))] 383 key = b",".join([s + compat.as_bytes(str(i)) for s, i in output]) 384 self.assertIn(key, expected) 385 frequency[key] += 1 386 387 # Expect an approximately even distribution over all possible orders. 388 expected_frequency = num_epochs / len(expected) 389 margin = expected_frequency * 0.4 390 tf_logging.info("Observed counts: %s", frequency) 391 for key in expected: 392 value = frequency[key] 393 self.assertGreater(value, expected_frequency - margin) 394 self.assertLess(value, expected_frequency + margin) 395 396 # Reached the limit. 397 with self.assertRaises(errors_impl.OutOfRangeError): 398 sess.run(slices) 399 for thread in threads: 400 thread.join() 401 402 def testSharedName(self): 403 with self.test_session(): 404 source_strings = ["A", "B", "D", "G"] 405 source_ints = [7, 3, 5, 2] 406 slices = inp.slice_input_producer( 407 [source_strings, source_ints], 408 shared_name="SHARED_NAME_XYZ", 409 name="sip") 410 411 self.assertProtoEquals( 412 "s: 'SHARED_NAME_XYZ'", 413 slices[0].op.inputs[1].op.inputs[0].op.node_def.attr["shared_name"]) 414 415 416 class DictHelperTest(test_lib.TestCase): 417 418 def testListInputs(self): 419 l = [1, 2, 3, 11, 22, 33] 420 l2 = inp._as_tensor_list(l) 421 self.assertEquals(l, l2) 422 l3 = inp._as_original_type(l, l2) 423 self.assertEquals(l, l3) 424 425 def testDictInputs(self): 426 d = {"a": 1, "b": 2, "c": 3, "aa": 11, "bb": 22, "cc": 33} 427 l = inp._as_tensor_list(d) 428 self.assertEquals([1, 11, 2, 22, 3, 33], l) 429 d2 = inp._as_original_type(d, l) 430 self.assertEquals(d, d2) 431 432 def testHeterogeneousKeysDictInputs(self): 433 d = {"z": 1, 1: 42, ("a", "b"): 100} 434 l = inp._as_tensor_list(d) 435 self.assertEquals([100, 42, 1], l) 436 d2 = inp._as_original_type(d, l) 437 self.assertEquals(d, d2) 438 439 440 class BatchTest(test_lib.TestCase): 441 442 def _testOneThreadHelper(self, use_dict): 443 with self.test_session() as sess: 444 batch_size = 10 445 num_batches = 3 446 zero64 = constant_op.constant(0, dtype=dtypes.int64) 447 examples = variables.Variable(zero64) 448 counter = examples.count_up_to(num_batches * batch_size) 449 sparse_counter = sparse_tensor.SparseTensor( 450 indices=array_ops.reshape( 451 array_ops.stack([zero64, zero64 + 1]), [2, 1]), 452 values=math_ops.cast( 453 array_ops.stack([counter, -counter]), dtypes.float32), 454 dense_shape=[2]) 455 if use_dict: 456 batched = inp.batch( 457 { 458 "c": counter, 459 "s": sparse_counter, 460 "S": "string" 461 }, 462 batch_size=batch_size) 463 batched_fetch = [batched["c"], batched["s"], batched["S"]] 464 else: 465 batched = inp.batch( 466 [counter, sparse_counter, "string"], batch_size=batch_size) 467 batched_fetch = batched 468 variables.global_variables_initializer().run() 469 variables.local_variables_initializer().run() 470 threads = queue_runner_impl.start_queue_runners() 471 472 for i in range(num_batches): 473 results = sess.run(batched_fetch) 474 self.assertAllEqual(results[0], 475 np.arange(i * batch_size, (i + 1) * batch_size)) 476 self.assertAllEqual( 477 results[1].indices, 478 np.vstack(( 479 np.arange(2 * batch_size) // 2, # 0, 0, 1, 1, ... 480 [0, 1] * batch_size)).T) 481 # [x, -x, x+1, -(x+1), ...] 482 expected = np.arange(2 * i * batch_size, 2 * (i + 1) * batch_size) // 2 483 expected *= ([1, -1] * batch_size) # mult by [1, -1, 1, -1, ...] 484 self.assertAllEqual(results[1].values, expected) 485 self.assertAllEqual(results[1].dense_shape, [batch_size, 2]) 486 self.assertAllEqual(results[2], [b"string"] * batch_size) 487 488 # Reached the limit. 489 with self.assertRaises(errors_impl.OutOfRangeError): 490 sess.run(batched_fetch) 491 for thread in threads: 492 thread.join() 493 494 def testOneThread(self): 495 self._testOneThreadHelper(use_dict=False) 496 497 def testOneThreadDict(self): 498 self._testOneThreadHelper(use_dict=True) 499 500 def testOneThreadDynamicPad(self): 501 with self.test_session() as sess: 502 batch_size = 10 503 num_batches = 3 504 zero64 = constant_op.constant(0, dtype=dtypes.int64) 505 examples = variables.Variable(zero64) 506 counter = examples.count_up_to(num_batches * batch_size) 507 string = array_ops.tile(["string"], 508 math_ops.to_int32(array_ops.stack([counter]))) 509 variables.global_variables_initializer().run() 510 variables.local_variables_initializer().run() 511 batched = inp.batch( 512 [counter, string], batch_size=batch_size, dynamic_pad=True) 513 threads = queue_runner_impl.start_queue_runners() 514 515 for i in range(num_batches): 516 results = sess.run(batched) 517 expected_results = np.arange(i * batch_size, (i + 1) * batch_size) 518 max_len = expected_results[-1] 519 self.assertAllEqual(results[0], expected_results) 520 expected_strings = [[b"string"] * rep + [b""] * (max_len - rep) 521 for rep in expected_results] 522 self.assertAllEqual(results[1], expected_strings) 523 524 # Reached the limit. 525 with self.assertRaises(errors_impl.OutOfRangeError): 526 sess.run(batched) 527 for thread in threads: 528 thread.join() 529 530 def testOneThreadEnqueueMany(self): 531 with self.test_session() as sess: 532 batch_size = 10 533 num_batches = 3 534 zero64 = constant_op.constant(0, dtype=dtypes.int64) 535 examples = variables.Variable(zero64) 536 counter = examples.count_up_to(num_batches * batch_size) 537 sparse_counter = sparse_tensor.SparseTensor( 538 indices=array_ops.reshape(zero64, [1, 1]), 539 values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]), 540 dense_shape=[1]) 541 pre_batched = inp.batch([counter, sparse_counter, "string"], batch_size=2) 542 batched = inp.batch(pre_batched, enqueue_many=True, batch_size=batch_size) 543 variables.global_variables_initializer().run() 544 variables.local_variables_initializer().run() 545 threads = queue_runner_impl.start_queue_runners() 546 547 for i in range(num_batches): 548 results = sess.run(batched) 549 self.assertAllEqual(results[0], 550 np.arange(i * batch_size, (i + 1) * batch_size)) 551 self.assertAllEqual( 552 results[1].indices, 553 np.vstack((np.arange(batch_size), np.zeros(batch_size))).T) 554 self.assertAllEqual(results[1].values, 555 np.arange(i * batch_size, (i + 1) * batch_size)) 556 self.assertAllEqual(results[1].dense_shape, [batch_size, 1]) 557 self.assertAllEqual(results[2], [b"string"] * batch_size) 558 559 # Reached the limit. 560 with self.assertRaises(errors_impl.OutOfRangeError): 561 sess.run(batched) 562 for thread in threads: 563 thread.join() 564 565 def testManyThreads(self): 566 with self.test_session() as sess: 567 batch_size = 10 568 num_batches = 3 569 zero64 = constant_op.constant(0, dtype=dtypes.int64) 570 571 examples = variables.Variable(zero64) 572 counter = examples.count_up_to(num_batches * batch_size) 573 sparse_counter = sparse_tensor.SparseTensor( 574 indices=array_ops.reshape(zero64, [1, 1]), 575 values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]), 576 dense_shape=[1]) 577 batched = inp.batch( 578 [counter, sparse_counter, "string"], 579 batch_size=batch_size, 580 num_threads=4) 581 variables.global_variables_initializer().run() 582 variables.local_variables_initializer().run() 583 threads = queue_runner_impl.start_queue_runners() 584 585 all_counts = [] 586 for i in range(num_batches): 587 results = sess.run(batched) 588 tf_logging.info("Batch %d: %s", i, results[0]) 589 self.assertEqual(len(results[0]), batch_size) 590 self.assertAllEqual(results[0], results[1].values) 591 self.assertAllEqual( 592 results[1].indices, 593 np.vstack((np.arange(batch_size), np.zeros(batch_size))).T) 594 self.assertAllEqual(results[1].dense_shape, [batch_size, 1]) 595 all_counts.extend(results[0]) 596 self.assertAllEqual(results[2], [b"string"] * batch_size) 597 self.assertItemsEqual(all_counts, range(num_batches * batch_size)) 598 599 # Reached the limit. 600 with self.assertRaises(errors_impl.OutOfRangeError): 601 sess.run(batched) 602 for thread in threads: 603 thread.join() 604 605 def testOneThreadSmallerBatch(self): 606 with self.test_session() as sess: 607 batch_size = 10 608 num_batches = 3 609 extra_elements = 5 610 zero64 = constant_op.constant(0, dtype=dtypes.int64) 611 examples = variables.Variable(zero64) 612 counter = examples.count_up_to(num_batches * batch_size + extra_elements) 613 sparse_counter = sparse_tensor.SparseTensor( 614 indices=array_ops.reshape( 615 array_ops.stack([zero64, zero64 + 1]), [2, 1]), 616 values=math_ops.cast( 617 array_ops.stack([counter, -counter]), dtypes.float32), 618 dense_shape=[2]) 619 batched = inp.batch( 620 [counter, sparse_counter, "string"], 621 batch_size=batch_size, 622 allow_smaller_final_batch=True) 623 variables.global_variables_initializer().run() 624 variables.local_variables_initializer().run() 625 threads = queue_runner_impl.start_queue_runners() 626 627 for i in range(num_batches): 628 results = sess.run(batched) 629 self.assertAllEqual(results[0], 630 np.arange(i * batch_size, (i + 1) * batch_size)) 631 self.assertAllEqual( 632 results[1].indices, 633 np.vstack(( 634 np.arange(2 * batch_size) // 2, # 0, 0, 1, 1, ... 635 [0, 1] * batch_size)).T) 636 # [x, -x, x+1, -(x+1), ...] 637 expected = np.arange(2 * i * batch_size, 2 * (i + 1) * batch_size) // 2 638 expected *= ([1, -1] * batch_size) # mult by [1, -1, 1, -1, ...] 639 self.assertAllEqual(results[1].values, expected) 640 self.assertAllEqual(results[1].dense_shape, [batch_size, 2]) 641 self.assertAllEqual(results[2], [b"string"] * batch_size) 642 643 # Reached the final batch with extra_elements. 644 results = sess.run(batched) 645 self.assertAllEqual(results[0], 646 np.arange(num_batches * batch_size, 647 num_batches * batch_size + extra_elements)) 648 self.assertAllEqual( 649 results[1].indices, 650 np.vstack(( 651 np.arange(2 * extra_elements) // 2, # 0, 0, 1, 1, ... 652 [0, 1] * extra_elements)).T) 653 self.assertAllEqual(results[1].dense_shape, [extra_elements, 2]) 654 self.assertAllEqual(results[2], [b"string"] * extra_elements) 655 656 # Reached the limit. 657 with self.assertRaises(errors_impl.OutOfRangeError): 658 sess.run(batched) 659 for thread in threads: 660 thread.join() 661 662 def testManyThreadsSmallerBatch(self): 663 with self.test_session() as sess: 664 batch_size = 10 665 num_batches = 3 666 extra_elements = 5 667 zero64 = constant_op.constant(0, dtype=dtypes.int64) 668 669 examples = variables.Variable(zero64) 670 counter = examples.count_up_to(num_batches * batch_size + extra_elements) 671 sparse_counter = sparse_tensor.SparseTensor( 672 indices=array_ops.reshape(zero64, [1, 1]), 673 values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]), 674 dense_shape=[1]) 675 batched = inp.batch( 676 [counter, sparse_counter, "string"], 677 batch_size=batch_size, 678 num_threads=4, 679 allow_smaller_final_batch=True) 680 variables.global_variables_initializer().run() 681 variables.local_variables_initializer().run() 682 threads = queue_runner_impl.start_queue_runners() 683 684 all_counts = [] 685 for i in range(num_batches): 686 results = sess.run(batched) 687 tf_logging.info("Batch %d: %s", i, results[0]) 688 self.assertEqual(len(results[0]), batch_size) 689 self.assertAllEqual(results[0], results[1].values) 690 self.assertAllEqual( 691 results[1].indices, 692 np.vstack((np.arange(batch_size), np.zeros(batch_size))).T) 693 self.assertAllEqual(results[1].dense_shape, [batch_size, 1]) 694 all_counts.extend(results[0]) 695 self.assertAllEqual(results[2], [b"string"] * batch_size) 696 697 # Reached the final batch with extra_elements. 698 results = sess.run(batched) 699 tf_logging.info("Last Batch: %s", results[0]) 700 self.assertEqual(len(results[0]), extra_elements) 701 self.assertAllEqual(results[0], results[1].values) 702 self.assertAllEqual( 703 results[1].indices, 704 np.vstack((np.arange(extra_elements), np.zeros(extra_elements))).T) 705 self.assertAllEqual(results[1].dense_shape, [extra_elements, 1]) 706 all_counts.extend(results[0]) 707 self.assertAllEqual(results[2], [b"string"] * extra_elements) 708 self.assertItemsEqual(all_counts, 709 range(num_batches * batch_size + extra_elements)) 710 711 # Reached the limit. 712 with self.assertRaises(errors_impl.OutOfRangeError): 713 sess.run(batched) 714 for thread in threads: 715 thread.join() 716 717 def testSharedName(self): 718 with self.test_session(): 719 batch_size = 10 720 num_batches = 3 721 zero64 = constant_op.constant(0, dtype=dtypes.int64) 722 examples = variables.Variable(zero64) 723 counter = examples.count_up_to(num_batches * batch_size) 724 batched = inp.batch( 725 [counter, "string"], 726 batch_size=batch_size, 727 shared_name="SHARED_NAME_XYZ", 728 name="Q") 729 730 self.assertProtoEquals( 731 "s: 'SHARED_NAME_XYZ'", 732 batched[0].op.inputs[0].op.node_def.attr["shared_name"]) 733 734 def testCannotInferRankError(self): 735 with self.test_session(): 736 x = array_ops.placeholder(dtype=dtypes.int64) 737 with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"): 738 inp.batch([x], batch_size=2) 739 740 def testBatchedSparseTensorInferredShape(self): 741 sparse = sparse_tensor.SparseTensor( 742 indices=[[0]], values=[1.0], dense_shape=[1]) 743 self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list()) 744 batched = inp.batch([sparse], batch_size=2) 745 self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list()) 746 747 def testBatchedSparseTensorInferredShapeEnqueueMany(self): 748 sparse = sparse_tensor.SparseTensor( 749 indices=[[0]], values=[1.0], dense_shape=[1]) 750 self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list()) 751 batched = inp.batch([sparse], batch_size=2, enqueue_many=True) 752 self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list()) 753 754 def testBatchedSparseTensorInferredShapeUnknownRank(self): 755 sparse = sparse_tensor.SparseTensor( 756 indices=array_ops.placeholder(dtypes.int64), 757 values=array_ops.placeholder(dtypes.float32), 758 dense_shape=array_ops.placeholder(dtypes.int64)) 759 self.assertIs(None, sparse.dense_shape.get_shape().num_elements()) 760 batched = inp.batch([sparse], batch_size=2) 761 self.assertIs(None, batched.dense_shape.get_shape().num_elements()) 762 763 def testBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self): 764 sparse = sparse_tensor.SparseTensor( 765 indices=array_ops.placeholder(dtypes.int64), 766 values=array_ops.placeholder(dtypes.float32), 767 dense_shape=array_ops.placeholder(dtypes.int64)) 768 self.assertIs(None, sparse.dense_shape.get_shape().num_elements()) 769 batched = inp.batch([sparse], batch_size=2, enqueue_many=True) 770 self.assertIs(None, batched.dense_shape.get_shape().num_elements()) 771 772 def testSingleElementDict(self): 773 x = inp.batch({"c": [12, 12]}, batch_size=8) 774 self.assertAllEqual((8, 2), x["c"].get_shape().as_list()) 775 776 def _testKeepInputHelper(self, num_threads, enqueue_many, 777 keep_input_vector=False): 778 with self.test_session() as sess: 779 batch_size = 5 780 num_batches = 4 781 examples = variables.Variable(0) 782 counter = examples.count_up_to(num_batches * batch_size * 2) 783 sparse_counter = sparse_tensor.SparseTensor( 784 indices=array_ops.zeros( 785 [1, 1], dtype=dtypes.int64), 786 values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]), 787 dense_shape=[1]) 788 to_batch = [counter, sparse_counter, "string"] 789 if enqueue_many: 790 to_batch = inp.batch(to_batch, 4 if keep_input_vector else 1) 791 keep_input = array_ops.squeeze( 792 math_ops.equal(0, math_ops.mod(to_batch[0], 2))) 793 batched = inp.maybe_batch( 794 to_batch, 795 keep_input, 796 batch_size, 797 num_threads=num_threads, 798 enqueue_many=enqueue_many) 799 variables.initialize_all_variables().run() 800 variables.initialize_local_variables().run() 801 threads = queue_runner_impl.start_queue_runners() 802 803 for _ in range(num_batches): 804 results = sess.run(batched) 805 self.assertAllEqual([0] * batch_size, np.mod(results[0], 2)) 806 self.assertAllEqual([0] * batch_size, np.mod(results[1].values, 2)) 807 self.assertAllEqual([b"string"] * batch_size, results[2]) 808 809 # Reached the limit. 810 with self.assertRaises(errors_impl.OutOfRangeError): 811 sess.run(batched) 812 for thread in threads: 813 thread.join() 814 815 def testSingleThreadKeepInput(self): 816 self._testKeepInputHelper(1, False) 817 818 def testSingleThreadKeepInputEnqueueMany(self): 819 self._testKeepInputHelper(1, True) 820 821 def testMultipleThreadKeepInput(self): 822 self._testKeepInputHelper(5, False) 823 824 def testMultipleThreadKeepInputEnqueueMany(self): 825 self._testKeepInputHelper(5, True) 826 827 def testMaybeEnqueuePerExample(self): 828 self._testKeepInputHelper(1, True, keep_input_vector=True) 829 830 def testMultipleThreadMaybeEnqueuePerExample(self): 831 self._testKeepInputHelper(5, True, keep_input_vector=True) 832 833 def testInvalidKeepInputVector(self): 834 # Can't have vector `keep_input` with `enqueue_many=False`. 835 with self.assertRaisesRegexp(ValueError, "`keep_input` cannot be a vector"): 836 inp.maybe_batch([array_ops.zeros(5)], 837 keep_input=constant_op.constant([True, False]), 838 batch_size=1, 839 enqueue_many=False) 840 # Can't have `keep_input` with more than one dimension. 841 with self.assertRaisesRegexp(ValueError, "must be 0 or 1 dimensions"): 842 inp.maybe_batch([array_ops.zeros(5)], 843 keep_input=constant_op.constant([[True], [False]]), 844 batch_size=1, 845 enqueue_many=True) 846 # `keep_input` must have dimensions determined at graph construction. 847 with self.assertRaisesRegexp(ValueError, 848 "must be known at graph construction"): 849 inp.maybe_batch([array_ops.zeros(5)], 850 keep_input=array_ops.placeholder(dtypes.bool), 851 batch_size=1, 852 enqueue_many=True) 853 854 def testMaybeBatchedSparseTensorInferredShape(self): 855 sparse = sparse_tensor.SparseTensor( 856 indices=[[0]], values=[1.0], dense_shape=[1]) 857 self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list()) 858 batched = inp.maybe_batch([sparse], keep_input=True, batch_size=2) 859 self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list()) 860 861 def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self): 862 sparse = sparse_tensor.SparseTensor( 863 indices=[[0]], values=[1.0], dense_shape=[1]) 864 self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list()) 865 batched = inp.maybe_batch( 866 [sparse], keep_input=True, batch_size=2, enqueue_many=True) 867 self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list()) 868 869 def testMaybeBatchedSparseTensorInferredShapeEnqueueManyPerExample(self): 870 sparse = sparse_tensor.SparseTensor( 871 indices=[[0], [0]], values=[1.0, 2.0], dense_shape=[2]) 872 self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list()) 873 batched = inp.maybe_batch( 874 [sparse], keep_input=[True, False], batch_size=2, enqueue_many=True) 875 self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list()) 876 877 def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self): 878 sparse = sparse_tensor.SparseTensor( 879 indices=array_ops.placeholder(dtypes.int64), 880 values=array_ops.placeholder(dtypes.float32), 881 dense_shape=array_ops.placeholder(dtypes.int64)) 882 self.assertIs(None, sparse.dense_shape.get_shape().num_elements()) 883 batched = inp.maybe_batch([sparse], keep_input=True, batch_size=2) 884 self.assertIs(None, batched.dense_shape.get_shape().num_elements()) 885 886 def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self): 887 sparse = sparse_tensor.SparseTensor( 888 indices=array_ops.placeholder(dtypes.int64), 889 values=array_ops.placeholder(dtypes.float32), 890 dense_shape=array_ops.placeholder(dtypes.int64)) 891 self.assertIs(None, sparse.dense_shape.get_shape().num_elements()) 892 batched = inp.maybe_batch( 893 [sparse], keep_input=True, batch_size=2, enqueue_many=True) 894 self.assertIs(None, batched.dense_shape.get_shape().num_elements()) 895 896 def testMaybeBatchedSparseTensorInferredShapeUnknownRankPerExample(self): 897 sparse = sparse_tensor.SparseTensor( 898 indices=array_ops.placeholder(dtypes.int64), 899 values=array_ops.placeholder(dtypes.float32), 900 dense_shape=array_ops.placeholder(dtypes.int64)) 901 self.assertIs(None, sparse.dense_shape.get_shape().num_elements()) 902 batched = inp.maybe_batch( 903 [sparse], keep_input=[True, False], batch_size=2, enqueue_many=True) 904 self.assertIs(None, batched.dense_shape.get_shape().num_elements()) 905 906 def testMaybeBatchCorrectValues(self): 907 sparse_t = sparse_tensor.SparseTensor( 908 indices=[[0, 1], [0, 2], [1, 0], [1, 3]], 909 dense_shape=[2, 4], 910 values=[5, 4, 7, 2]) 911 keep = constant_op.constant([True, False]) 912 batched = inp.maybe_batch( 913 [sparse_t], keep_input=keep, batch_size=1, enqueue_many=True) 914 915 with self.test_session(): 916 coord = coordinator.Coordinator() 917 threads = queue_runner_impl.start_queue_runners(coord=coord) 918 919 batched_np = batched.eval() 920 921 coord.request_stop() 922 for thread in threads: 923 thread.join() 924 925 self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices) 926 self.assertAllEqual([5, 4], batched_np.values) 927 self.assertAllEqual([1, 4], batched_np.dense_shape) 928 929 930 class BatchJoinTest(test_lib.TestCase): 931 932 def _testTwoThreadsHelper(self, use_dict): 933 with self.test_session() as sess: 934 # Two threads, the first generates (0..69, "a"). 935 num_a = 70 936 zero64 = constant_op.constant(0, dtype=dtypes.int64) 937 examples = variables.Variable(zero64) 938 counter = examples.count_up_to(num_a) 939 sparse_counter = sparse_tensor.SparseTensor( 940 indices=array_ops.reshape(zero64, [1, 1]), 941 values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]), 942 dense_shape=[1]) 943 944 # The second generates (99, "b") 90 times and then stops. 945 num_b = 90 946 ninety_nine = inp.limit_epochs( 947 constant_op.constant( 948 99, dtype=dtypes.int64), num_b) 949 sparse_ninety_nine = sparse_tensor.SparseTensor( 950 indices=array_ops.reshape(zero64, [1, 1]), 951 values=array_ops.stack([math_ops.cast(ninety_nine, dtypes.float32)]), 952 dense_shape=[1]) 953 954 # These get joined together and grouped into batches of 5. 955 batch_size = 5 956 if use_dict: 957 batched = inp.batch_join( 958 [{ 959 "c": counter, 960 "s": sparse_counter, 961 "S": "a" 962 }, { 963 "c": ninety_nine, 964 "s": sparse_ninety_nine, 965 "S": "b" 966 }], 967 batch_size=batch_size) 968 batched_fetch = [batched["c"], batched["s"], batched["S"]] 969 else: 970 batched = inp.batch_join( 971 [[counter, sparse_counter, "a"], 972 [ninety_nine, sparse_ninety_nine, "b"]], 973 batch_size=batch_size) 974 batched_fetch = batched 975 976 # Shapes. 977 self.assertEqual(3, len(batched_fetch)) 978 self.assertAllEqual((batch_size,), batched_fetch[0].get_shape().as_list()) 979 self.assertAllEqual((None, 2), 980 batched_fetch[1].indices.get_shape().as_list()) 981 self.assertAllEqual((None,), 982 batched_fetch[1].values.get_shape().as_list()) 983 self.assertAllEqual((2,), 984 batched_fetch[1].dense_shape.get_shape().as_list()) 985 self.assertAllEqual((batch_size,), batched_fetch[2].get_shape().as_list()) 986 987 variables.global_variables_initializer().run() 988 variables.local_variables_initializer().run() 989 threads = queue_runner_impl.start_queue_runners() 990 991 # Should see the "a" and "b" threads mixed together. 992 all_a = [] 993 seen_b = 0 994 saw_both = 0 995 num_batches = (num_a + num_b) // batch_size 996 for i in range(num_batches): 997 results = sess.run(batched_fetch) 998 self.assertEqual(3, len(results)) 999 self.assertEqual(batch_size, len(results[0])) 1000 self.assertEqual(batch_size, len(results[2])) 1001 self.assertAllEqual(results[0], results[1].values) 1002 self.assertAllEqual( 1003 results[1].indices, 1004 np.vstack((np.arange(batch_size), np.zeros(batch_size))).T) 1005 self.assertAllEqual(results[1].dense_shape, [batch_size, 1]) 1006 which_a = [i for i, s in enumerate(results[2]) if s == b"a"] 1007 which_b = [i for i, s in enumerate(results[2]) if s == b"b"] 1008 self.assertEqual(len(which_a) + len(which_b), batch_size) 1009 if which_a and which_b: 1010 saw_both += 1 1011 all_a.extend([results[0][i] for i in which_a]) 1012 seen_b += len(which_b) 1013 self.assertAllEqual([99] * len(which_b), 1014 [results[0][i] for i in which_b]) 1015 1016 # Some minimum level of mixing of the results of both threads. 1017 self.assertGreater(saw_both, 1) 1018 1019 # Verify the order of results from "a" were preserved. 1020 self.assertAllEqual(all_a, np.arange(num_a)) 1021 self.assertEqual(seen_b, num_b) 1022 1023 # Reached the limit. 1024 with self.assertRaises(errors_impl.OutOfRangeError): 1025 sess.run(batched_fetch) 1026 for thread in threads: 1027 thread.join() 1028 1029 def DISABLED_testTwoThreads(self): 1030 self._testTwoThreadsHelper(use_dict=False) 1031 1032 def DISABLED_testTwoThreadsDict(self): 1033 self._testTwoThreadsHelper(use_dict=True) 1034 1035 def testMismatchedDictKeys(self): 1036 with self.assertRaisesRegexp(ValueError, "must have the same keys"): 1037 inp.batch_join( 1038 [{ 1039 "c": 12, 1040 "s": 123, 1041 "S": "a" 1042 }, { 1043 "cool": -12, 1044 "s": 99, 1045 "S": "b" 1046 }], 1047 batch_size=8) 1048 1049 def DISABLED_testTwoThreadsDynamicPad(self): 1050 with self.test_session() as sess: 1051 # Two threads, the first generates (0..69, ["a"] * 1..70). 1052 num_a = 70 1053 zero64 = constant_op.constant(0, dtype=dtypes.int64) 1054 examples = variables.Variable(zero64) 1055 counter = examples.count_up_to(num_a) 1056 1057 # The second generates (99, ["b"] * 99) 90 times and then stops. 1058 num_b = 90 1059 ninety_nine = inp.limit_epochs( 1060 constant_op.constant( 1061 99, dtype=dtypes.int64), num_b) 1062 1063 # These get joined together and grouped into batches of 5. 1064 batch_size = 5 1065 a = array_ops.tile(["a"], 1066 math_ops.to_int32(array_ops.stack([counter + 1]))) 1067 b = array_ops.tile(["b"], 1068 math_ops.to_int32(array_ops.stack([ninety_nine]))) 1069 batched = inp.batch_join( 1070 [[counter, a], [ninety_nine, b]], 1071 batch_size=batch_size, 1072 dynamic_pad=True) 1073 1074 # Shapes. 1075 self.assertEqual(2, len(batched)) 1076 self.assertAllEqual((batch_size,), batched[0].get_shape().as_list()) 1077 self.assertAllEqual((batch_size, None), batched[1].get_shape().as_list()) 1078 1079 variables.global_variables_initializer().run() 1080 variables.local_variables_initializer().run() 1081 threads = queue_runner_impl.start_queue_runners() 1082 1083 # Should see the "a" and "b" threads mixed together. 1084 all_a = [] 1085 count_string_a = [] 1086 seen_b = 0 1087 saw_both = 0 1088 num_batches = (num_a + num_b) // batch_size 1089 for i in range(num_batches): 1090 results = sess.run(batched) 1091 self.assertEqual(2, len(results)) 1092 self.assertEqual(len(results[0]), batch_size) 1093 self.assertEqual(len(results[1]), batch_size) 1094 for s in results[1]: 1095 if s[0] == b"b": 1096 self.assertAllEqual(s, [b"b"] * 99) 1097 else: 1098 count_string_a.append(sum(x == b"a" for x in s)) 1099 which_a = [i for i, s in enumerate(results[1]) if s[0] == b"a"] 1100 which_b = [i for i, s in enumerate(results[1]) if s[0] == b"b"] 1101 self.assertEqual(len(which_a) + len(which_b), batch_size) 1102 if which_a and which_b: 1103 saw_both += 1 1104 all_a.extend([results[0][i] for i in which_a]) 1105 seen_b += len(which_b) 1106 self.assertAllEqual([99] * len(which_b), 1107 [results[0][i] for i in which_b]) 1108 1109 # Some minimum level of mixing of the results of both threads. 1110 self.assertGreater(saw_both, 1) 1111 1112 # Verify the order of results from "a" were preserved. 1113 self.assertAllEqual( # tiled "a" with counter + 1 1114 count_string_a, np.arange(num_a) + 1) 1115 self.assertAllEqual(all_a, np.arange(num_a)) 1116 self.assertEqual(seen_b, num_b) 1117 1118 # Reached the limit. 1119 with self.assertRaises(errors_impl.OutOfRangeError): 1120 sess.run(batched) 1121 for thread in threads: 1122 thread.join() 1123 1124 def DISABLED_testTwoThreadsSmallerBatch(self): 1125 with self.test_session() as sess: 1126 extra_elements = 2 1127 # Two threads, the first generates (0..69, "a"). 1128 num_a = 70 + extra_elements 1129 zero64 = constant_op.constant(0, dtype=dtypes.int64) 1130 examples = variables.Variable(zero64) 1131 counter = examples.count_up_to(num_a) 1132 sparse_counter = sparse_tensor.SparseTensor( 1133 indices=array_ops.reshape(zero64, [1, 1]), 1134 values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]), 1135 dense_shape=[1]) 1136 1137 # The second generates (99, "b") 90 times and then stops. 1138 num_b = 90 + extra_elements 1139 ninety_nine = inp.limit_epochs( 1140 constant_op.constant( 1141 99, dtype=dtypes.int64), num_b) 1142 sparse_ninety_nine = sparse_tensor.SparseTensor( 1143 indices=array_ops.reshape(zero64, [1, 1]), 1144 values=array_ops.stack([math_ops.cast(ninety_nine, dtypes.float32)]), 1145 dense_shape=[1]) 1146 1147 # These get joined together and grouped into batches of 5. 1148 batch_size = 5 1149 batched = inp.batch_join( 1150 [[counter, sparse_counter, "a"], 1151 [ninety_nine, sparse_ninety_nine, "b"]], 1152 batch_size=batch_size, 1153 allow_smaller_final_batch=True) 1154 1155 # Shapes. 1156 self.assertEqual(3, len(batched)) 1157 self.assertAllEqual((None,), batched[0].get_shape().as_list()) 1158 self.assertAllEqual((None, 2), batched[1].indices.get_shape().as_list()) 1159 self.assertAllEqual((None,), batched[1].values.get_shape().as_list()) 1160 self.assertAllEqual((2,), batched[1].dense_shape.get_shape().as_list()) 1161 self.assertAllEqual((None,), batched[2].get_shape().as_list()) 1162 1163 variables.global_variables_initializer().run() 1164 variables.local_variables_initializer().run() 1165 threads = queue_runner_impl.start_queue_runners() 1166 1167 # Should see the "a" and "b" threads mixed together. 1168 all_a = [] 1169 seen_b = 0 1170 saw_both = 0 1171 num_batches = (num_a + num_b) // batch_size 1172 for i in range(num_batches): 1173 results = sess.run(batched) 1174 tf_logging.info("Batch %d: %s", i, results[0]) 1175 self.assertEqual(len(results[0]), batch_size) 1176 self.assertEqual(len(results[2]), batch_size) 1177 self.assertAllEqual(results[0], results[1].values) 1178 self.assertAllEqual( 1179 results[1].indices, 1180 np.vstack((np.arange(batch_size), np.zeros(batch_size))).T) 1181 self.assertAllEqual(results[1].dense_shape, [batch_size, 1]) 1182 which_a = [i for i, s in enumerate(results[2]) if s == b"a"] 1183 which_b = [i for i, s in enumerate(results[2]) if s == b"b"] 1184 self.assertEqual(len(which_a) + len(which_b), batch_size) 1185 if which_a and which_b: 1186 saw_both += 1 1187 all_a.extend([results[0][i] for i in which_a]) 1188 seen_b += len(which_b) 1189 self.assertAllEqual([99] * len(which_b), 1190 [results[0][i] for i in which_b]) 1191 1192 # Reached the final batch with 2 * extra_elements. 1193 results = sess.run(batched) 1194 tf_logging.info("Last Batch: %s", results[0]) 1195 self.assertEqual(len(results[0]), 2 * extra_elements) 1196 self.assertEqual(len(results[2]), 2 * extra_elements) 1197 self.assertAllEqual(results[0], results[1].values) 1198 self.assertAllEqual(results[1].indices, 1199 np.vstack((np.arange(2 * extra_elements), 1200 np.zeros(2 * extra_elements))).T) 1201 self.assertAllEqual(results[1].dense_shape, [2 * extra_elements, 1]) 1202 which_a = [i for i, s in enumerate(results[2]) if s == b"a"] 1203 which_b = [i for i, s in enumerate(results[2]) if s == b"b"] 1204 self.assertEqual(len(which_a) + len(which_b), 2 * extra_elements) 1205 if which_a and which_b: 1206 saw_both += 1 1207 all_a.extend([results[0][i] for i in which_a]) 1208 seen_b += len(which_b) 1209 1210 # Some minimum level of mixing of the results of both threads. 1211 self.assertGreater(saw_both, 1) 1212 1213 # Verify the order of results from "a" were preserved. 1214 self.assertAllEqual(all_a, np.arange(num_a)) 1215 self.assertEqual(seen_b, num_b) 1216 1217 # Reached the limit. 1218 with self.assertRaises(errors_impl.OutOfRangeError): 1219 sess.run(batched) 1220 for thread in threads: 1221 thread.join() 1222 1223 def DISABLED_testTwoThreadsDynamicPadSmallerBatch(self): 1224 with self.test_session() as sess: 1225 extra_elements = 2 1226 # Two threads, the first generates (0..69, ["a"] * 1..70). 1227 num_a = 70 + extra_elements 1228 zero64 = constant_op.constant(0, dtype=dtypes.int64) 1229 examples = variables.Variable(zero64) 1230 counter = examples.count_up_to(num_a) 1231 1232 # The second generates (99, ["b"] * 99) 90 times and then stops. 1233 num_b = 90 + extra_elements 1234 ninety_nine = inp.limit_epochs( 1235 constant_op.constant( 1236 99, dtype=dtypes.int64), num_b) 1237 1238 # These get joined together and grouped into batches of 5. 1239 batch_size = 5 1240 a = array_ops.tile(["a"], 1241 math_ops.to_int32(array_ops.stack([counter + 1]))) 1242 b = array_ops.tile(["b"], 1243 math_ops.to_int32(array_ops.stack([ninety_nine]))) 1244 batched = inp.batch_join( 1245 [[counter, a], [ninety_nine, b]], 1246 batch_size=batch_size, 1247 dynamic_pad=True, 1248 allow_smaller_final_batch=True) 1249 1250 # Shapes. 1251 self.assertEqual(2, len(batched)) 1252 self.assertAllEqual((None,), batched[0].get_shape().as_list()) 1253 self.assertAllEqual((None, None), batched[1].get_shape().as_list()) 1254 1255 variables.global_variables_initializer().run() 1256 variables.local_variables_initializer().run() 1257 threads = queue_runner_impl.start_queue_runners() 1258 1259 # Should see the "a" and "b" threads mixed together. 1260 all_a = [] 1261 count_string_a = [] 1262 seen_b = 0 1263 saw_both = 0 1264 num_batches = (num_a + num_b) // batch_size 1265 for i in range(num_batches): 1266 results = sess.run(batched) 1267 tf_logging.info("Batch %d: %s", i, results[0]) 1268 self.assertEqual(len(results[0]), batch_size) 1269 self.assertEqual(len(results[1]), batch_size) 1270 for s in results[1]: 1271 if s[0] == b"b": 1272 self.assertAllEqual(s, [b"b"] * 99) 1273 else: 1274 count_string_a.append(sum(x == b"a" for x in s)) 1275 which_a = [i for i, s in enumerate(results[1]) if s[0] == b"a"] 1276 which_b = [i for i, s in enumerate(results[1]) if s[0] == b"b"] 1277 self.assertEqual(len(which_a) + len(which_b), batch_size) 1278 if which_a and which_b: 1279 saw_both += 1 1280 all_a.extend([results[0][i] for i in which_a]) 1281 seen_b += len(which_b) 1282 self.assertAllEqual([99] * len(which_b), 1283 [results[0][i] for i in which_b]) 1284 1285 # Reached the final batch with 2 * extra_elements. 1286 results = sess.run(batched) 1287 tf_logging.info("Last Batch: %s", results[0]) 1288 self.assertEqual(len(results[0]), 2 * extra_elements) 1289 self.assertEqual(len(results[1]), 2 * extra_elements) 1290 for s in results[1]: 1291 if s[0] == b"b": 1292 self.assertAllEqual(s, [b"b"] * 99) 1293 else: 1294 count_string_a.append(sum(x == b"a" for x in s)) 1295 which_a = [i for i, s in enumerate(results[1]) if s[0] == b"a"] 1296 which_b = [i for i, s in enumerate(results[1]) if s[0] == b"b"] 1297 self.assertEqual(len(which_a) + len(which_b), 2 * extra_elements) 1298 if which_a and which_b: 1299 saw_both += 1 1300 all_a.extend([results[0][i] for i in which_a]) 1301 seen_b += len(which_b) 1302 1303 # Some minimum level of mixing of the results of both threads. 1304 self.assertGreater(saw_both, 1) 1305 1306 # Verify the order of results from "a" were preserved. 1307 self.assertAllEqual( # tiled "a" with counter + 1 1308 count_string_a, np.arange(num_a) + 1) 1309 self.assertAllEqual(all_a, np.arange(num_a)) 1310 self.assertEqual(seen_b, num_b) 1311 1312 # Reached the limit. 1313 with self.assertRaises(errors_impl.OutOfRangeError): 1314 sess.run(batched) 1315 for thread in threads: 1316 thread.join() 1317 1318 def testSharedName(self): 1319 with self.test_session(): 1320 batch_size = 10 1321 num_batches = 3 1322 zero64 = constant_op.constant(0, dtype=dtypes.int64) 1323 examples = variables.Variable(zero64) 1324 counter = examples.count_up_to(num_batches * batch_size) 1325 batched = inp.batch_join( 1326 [[counter, "string"]], 1327 batch_size=batch_size, 1328 shared_name="SHARED_NAME_XYZ", 1329 name="Q") 1330 1331 # Shapes. 1332 self.assertEqual(2, len(batched)) 1333 self.assertAllEqual((batch_size,), batched[0].get_shape().as_list()) 1334 self.assertAllEqual((batch_size,), batched[1].get_shape().as_list()) 1335 1336 self.assertProtoEquals( 1337 "s: 'SHARED_NAME_XYZ'", 1338 batched[0].op.inputs[0].op.node_def.attr["shared_name"]) 1339 1340 def testCannotInferRankError(self): 1341 with self.test_session(): 1342 x = array_ops.placeholder(dtype=dtypes.int64) 1343 with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"): 1344 inp.batch_join([[x]], batch_size=2) 1345 1346 def testSingleElementDict(self): 1347 x = inp.batch_join([{"c": [12, 12]}], batch_size=8) 1348 self.assertAllEqual((8, 2), x["c"].get_shape().as_list()) 1349 1350 def _testKeepInputHelper(self, num_threads, enqueue_many, 1351 keep_input_vector=False): 1352 with self.test_session() as sess: 1353 batch_size = 5 1354 num_batches = 4 1355 examples = variables.Variable(0) 1356 counter = examples.count_up_to(num_batches * batch_size * 2) 1357 sparse_counter = sparse_tensor.SparseTensor( 1358 indices=array_ops.zeros( 1359 [1, 1], dtype=dtypes.int64), 1360 values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]), 1361 dense_shape=[1]) 1362 to_batch = [counter, sparse_counter, "string"] 1363 if enqueue_many: 1364 to_batch = inp.batch(to_batch, 4 if keep_input_vector else 1) 1365 keep_input = array_ops.squeeze( 1366 math_ops.equal(0, math_ops.mod(to_batch[0], 2))) 1367 batched = inp.maybe_batch_join( 1368 [to_batch] * num_threads, 1369 keep_input, 1370 batch_size, 1371 enqueue_many=enqueue_many) 1372 variables.initialize_all_variables().run() 1373 variables.initialize_local_variables().run() 1374 threads = queue_runner_impl.start_queue_runners() 1375 1376 for _ in range(num_batches): 1377 results = sess.run(batched) 1378 self.assertAllEqual( 1379 [0] * batch_size, 1380 np.mod(results[0], 2),) 1381 self.assertAllEqual( 1382 [0] * batch_size, 1383 np.mod(results[1].values, 2),) 1384 self.assertAllEqual([b"string"] * batch_size, results[2]) 1385 1386 # Reached the limit. 1387 with self.assertRaises(errors_impl.OutOfRangeError): 1388 sess.run(batched) 1389 for thread in threads: 1390 thread.join() 1391 1392 def testSingleThreadKeepInput(self): 1393 self._testKeepInputHelper(1, False) 1394 1395 def testSingleThreadKeepInputEnqueueMany(self): 1396 self._testKeepInputHelper(1, True) 1397 1398 def testMultipleThreadKeepInput(self): 1399 self._testKeepInputHelper(5, False) 1400 1401 def testMultipleThreadKeepInputEnqueueMany(self): 1402 self._testKeepInputHelper(5, True) 1403 1404 def testSingleThreadKeepInputPerExample(self): 1405 self._testKeepInputHelper(1, True, keep_input_vector=True) 1406 1407 def testMultipleThreadKeepInputPerExample(self): 1408 self._testKeepInputHelper(5, True, keep_input_vector=True) 1409 1410 def testInvalidKeepInputVector(self): 1411 # Can't have vector `keep_input` with `enqueue_many=False`. 1412 with self.assertRaisesRegexp(ValueError, "`keep_input` cannot be a vector"): 1413 inp.maybe_batch_join([[array_ops.zeros(5)]], 1414 keep_input=constant_op.constant([True, False]), 1415 batch_size=1, 1416 enqueue_many=False) 1417 # Can't have `keep_input` with more than one dimension. 1418 with self.assertRaisesRegexp(ValueError, "must be 0 or 1 dimensions"): 1419 inp.maybe_batch_join([[array_ops.zeros(5)]], 1420 keep_input=constant_op.constant([[True], [False]]), 1421 batch_size=1, 1422 enqueue_many=True) 1423 # `keep_input` must have dimensions determined at graph construction. 1424 with self.assertRaisesRegexp(ValueError, 1425 "must be known at graph construction"): 1426 inp.maybe_batch_join([[array_ops.zeros(5)]], 1427 keep_input=array_ops.placeholder(dtypes.bool), 1428 batch_size=1, 1429 enqueue_many=True) 1430 1431 def testMaybeBatchedSparseTensorInferredShape(self): 1432 sparse = sparse_tensor.SparseTensor( 1433 indices=[[0]], values=[1.0], dense_shape=[1]) 1434 self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list()) 1435 batched = inp.maybe_batch_join([[sparse]], keep_input=True, batch_size=2) 1436 self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list()) 1437 1438 def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self): 1439 sparse = sparse_tensor.SparseTensor( 1440 indices=[[0]], values=[1.0], dense_shape=[1]) 1441 self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list()) 1442 batched = inp.maybe_batch_join( 1443 [[sparse]], keep_input=True, batch_size=2, enqueue_many=True) 1444 self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list()) 1445 1446 def testMaybeBatchedSparseTensorInferredShapeEnqueueManyPerExample(self): 1447 sparse = sparse_tensor.SparseTensor( 1448 indices=[[0], [0]], values=[1.0, 2.0], dense_shape=[2]) 1449 self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list()) 1450 batched = inp.maybe_batch_join( 1451 [[sparse]], keep_input=[True, False], batch_size=2, enqueue_many=True) 1452 self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list()) 1453 1454 def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self): 1455 sparse = sparse_tensor.SparseTensor( 1456 indices=array_ops.placeholder(dtypes.int64), 1457 values=array_ops.placeholder(dtypes.float32), 1458 dense_shape=array_ops.placeholder(dtypes.int64)) 1459 self.assertIs(None, sparse.dense_shape.get_shape().num_elements()) 1460 batched = inp.maybe_batch_join([[sparse]], keep_input=True, batch_size=2) 1461 self.assertIs(None, batched.dense_shape.get_shape().num_elements()) 1462 1463 def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self): 1464 sparse = sparse_tensor.SparseTensor( 1465 indices=array_ops.placeholder(dtypes.int64), 1466 values=array_ops.placeholder(dtypes.float32), 1467 dense_shape=array_ops.placeholder(dtypes.int64)) 1468 self.assertIs(None, sparse.dense_shape.get_shape().num_elements()) 1469 batched = inp.maybe_batch_join( 1470 [[sparse]], keep_input=True, batch_size=2, enqueue_many=True) 1471 self.assertIs(None, batched.dense_shape.get_shape().num_elements()) 1472 1473 def testMaybeBatchedSparseTensorInferredShapeUnknownRankPerExample(self): 1474 sparse = sparse_tensor.SparseTensor( 1475 indices=array_ops.placeholder(dtypes.int64), 1476 values=array_ops.placeholder(dtypes.float32), 1477 dense_shape=array_ops.placeholder(dtypes.int64)) 1478 self.assertIs(None, sparse.dense_shape.get_shape().num_elements()) 1479 batched = inp.maybe_batch_join( 1480 [[sparse]], keep_input=[True, False], batch_size=2, enqueue_many=True) 1481 self.assertIs(None, batched.dense_shape.get_shape().num_elements()) 1482 1483 def testMaybeBatchCorrectValues(self): 1484 sparse = sparse_tensor.SparseTensor( 1485 indices=[[0, 1], [0, 2], [1, 0], [1, 3]], 1486 dense_shape=[2, 4], 1487 values=[5, 4, 7, 2]) 1488 keep = constant_op.constant([True, False]) 1489 batched = inp.maybe_batch_join( 1490 [[sparse]], keep_input=keep, batch_size=1, enqueue_many=True) 1491 1492 with self.test_session(): 1493 coord = coordinator.Coordinator() 1494 threads = queue_runner_impl.start_queue_runners(coord=coord) 1495 1496 batched_np = batched.eval() 1497 1498 coord.request_stop() 1499 for thread in threads: 1500 thread.join() 1501 1502 self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices) 1503 self.assertAllEqual([5, 4], batched_np.values) 1504 self.assertAllEqual([1, 4], batched_np.dense_shape) 1505 1506 1507 class ShuffleBatchTest(test_lib.TestCase): 1508 1509 def _testOneThreadHelper(self, use_dict): 1510 with self.test_session() as sess: 1511 batch_size = 10 1512 num_batches = 3 1513 zero64 = constant_op.constant(0, dtype=dtypes.int64) 1514 examples = variables.Variable(zero64) 1515 counter = examples.count_up_to(num_batches * batch_size) 1516 sparse_counter = sparse_tensor.SparseTensor( 1517 indices=array_ops.reshape(zero64, [1, 1]), 1518 values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]), 1519 dense_shape=[1]) 1520 if use_dict: 1521 batched = inp.shuffle_batch( 1522 { 1523 "c": counter, 1524 "s": sparse_counter, 1525 "S": "string" 1526 }, 1527 batch_size=batch_size, 1528 capacity=32, 1529 min_after_dequeue=16, 1530 seed=141421) 1531 batched_fetch = [batched["c"], batched["s"], batched["S"]] 1532 else: 1533 batched = inp.shuffle_batch( 1534 [counter, sparse_counter, "string"], 1535 batch_size=batch_size, 1536 capacity=32, 1537 min_after_dequeue=16, 1538 seed=141421) 1539 batched_fetch = batched 1540 variables.global_variables_initializer().run() 1541 variables.local_variables_initializer().run() 1542 threads = queue_runner_impl.start_queue_runners() 1543 1544 all_counts = [] 1545 for i in range(num_batches): 1546 results = sess.run(batched_fetch) 1547 self.assertEqual(len(results[0]), batch_size) 1548 all_counts.extend(results[0]) 1549 self.assertAllEqual( 1550 results[1].indices, 1551 np.vstack((np.arange(batch_size), np.zeros(batch_size))).T) 1552 self.assertAllEqual(results[0], results[1].values) 1553 self.assertAllEqual(results[1].dense_shape, [batch_size, 1]) 1554 self.assertAllEqual(results[2], [b"string"] * batch_size) 1555 # Results scrambled, but include all the expected numbers. 1556 deltas = [ 1557 all_counts[i + 1] - all_counts[i] for i in range(len(all_counts) - 1) 1558 ] 1559 self.assertFalse(all(d == deltas[0] for d in deltas)) 1560 self.assertItemsEqual(all_counts, range(num_batches * batch_size)) 1561 1562 # Reached the limit. 1563 with self.assertRaises(errors_impl.OutOfRangeError): 1564 sess.run(batched_fetch) 1565 for thread in threads: 1566 thread.join() 1567 1568 def testOneThread(self): 1569 self._testOneThreadHelper(use_dict=False) 1570 1571 def testOneThreadDict(self): 1572 self._testOneThreadHelper(use_dict=True) 1573 1574 def testOneThreadSmallerBatch(self): 1575 with self.test_session() as sess: 1576 batch_size = 10 1577 num_batches = 3 1578 extra_elements = 5 1579 zero64 = constant_op.constant(0, dtype=dtypes.int64) 1580 examples = variables.Variable(zero64) 1581 total_elements = num_batches * batch_size + extra_elements 1582 counter = examples.count_up_to(total_elements) 1583 sparse_counter = sparse_tensor.SparseTensor( 1584 indices=array_ops.reshape(zero64, [1, 1]), 1585 values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]), 1586 dense_shape=[1]) 1587 batched = inp.shuffle_batch( 1588 [counter, sparse_counter, "string"], 1589 batch_size=batch_size, 1590 capacity=32, 1591 min_after_dequeue=16, 1592 seed=141421, 1593 allow_smaller_final_batch=True) 1594 batched_fetch = batched 1595 variables.global_variables_initializer().run() 1596 variables.local_variables_initializer().run() 1597 threads = queue_runner_impl.start_queue_runners() 1598 1599 all_counts = [] 1600 for _ in range(num_batches): 1601 results = sess.run(batched_fetch) 1602 self.assertEqual(len(results[0]), batch_size) 1603 all_counts.extend(results[0]) 1604 self.assertAllEqual( 1605 results[1].indices, 1606 np.vstack((np.arange(batch_size), np.zeros(batch_size))).T) 1607 self.assertAllEqual(results[0], results[1].values) 1608 self.assertAllEqual(results[1].dense_shape, [batch_size, 1]) 1609 self.assertAllEqual(results[2], [b"string"] * batch_size) 1610 1611 # Reached the final batch with extra elements. 1612 results = sess.run(batched) 1613 self.assertAllEqual(results[1].dense_shape, [extra_elements, 1]) 1614 self.assertAllEqual(results[2], [b"string"] * extra_elements) 1615 all_counts.extend(results[0]) 1616 1617 # Results scrambled, but include all the expected numbers. 1618 deltas = [ 1619 all_counts[i + 1] - all_counts[i] for i in range(len(all_counts) - 1) 1620 ] 1621 self.assertFalse(all(d == deltas[0] for d in deltas)) 1622 self.assertItemsEqual(all_counts, range(total_elements)) 1623 1624 # Reached the limit. 1625 with self.assertRaises(errors_impl.OutOfRangeError): 1626 sess.run(batched_fetch) 1627 for thread in threads: 1628 thread.join() 1629 1630 def testManyThreads(self): 1631 with self.test_session() as sess: 1632 batch_size = 10 1633 num_batches = 3 1634 zero64 = constant_op.constant(0, dtype=dtypes.int64) 1635 examples = variables.Variable(zero64) 1636 counter = examples.count_up_to(num_batches * batch_size) 1637 sparse_counter = sparse_tensor.SparseTensor( 1638 indices=array_ops.reshape(zero64, [1, 1]), 1639 values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]), 1640 dense_shape=[1]) 1641 batched = inp.shuffle_batch( 1642 [counter, sparse_counter, "string"], 1643 batch_size=batch_size, 1644 capacity=32, 1645 min_after_dequeue=16, 1646 seed=173205, 1647 num_threads=4) 1648 variables.global_variables_initializer().run() 1649 variables.local_variables_initializer().run() 1650 threads = queue_runner_impl.start_queue_runners() 1651 1652 all_counts = [] 1653 for i in range(num_batches): 1654 results = sess.run(batched) 1655 tf_logging.info("Batch %d: %s", i, results[0]) 1656 self.assertEqual(len(results[0]), batch_size) 1657 all_counts.extend(results[0]) 1658 self.assertAllEqual( 1659 results[1].indices, 1660 np.vstack((np.arange(batch_size), np.zeros(batch_size))).T) 1661 self.assertAllEqual(results[0], results[1].values) 1662 self.assertAllEqual(results[1].dense_shape, [batch_size, 1]) 1663 self.assertAllEqual(results[2], [b"string"] * batch_size) 1664 # Results scrambled, but include all the expected numbers. 1665 deltas = [ 1666 all_counts[i + 1] - all_counts[i] for i in range(len(all_counts) - 1) 1667 ] 1668 self.assertFalse(all(d == deltas[0] for d in deltas)) 1669 self.assertItemsEqual(all_counts, range(num_batches * batch_size)) 1670 1671 # Reached the limit. 1672 with self.assertRaises(errors_impl.OutOfRangeError): 1673 sess.run(batched) 1674 for thread in threads: 1675 thread.join() 1676 1677 def testManyThreadsSmallerBatch(self): 1678 with self.test_session() as sess: 1679 batch_size = 10 1680 num_batches = 3 1681 extra_elements = 5 1682 zero64 = constant_op.constant(0, dtype=dtypes.int64) 1683 examples = variables.Variable(zero64) 1684 total_elements = num_batches * batch_size + extra_elements 1685 counter = examples.count_up_to(total_elements) 1686 sparse_counter = sparse_tensor.SparseTensor( 1687 indices=array_ops.reshape(zero64, [1, 1]), 1688 values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]), 1689 dense_shape=[1]) 1690 batched = inp.shuffle_batch( 1691 [counter, sparse_counter, "string"], 1692 batch_size=batch_size, 1693 capacity=32, 1694 min_after_dequeue=16, 1695 seed=173205, 1696 num_threads=4, 1697 allow_smaller_final_batch=True) 1698 variables.global_variables_initializer().run() 1699 variables.local_variables_initializer().run() 1700 threads = queue_runner_impl.start_queue_runners() 1701 1702 all_counts = [] 1703 for i in range(num_batches): 1704 results = sess.run(batched) 1705 tf_logging.info("Batch %d: %s", i, results[0]) 1706 self.assertEqual(len(results[0]), batch_size) 1707 all_counts.extend(results[0]) 1708 self.assertAllEqual( 1709 results[1].indices, 1710 np.vstack((np.arange(batch_size), np.zeros(batch_size))).T) 1711 self.assertAllEqual(results[0], results[1].values) 1712 self.assertAllEqual(results[1].dense_shape, [batch_size, 1]) 1713 self.assertAllEqual(results[2], [b"string"] * batch_size) 1714 1715 # Reached the final batch with extra elements. 1716 results = sess.run(batched) 1717 self.assertAllEqual(results[0].shape, [extra_elements]) 1718 self.assertAllEqual(results[1].dense_shape, [extra_elements, 1]) 1719 self.assertAllEqual(results[2], [b"string"] * extra_elements) 1720 all_counts.extend(results[0]) 1721 1722 # Results scrambled, but include all the expected numbers. 1723 deltas = [ 1724 all_counts[i + 1] - all_counts[i] for i in range(len(all_counts) - 1) 1725 ] 1726 self.assertFalse(all(d == deltas[0] for d in deltas)) 1727 self.assertItemsEqual(all_counts, range(total_elements)) 1728 1729 # Reached the limit. 1730 with self.assertRaises(errors_impl.OutOfRangeError): 1731 sess.run(batched) 1732 for thread in threads: 1733 thread.join() 1734 1735 def testSharedName(self): 1736 with self.test_session(): 1737 batch_size = 10 1738 num_batches = 3 1739 zero64 = constant_op.constant(0, dtype=dtypes.int64) 1740 examples = variables.Variable(zero64) 1741 counter = examples.count_up_to(num_batches * batch_size) 1742 batched = inp.shuffle_batch( 1743 [counter, "string"], 1744 batch_size=batch_size, 1745 capacity=32, 1746 min_after_dequeue=10, 1747 shared_name="SHARED_NAME_XYZ", 1748 name="Q") 1749 1750 self.assertProtoEquals( 1751 "s: 'SHARED_NAME_XYZ'", 1752 batched[0].op.inputs[0].op.node_def.attr["shared_name"]) 1753 1754 def _testKeepInputHelper(self, num_threads, enqueue_many, 1755 keep_input_vector=False): 1756 with self.test_session() as sess: 1757 batch_size = 5 1758 num_batches = 4 1759 examples = variables.Variable(0) 1760 counter = examples.count_up_to(num_batches * batch_size * 2) 1761 sparse_counter = sparse_tensor.SparseTensor( 1762 indices=array_ops.zeros( 1763 [1, 1], dtype=dtypes.int64), 1764 values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]), 1765 dense_shape=[1]) 1766 to_batch = [counter, sparse_counter, "string"] 1767 if enqueue_many: 1768 to_batch = inp.batch(to_batch, 4 if keep_input_vector else 1) 1769 keep_input = array_ops.squeeze( 1770 math_ops.equal(0, math_ops.mod(to_batch[0], 2))) 1771 batched = inp.maybe_shuffle_batch( 1772 to_batch, 1773 batch_size, 1774 10, 1775 1, 1776 keep_input, 1777 num_threads=num_threads, 1778 enqueue_many=enqueue_many) 1779 variables.initialize_all_variables().run() 1780 variables.initialize_local_variables().run() 1781 threads = queue_runner_impl.start_queue_runners() 1782 1783 for _ in range(num_batches): 1784 results = sess.run(batched) 1785 self.assertAllEqual([0] * batch_size, np.mod(results[0], 2)) 1786 self.assertAllEqual([0] * batch_size, np.mod(results[1].values, 2)) 1787 self.assertAllEqual([b"string"] * batch_size, results[2]) 1788 1789 # Reached the limit. 1790 with self.assertRaises(errors_impl.OutOfRangeError): 1791 sess.run(batched) 1792 for thread in threads: 1793 thread.join() 1794 1795 def testSingleThreadKeepInput(self): 1796 self._testKeepInputHelper(1, False) 1797 1798 def testSingleThreadKeepInputEnqueueMany(self): 1799 self._testKeepInputHelper(1, True) 1800 1801 def testMultipleThreadKeepInput(self): 1802 self._testKeepInputHelper(5, False) 1803 1804 def testMultipleThreadKeepInputEnqueueMany(self): 1805 self._testKeepInputHelper(5, True) 1806 1807 def testSingleThreadKeepInputPerExample(self): 1808 self._testKeepInputHelper(1, True, keep_input_vector=True) 1809 1810 def testMultipleThreadKeepInputPerExample(self): 1811 self._testKeepInputHelper(5, True, keep_input_vector=True) 1812 1813 def testInvalidKeepInputVector(self): 1814 # Can't have vector `keep_input` with `enqueue_many=False`. 1815 with self.assertRaisesRegexp(ValueError, "`keep_input` cannot be a vector"): 1816 inp.maybe_shuffle_batch([array_ops.zeros(5)], 1, 10, 1, 1817 keep_input=constant_op.constant([True, False]), 1818 enqueue_many=False) 1819 # Can't have `keep_input` with more than one dimension. 1820 with self.assertRaisesRegexp(ValueError, "must be 0 or 1 dimensions"): 1821 inp.maybe_shuffle_batch([array_ops.zeros(5)], 1, 10, 1, 1822 keep_input=constant_op.constant([[True]]), 1823 enqueue_many=True) 1824 # `keep_input` must have dimensions determined at graph construction. 1825 with self.assertRaisesRegexp(ValueError, 1826 "must be known at graph construction"): 1827 inp.maybe_shuffle_batch([array_ops.zeros(5)], 1, 10, 1, 1828 keep_input=array_ops.placeholder(dtypes.bool), 1829 enqueue_many=True) 1830 1831 def testMaybeBatchedSparseTensorInferredShape(self): 1832 sparse = sparse_tensor.SparseTensor( 1833 indices=[[0]], values=[1.0], dense_shape=[1]) 1834 self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list()) 1835 batched = inp.maybe_shuffle_batch([sparse], 2, 10, 1, True) 1836 self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list()) 1837 1838 def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self): 1839 sparse = sparse_tensor.SparseTensor( 1840 indices=[[0]], values=[1.0], dense_shape=[1]) 1841 self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list()) 1842 batched = inp.maybe_shuffle_batch( 1843 [sparse], 2, 10, 1, True, enqueue_many=True) 1844 self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list()) 1845 1846 def testMaybeBatchedSparseTensorInferredShapeEnqueueManyPerExample(self): 1847 sparse = sparse_tensor.SparseTensor( 1848 indices=[[0], [0]], values=[1.0, 2.0], dense_shape=[2]) 1849 self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list()) 1850 batched = inp.maybe_shuffle_batch( 1851 [sparse], 2, 10, 1, [True, False], enqueue_many=True) 1852 self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list()) 1853 1854 def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self): 1855 sparse = sparse_tensor.SparseTensor( 1856 indices=array_ops.placeholder(dtypes.int64), 1857 values=array_ops.placeholder(dtypes.float32), 1858 dense_shape=array_ops.placeholder(dtypes.int64)) 1859 self.assertIs(None, sparse.dense_shape.get_shape().num_elements()) 1860 batched = inp.maybe_shuffle_batch([sparse], 2, 10, 1, True) 1861 self.assertIs(None, batched.dense_shape.get_shape().num_elements()) 1862 1863 def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self): 1864 sparse = sparse_tensor.SparseTensor( 1865 indices=array_ops.placeholder(dtypes.int64), 1866 values=array_ops.placeholder(dtypes.float32), 1867 dense_shape=array_ops.placeholder(dtypes.int64)) 1868 self.assertIs(None, sparse.dense_shape.get_shape().num_elements()) 1869 batched = inp.maybe_shuffle_batch( 1870 [sparse], 2, 10, 1, True, enqueue_many=True) 1871 self.assertIs(None, batched.dense_shape.get_shape().num_elements()) 1872 1873 def testMaybeBatchedSparseTensorInferredShapeUnknownRankPerExample(self): 1874 sparse = sparse_tensor.SparseTensor( 1875 indices=array_ops.placeholder(dtypes.int64), 1876 values=array_ops.placeholder(dtypes.float32), 1877 dense_shape=array_ops.placeholder(dtypes.int64)) 1878 self.assertIs(None, sparse.dense_shape.get_shape().num_elements()) 1879 batched = inp.maybe_shuffle_batch( 1880 [sparse], 2, 10, 1, [True, False], enqueue_many=True) 1881 self.assertIs(None, batched.dense_shape.get_shape().num_elements()) 1882 1883 1884 class ShuffleBatchJoinTest(test_lib.TestCase): 1885 1886 def _testTwoThreadsHelper(self, use_dict): 1887 with self.test_session() as sess: 1888 # Two threads, the first generates (0..24, "a"). 1889 num_a = 25 1890 zero64 = constant_op.constant(0, dtype=dtypes.int64) 1891 examples = variables.Variable(zero64) 1892 counter = examples.count_up_to(num_a) 1893 sparse_counter = sparse_tensor.SparseTensor( 1894 indices=array_ops.reshape(zero64, [1, 1]), 1895 values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]), 1896 dense_shape=[1]) 1897 1898 # The second generates (99, "b") 35 times and then stops. 1899 num_b = 35 1900 ninety_nine = inp.limit_epochs( 1901 constant_op.constant( 1902 99, dtype=dtypes.int64), num_b) 1903 sparse_ninety_nine = sparse_tensor.SparseTensor( 1904 indices=array_ops.reshape(zero64, [1, 1]), 1905 values=array_ops.stack([math_ops.cast(ninety_nine, dtypes.float32)]), 1906 dense_shape=[1]) 1907 1908 # These get joined together and grouped into batches of 5. 1909 batch_size = 5 1910 if use_dict: 1911 batched = inp.shuffle_batch_join( 1912 [{ 1913 "c": counter, 1914 "s": sparse_counter, 1915 "S": "a" 1916 }, { 1917 "c": ninety_nine, 1918 "s": sparse_ninety_nine, 1919 "S": "b" 1920 }], 1921 batch_size=batch_size, 1922 capacity=32, 1923 min_after_dequeue=16, 1924 seed=223607) 1925 batched_fetch = [batched["c"], batched["s"], batched["S"]] 1926 else: 1927 batched = inp.shuffle_batch_join( 1928 [[counter, sparse_counter, "a"], 1929 [ninety_nine, sparse_ninety_nine, "b"]], 1930 batch_size=batch_size, 1931 capacity=32, 1932 min_after_dequeue=16, 1933 seed=223607) 1934 batched_fetch = batched 1935 1936 # Shapes. 1937 self.assertEqual(3, len(batched_fetch)) 1938 self.assertAllEqual((batch_size,), batched_fetch[0].get_shape().as_list()) 1939 self.assertAllEqual((None, 2), 1940 batched_fetch[1].indices.get_shape().as_list()) 1941 self.assertAllEqual((None,), 1942 batched_fetch[1].values.get_shape().as_list()) 1943 self.assertAllEqual((2,), 1944 batched_fetch[1].dense_shape.get_shape().as_list()) 1945 self.assertAllEqual((batch_size,), batched_fetch[2].get_shape().as_list()) 1946 1947 variables.global_variables_initializer().run() 1948 variables.local_variables_initializer().run() 1949 threads = queue_runner_impl.start_queue_runners() 1950 1951 # Should see the "a" and "b" threads mixed together. 1952 all_a = [] 1953 seen_b = 0 1954 saw_both = 0 1955 num_batches = (num_a + num_b) // batch_size 1956 for i in range(num_batches): 1957 results = sess.run(batched_fetch) 1958 self.assertEqual(3, len(results)) 1959 self.assertEqual(len(results[0]), batch_size) 1960 self.assertEqual(len(results[2]), batch_size) 1961 self.assertAllEqual(results[0], results[1].values) 1962 self.assertAllEqual( 1963 results[1].indices, 1964 np.vstack((np.arange(batch_size), np.zeros(batch_size))).T) 1965 self.assertAllEqual(results[1].dense_shape, [batch_size, 1]) 1966 which_a = [i for i, s in enumerate(results[2]) if s == b"a"] 1967 which_b = [i for i, s in enumerate(results[2]) if s == b"b"] 1968 self.assertEqual(len(which_a) + len(which_b), batch_size) 1969 if which_a and which_b: 1970 saw_both += 1 1971 all_a.extend([results[0][i] for i in which_a]) 1972 seen_b += len(which_b) 1973 self.assertAllEqual([99] * len(which_b), 1974 [results[0][i] for i in which_b]) 1975 1976 # Some minimum level of mixing of the results of both threads. 1977 self.assertGreater(saw_both, 1) 1978 1979 # Saw all the items from "a", but scrambled. 1980 self.assertItemsEqual(all_a, range(num_a)) 1981 deltas = [all_a[i + 1] - all_a[i] for i in range(len(all_a) - 1)] 1982 self.assertFalse(all(d == deltas[0] for d in deltas)) 1983 self.assertEqual(seen_b, num_b) 1984 1985 # Reached the limit. 1986 with self.assertRaises(errors_impl.OutOfRangeError): 1987 sess.run(batched_fetch) 1988 for thread in threads: 1989 thread.join() 1990 1991 def testTwoThreads(self): 1992 self._testTwoThreadsHelper(use_dict=False) 1993 1994 def testTwoThreadsDict(self): 1995 self._testTwoThreadsHelper(use_dict=True) 1996 1997 def testTwoThreadsSmallerBatch(self): 1998 with self.test_session() as sess: 1999 # Two threads, the first generates (0..26, "a"). 2000 extra_elements = 2 2001 num_a = 25 + extra_elements 2002 zero64 = constant_op.constant(0, dtype=dtypes.int64) 2003 examples = variables.Variable(zero64) 2004 counter = examples.count_up_to(num_a) 2005 sparse_counter = sparse_tensor.SparseTensor( 2006 indices=array_ops.reshape(zero64, [1, 1]), 2007 values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]), 2008 dense_shape=[1]) 2009 2010 # The second generates (99, "b") 37 times and then stops. 2011 num_b = 35 + extra_elements 2012 ninety_nine = inp.limit_epochs( 2013 constant_op.constant( 2014 99, dtype=dtypes.int64), num_b) 2015 sparse_ninety_nine = sparse_tensor.SparseTensor( 2016 indices=array_ops.reshape(zero64, [1, 1]), 2017 values=array_ops.stack([math_ops.cast(ninety_nine, dtypes.float32)]), 2018 dense_shape=[1]) 2019 2020 # These get joined together and grouped into batches of 5. 2021 batch_size = 5 2022 batched = inp.shuffle_batch_join( 2023 [[counter, sparse_counter, "a"], 2024 [ninety_nine, sparse_ninety_nine, "b"]], 2025 batch_size=batch_size, 2026 capacity=32, 2027 min_after_dequeue=16, 2028 seed=223607, 2029 allow_smaller_final_batch=True) 2030 2031 # Shapes. 2032 self.assertEqual(3, len(batched)) 2033 self.assertAllEqual((None,), batched[0].get_shape().as_list()) 2034 self.assertAllEqual((None, 2), batched[1].indices.get_shape().as_list()) 2035 self.assertAllEqual((None,), batched[1].values.get_shape().as_list()) 2036 self.assertAllEqual((2,), batched[1].dense_shape.get_shape().as_list()) 2037 self.assertAllEqual((None,), batched[2].get_shape().as_list()) 2038 2039 variables.global_variables_initializer().run() 2040 variables.local_variables_initializer().run() 2041 threads = queue_runner_impl.start_queue_runners() 2042 2043 # Should see the "a" and "b" threads mixed together. 2044 all_a = [] 2045 seen_b = 0 2046 saw_both = 0 2047 num_batches = (num_a + num_b) // batch_size 2048 for i in range(num_batches): 2049 results = sess.run(batched) 2050 tf_logging.info("Batch %d: %s", i, results[0]) 2051 self.assertEqual(len(results[0]), batch_size) 2052 self.assertEqual(len(results[2]), batch_size) 2053 self.assertAllEqual(results[0], results[1].values) 2054 self.assertAllEqual( 2055 results[1].indices, 2056 np.vstack((np.arange(batch_size), np.zeros(batch_size))).T) 2057 self.assertAllEqual(results[1].dense_shape, [batch_size, 1]) 2058 which_a = [i for i, s in enumerate(results[2]) if s == b"a"] 2059 which_b = [i for i, s in enumerate(results[2]) if s == b"b"] 2060 self.assertEqual(len(which_a) + len(which_b), batch_size) 2061 if which_a and which_b: 2062 saw_both += 1 2063 all_a.extend([results[0][i] for i in which_a]) 2064 seen_b += len(which_b) 2065 self.assertAllEqual([99] * len(which_b), 2066 [results[0][i] for i in which_b]) 2067 2068 # Reached end with 2 * extra_elements left 2069 results = sess.run(batched) 2070 self.assertEqual(len(results[0]), 2 * extra_elements) 2071 self.assertAllEqual(results[1].dense_shape, [2 * extra_elements, 1]) 2072 self.assertEqual(len(results[2]), 2 * extra_elements) 2073 self.assertAllEqual(results[0], results[1].values) 2074 self.assertAllEqual(results[1].indices, 2075 np.vstack((np.arange(2 * extra_elements), 2076 np.zeros(2 * extra_elements))).T) 2077 which_a = [i for i, s in enumerate(results[2]) if s == b"a"] 2078 which_b = [i for i, s in enumerate(results[2]) if s == b"b"] 2079 self.assertEqual(len(which_a) + len(which_b), 2 * extra_elements) 2080 if which_a and which_b: 2081 saw_both += 1 2082 all_a.extend([results[0][i] for i in which_a]) 2083 seen_b += len(which_b) 2084 2085 # Some minimum level of mixing of the results of both threads. 2086 self.assertGreater(saw_both, 1) 2087 2088 # Saw all the items from "a", but scrambled, including extras. 2089 self.assertItemsEqual(all_a, range(num_a)) 2090 deltas = [all_a[i + 1] - all_a[i] for i in range(len(all_a) - 1)] 2091 self.assertFalse(all(d == deltas[0] for d in deltas)) 2092 self.assertEqual(seen_b, num_b) 2093 2094 # Reached the limit. 2095 with self.assertRaises(errors_impl.OutOfRangeError): 2096 sess.run(batched) 2097 for thread in threads: 2098 thread.join() 2099 2100 def testMismatchedDictKeys(self): 2101 with self.assertRaisesRegexp(ValueError, "must have the same keys"): 2102 inp.shuffle_batch_join( 2103 [{ 2104 "c": 12, 2105 "s": 123, 2106 "S": "a" 2107 }, { 2108 "cool": -12, 2109 "s": 99, 2110 "S": "b" 2111 }], 2112 batch_size=8, 2113 capacity=32, 2114 min_after_dequeue=16, 2115 seed=223607) 2116 2117 def testSharedName(self): 2118 with self.test_session(): 2119 batch_size = 10 2120 num_batches = 3 2121 zero64 = constant_op.constant(0, dtype=dtypes.int64) 2122 examples = variables.Variable(zero64) 2123 counter = examples.count_up_to(num_batches * batch_size) 2124 batched = inp.shuffle_batch_join( 2125 [[counter, "string"]], 2126 batch_size=batch_size, 2127 capacity=32, 2128 min_after_dequeue=10, 2129 shared_name="SHARED_NAME_XYZ", 2130 name="Q") 2131 2132 # Shapes. 2133 self.assertEqual(2, len(batched)) 2134 self.assertAllEqual((batch_size,), batched[0].get_shape().as_list()) 2135 self.assertAllEqual((batch_size,), batched[1].get_shape().as_list()) 2136 2137 self.assertProtoEquals( 2138 "s: 'SHARED_NAME_XYZ'", 2139 batched[0].op.inputs[0].op.node_def.attr["shared_name"]) 2140 2141 def _testKeepInputHelper(self, num_threads, enqueue_many, 2142 keep_input_vector=False): 2143 with self.test_session() as sess: 2144 batch_size = 5 2145 num_batches = 4 2146 examples = variables.Variable(0) 2147 counter = examples.count_up_to(num_batches * batch_size * 2) 2148 sparse_counter = sparse_tensor.SparseTensor( 2149 indices=array_ops.zeros( 2150 [1, 1], dtype=dtypes.int64), 2151 values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]), 2152 dense_shape=[1]) 2153 to_batch = [counter, sparse_counter, "string"] 2154 if enqueue_many: 2155 to_batch = inp.batch(to_batch, 4 if keep_input_vector else 1) 2156 keep_input = array_ops.squeeze( 2157 math_ops.equal(0, math_ops.mod(to_batch[0], 2))) 2158 batched = inp.maybe_shuffle_batch_join( 2159 [to_batch] * num_threads, 2160 batch_size, 2161 10, 2162 1, 2163 keep_input, 2164 enqueue_many=enqueue_many) 2165 variables.initialize_all_variables().run() 2166 variables.initialize_local_variables().run() 2167 threads = queue_runner_impl.start_queue_runners() 2168 2169 for _ in range(num_batches): 2170 results = sess.run(batched) 2171 self.assertAllEqual([0] * batch_size, np.mod(results[0], 2)) 2172 self.assertAllEqual([0] * batch_size, np.mod(results[1].values, 2)) 2173 self.assertAllEqual([b"string"] * batch_size, results[2]) 2174 2175 # Reached the limit. 2176 with self.assertRaises(errors_impl.OutOfRangeError): 2177 sess.run(batched) 2178 for thread in threads: 2179 thread.join() 2180 2181 def testSingleThreadKeepInput(self): 2182 self._testKeepInputHelper(1, False) 2183 2184 def testSingleThreadKeepInputEnqueueMany(self): 2185 self._testKeepInputHelper(1, True) 2186 2187 def testMultipleThreadKeepInput(self): 2188 self._testKeepInputHelper(5, False) 2189 2190 def testMultipleThreadKeepInputEnqueueMany(self): 2191 self._testKeepInputHelper(5, True) 2192 2193 def testSingleThreadKeepInputPerExample(self): 2194 self._testKeepInputHelper(1, True, keep_input_vector=True) 2195 2196 def testMultipleThreadKeepInputPerExample(self): 2197 self._testKeepInputHelper(5, True, keep_input_vector=True) 2198 2199 def testInvalidKeepInputVector(self): 2200 # Can't have vector `keep_input` with `enqueue_many=False`. 2201 with self.assertRaisesRegexp(ValueError, "`keep_input` cannot be a vector"): 2202 inp.maybe_shuffle_batch_join( 2203 [[array_ops.zeros(5)]], 1, 10, 1, 2204 keep_input=constant_op.constant([True, False]), 2205 enqueue_many=False) 2206 # Can't have `keep_input` with more than one dimension. 2207 with self.assertRaisesRegexp(ValueError, "must be 0 or 1 dimensions"): 2208 inp.maybe_shuffle_batch_join( 2209 [[array_ops.zeros(5)]], 1, 10, 1, 2210 keep_input=constant_op.constant([[True]]), 2211 enqueue_many=True) 2212 # `keep_input` must have dimensions determined at graph construction. 2213 with self.assertRaisesRegexp(ValueError, 2214 "must be known at graph construction"): 2215 inp.maybe_shuffle_batch_join( 2216 [[array_ops.zeros(5)]], 1, 10, 1, 2217 keep_input=array_ops.placeholder(dtypes.bool), 2218 enqueue_many=True) 2219 2220 def testMaybeBatchedSparseTensorInferredShape(self): 2221 sparse = sparse_tensor.SparseTensor( 2222 indices=[[0]], values=[1.0], dense_shape=[1]) 2223 self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list()) 2224 batched = inp.maybe_shuffle_batch_join([[sparse]], 2, 10, 1, True) 2225 self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list()) 2226 2227 def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self): 2228 sparse = sparse_tensor.SparseTensor( 2229 indices=[[0]], values=[1.0], dense_shape=[1]) 2230 self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list()) 2231 batched = inp.maybe_shuffle_batch_join( 2232 [[sparse]], 2, 10, 1, True, enqueue_many=True) 2233 self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list()) 2234 2235 def testMaybeBatchedSparseTensorInferredShapeEnqueueManyPerExample(self): 2236 sparse = sparse_tensor.SparseTensor( 2237 indices=[[0], [0]], values=[1.0, 2.0], dense_shape=[2]) 2238 self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list()) 2239 batched = inp.maybe_shuffle_batch_join( 2240 [[sparse]], 2, 10, 1, [True, False], enqueue_many=True) 2241 self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list()) 2242 2243 def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self): 2244 sparse = sparse_tensor.SparseTensor( 2245 indices=array_ops.placeholder(dtypes.int64), 2246 values=array_ops.placeholder(dtypes.float32), 2247 dense_shape=array_ops.placeholder(dtypes.int64)) 2248 self.assertIs(None, sparse.dense_shape.get_shape().num_elements()) 2249 batched = inp.maybe_shuffle_batch_join([[sparse]], 2, 10, 1, True) 2250 self.assertIs(None, batched.dense_shape.get_shape().num_elements()) 2251 2252 def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self): 2253 sparse = sparse_tensor.SparseTensor( 2254 indices=array_ops.placeholder(dtypes.int64), 2255 values=array_ops.placeholder(dtypes.float32), 2256 dense_shape=array_ops.placeholder(dtypes.int64)) 2257 self.assertIs(None, sparse.dense_shape.get_shape().num_elements()) 2258 batched = inp.maybe_shuffle_batch_join( 2259 [[sparse]], 2, 10, 1, True, enqueue_many=True) 2260 self.assertIs(None, batched.dense_shape.get_shape().num_elements()) 2261 2262 def testMaybeBatchedSparseTensorInferredShapeUnknownRankPerExample(self): 2263 sparse = sparse_tensor.SparseTensor( 2264 indices=array_ops.placeholder(dtypes.int64), 2265 values=array_ops.placeholder(dtypes.float32), 2266 dense_shape=array_ops.placeholder(dtypes.int64)) 2267 self.assertIs(None, sparse.dense_shape.get_shape().num_elements()) 2268 batched = inp.maybe_shuffle_batch_join( 2269 [[sparse]], 2, 10, 1, [True, False], enqueue_many=True) 2270 self.assertIs(None, batched.dense_shape.get_shape().num_elements()) 2271 2272 2273 if __name__ == "__main__": 2274 test_lib.main() 2275