Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for training.input."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import itertools
     22 import os
     23 
     24 import numpy as np
     25 from six.moves import xrange  # pylint: disable=redefined-builtin
     26 
     27 from tensorflow.python.framework import constant_op
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.framework import errors_impl
     30 from tensorflow.python.framework import sparse_tensor
     31 from tensorflow.python.ops import array_ops
     32 from tensorflow.python.ops import math_ops
     33 from tensorflow.python.ops import variables
     34 from tensorflow.python.platform import test as test_lib
     35 from tensorflow.python.platform import tf_logging
     36 from tensorflow.python.training import coordinator
     37 from tensorflow.python.training import input as inp
     38 from tensorflow.python.training import queue_runner_impl
     39 from tensorflow.python.util import compat
     40 
     41 
     42 class MatchFilenamesOnceTest(test_lib.TestCase):
     43 
     44   def test(self):
     45     temp_dir = self.get_temp_dir()
     46     filenames = [os.path.join(temp_dir, n) for n in os.listdir(temp_dir)]
     47     additional = [
     48         os.path.join(self.get_temp_dir(), "match_filenames.%d" % i)
     49         for i in range(3)
     50     ]
     51     for name in additional:
     52       open(name, "w").write("Some contents")
     53     filenames = list(set(filenames + additional))
     54     with self.test_session():
     55       star = inp.match_filenames_once(os.path.join(self.get_temp_dir(), "*"))
     56       question = inp.match_filenames_once(
     57           os.path.join(self.get_temp_dir(), "match_filenames.?"))
     58       one = inp.match_filenames_once(additional[1])
     59       variables.global_variables_initializer().run()
     60       variables.local_variables_initializer().run()
     61       self.assertItemsEqual(map(compat.as_bytes, filenames), star.eval())
     62       self.assertItemsEqual(map(compat.as_bytes, additional), question.eval())
     63       self.assertItemsEqual([compat.as_bytes(additional[1])], one.eval())
     64 
     65 
     66 class LimitEpochsTest(test_lib.TestCase):
     67 
     68   def testNoLimit(self):
     69     with self.test_session():
     70       seven = constant_op.constant(7)
     71       seven_forever = inp.limit_epochs(seven)
     72       variables.local_variables_initializer().run()
     73       for _ in range(100):
     74         self.assertEqual(7, seven_forever.eval())
     75 
     76   def testLimit(self):
     77     with self.test_session():
     78       love_me = constant_op.constant("Love Me")
     79       love_me_two_times = inp.limit_epochs(love_me, num_epochs=2)
     80       variables.global_variables_initializer().run()
     81       variables.local_variables_initializer().run()
     82       self.assertEqual(b"Love Me", love_me_two_times.eval())
     83       self.assertEqual(b"Love Me", love_me_two_times.eval())
     84       with self.assertRaises(errors_impl.OutOfRangeError):
     85         love_me_two_times.eval()
     86 
     87 
     88 class InputProducerTest(test_lib.TestCase):
     89 
     90   def testNoShuffle(self):
     91     with self.test_session():
     92       input_tensor = [[1, 2, 3, 4],
     93                       [5, 6, 7, 8],
     94                       [9, 10, 11, 12]]
     95       num_epochs = 2
     96       queue = inp.input_producer(
     97           input_tensor, num_epochs=num_epochs, shuffle=False)
     98       dequeue_many = queue.dequeue_many(len(input_tensor) * num_epochs)
     99       dequeue = queue.dequeue()
    100       variables.global_variables_initializer().run()
    101       variables.local_variables_initializer().run()
    102       threads = queue_runner_impl.start_queue_runners()
    103 
    104       # No randomness, so just see repeated copies of the input.
    105       self.assertAllEqual(input_tensor * num_epochs, dequeue_many.eval())
    106 
    107       # Reached the limit.
    108       with self.assertRaises(errors_impl.OutOfRangeError):
    109         dequeue.eval()
    110       for thread in threads:
    111         thread.join()
    112 
    113   def testNoShapeInference(self):
    114     with self.test_session():
    115       # Disable shape inference for the input.
    116       input_value = [[1, 2, 3, 4],
    117                      [5, 6, 7, 8],
    118                      [9, 10, 11, 12]]
    119       input_tensor = array_ops.placeholder_with_default(input_value, shape=None)
    120       num_epochs = 2
    121       queue = inp.input_producer(
    122           input_tensor, element_shape=[4], num_epochs=num_epochs, shuffle=False)
    123       dequeue_many = queue.dequeue_many(len(input_value) * num_epochs)
    124       dequeue = queue.dequeue()
    125       variables.global_variables_initializer().run()
    126       variables.local_variables_initializer().run()
    127       threads = queue_runner_impl.start_queue_runners()
    128 
    129       # No randomness, so just see repeated copies of the input.
    130       self.assertAllEqual(input_value * num_epochs, dequeue_many.eval())
    131 
    132       # Reached the limit.
    133       with self.assertRaises(errors_impl.OutOfRangeError):
    134         dequeue.eval()
    135       for thread in threads:
    136         thread.join()
    137 
    138   def testShapeError(self):
    139     input_tensor = array_ops.placeholder(dtypes.float32, None)
    140     with self.assertRaisesRegexp(ValueError, "fully defined shape"):
    141       _ = inp.input_producer(input_tensor)
    142 
    143 
    144 class StringInputProducerTest(test_lib.TestCase):
    145 
    146   def testNoShuffle(self):
    147     with self.test_session():
    148       strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
    149       num_epochs = 3
    150       queue = inp.string_input_producer(
    151           strings, num_epochs=num_epochs, shuffle=False)
    152       dequeue_many = queue.dequeue_many(len(strings) * num_epochs)
    153       dequeue = queue.dequeue()
    154       variables.global_variables_initializer().run()
    155       variables.local_variables_initializer().run()
    156       threads = queue_runner_impl.start_queue_runners()
    157 
    158       # No randomness, so just see repeated copies of the input.
    159       output = dequeue_many.eval()
    160       self.assertAllEqual(strings * num_epochs, output)
    161 
    162       # Reached the limit.
    163       with self.assertRaises(errors_impl.OutOfRangeError):
    164         dequeue.eval()
    165       for thread in threads:
    166         thread.join()
    167 
    168   def testShuffle(self):
    169     with self.test_session():
    170       strings = [b"a", b"b", b"c"]
    171       num_epochs = 600
    172       queue = inp.string_input_producer(
    173           strings, num_epochs=num_epochs, shuffle=True, seed=271828)
    174       dequeue_many = queue.dequeue_many(len(strings))
    175       dequeue = queue.dequeue()
    176       variables.global_variables_initializer().run()
    177       variables.local_variables_initializer().run()
    178       threads = queue_runner_impl.start_queue_runners()
    179 
    180       # Validate that we only shuffle the strings within an epoch and
    181       # count how often each possible order appears.
    182       expected = [b"abc", b"acb", b"bac", b"bca", b"cab", b"cba"]
    183       frequency = {}
    184       for e in expected:
    185         frequency[e] = 0
    186       for _ in range(num_epochs):
    187         output = dequeue_many.eval()
    188         key = b"".join(output)
    189         self.assertIn(key, expected)
    190         frequency[key] += 1
    191 
    192       # Expect an approximately even distribution over all possible orders.
    193       expected_frequency = num_epochs / len(expected)
    194       margin = expected_frequency * 0.4
    195       tf_logging.info("Observed counts: %s", frequency)
    196       for key in expected:
    197         value = frequency[key]
    198         self.assertGreater(value, expected_frequency - margin)
    199         self.assertLess(value, expected_frequency + margin)
    200 
    201       # Reached the limit.
    202       with self.assertRaises(errors_impl.OutOfRangeError):
    203         dequeue.eval()
    204       for thread in threads:
    205         thread.join()
    206 
    207   def testNullStringPython(self):
    208     # Graph-construction time check for empty string list:
    209     with self.test_session():
    210       with self.assertRaises(ValueError):
    211         _ = inp.string_input_producer([])
    212 
    213   def testNullString(self):
    214     # Runtime check for empty string list.  This is slightly oblique:
    215     # The queue runner should die with an assertion error on the null
    216     # input tensor, causing the dequeue to fail with an OutOfRangeError.
    217     with self.test_session():
    218       coord = coordinator.Coordinator()
    219       queue = inp.string_input_producer(
    220           constant_op.constant(
    221               [], dtype=dtypes.string))
    222       dequeue = queue.dequeue()
    223       variables.global_variables_initializer().run()
    224       variables.local_variables_initializer().run()
    225       threads = queue_runner_impl.start_queue_runners(coord=coord)
    226       with self.assertRaises(errors_impl.OutOfRangeError):
    227         dequeue.eval()
    228       coord.request_stop()
    229       for thread in threads:
    230         thread.join()
    231 
    232   def testSharedName(self):
    233     with self.test_session():
    234       strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
    235       queue = inp.string_input_producer(
    236           strings, shared_name="SHARED_NAME_XYZ", name="Q")
    237       self.assertProtoEquals("s: 'SHARED_NAME_XYZ'",
    238                              queue.queue_ref.op.node_def.attr["shared_name"])
    239 
    240   def testConstructionRace(self):
    241     with self.test_session() as sess:
    242       strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
    243       queue = inp.string_input_producer(strings, shuffle=False)
    244       coord = coordinator.Coordinator()
    245       threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
    246       for _ in range(2):
    247         for string in strings:
    248           # NOTE(mrry): This is not the recommended way to write
    249           # dequeuing code (instead you should create a single dequeue
    250           # op before starting the queue runners, and run it
    251           # repeatedly), because it leads to concurrent reading and
    252           # writing of the `tf.Graph` object. However, many users
    253           # write code this way, so we include this test to ensure
    254           # that we can support it.
    255           self.assertEquals(string, sess.run(queue.dequeue()))
    256       coord.request_stop()
    257       coord.join(threads)
    258 
    259 
    260 class RangeInputProducerTest(test_lib.TestCase):
    261 
    262   def testNoShuffle(self):
    263     with self.test_session():
    264       num_epochs = 3
    265       range_size = 5
    266       queue = inp.range_input_producer(
    267           range_size, num_epochs=num_epochs, shuffle=False)
    268       dequeue_many = queue.dequeue_many(range_size * num_epochs)
    269       dequeue = queue.dequeue()
    270       variables.global_variables_initializer().run()
    271       variables.local_variables_initializer().run()
    272       threads = queue_runner_impl.start_queue_runners()
    273 
    274       # No randomness, so just see repeated copies of the input.
    275       output = dequeue_many.eval()
    276       self.assertAllEqual(list(xrange(range_size)) * num_epochs, output)
    277 
    278       # Reached the limit.
    279       with self.assertRaises(errors_impl.OutOfRangeError):
    280         dequeue.eval()
    281       for thread in threads:
    282         thread.join()
    283 
    284   def testShuffle(self):
    285     with self.test_session():
    286       num_epochs = 200
    287       range_size = 2
    288       queue = inp.range_input_producer(
    289           range_size, num_epochs=num_epochs, shuffle=True, seed=314159)
    290       dequeue_many = queue.dequeue_many(range_size)
    291       dequeue = queue.dequeue()
    292       variables.global_variables_initializer().run()
    293       variables.local_variables_initializer().run()
    294       threads = queue_runner_impl.start_queue_runners()
    295 
    296       # Validate that we only shuffle the integers within an epoch and
    297       # count how often each possible order appears.
    298       expected = [12, 21]
    299       frequency = {}
    300       for e in expected:
    301         frequency[e] = 0
    302       for _ in range(num_epochs):
    303         output = dequeue_many.eval()
    304         key = 10 * (output[0] + 1) + (output[1] + 1)
    305         self.assertIn(key, expected)
    306         frequency[key] += 1
    307 
    308       # Expect an approximately even distribution over all possible orders.
    309       expected_frequency = num_epochs / len(expected)
    310       margin = expected_frequency * 0.4
    311       tf_logging.info("Observed counts: %s", frequency)
    312       for key in expected:
    313         value = frequency[key]
    314         self.assertGreater(value, expected_frequency - margin)
    315         self.assertLess(value, expected_frequency + margin)
    316 
    317       # Reached the limit.
    318       with self.assertRaises(errors_impl.OutOfRangeError):
    319         dequeue.eval()
    320       for thread in threads:
    321         thread.join()
    322 
    323   def testSharedName(self):
    324     with self.test_session():
    325       range_size = 5
    326       queue = inp.range_input_producer(
    327           range_size, shared_name="SHARED_NAME_XYZ", name="Q")
    328       self.assertProtoEquals("s: 'SHARED_NAME_XYZ'",
    329                              queue.queue_ref.op.node_def.attr["shared_name"])
    330 
    331 
    332 class SliceInputProducerTest(test_lib.TestCase):
    333 
    334   def testNoShuffle(self):
    335     with self.test_session() as sess:
    336       num_epochs = 3
    337       source_strings = [b"Alpha", b"Beta", b"Delta", b"Gamma"]
    338       source_ints = [2, 3, 5, 7]
    339       slices = inp.slice_input_producer(
    340           [source_strings, source_ints], num_epochs=num_epochs, shuffle=False)
    341       variables.global_variables_initializer().run()
    342       variables.local_variables_initializer().run()
    343       threads = queue_runner_impl.start_queue_runners()
    344 
    345       # No randomness, so just see repeated copies of the input.
    346       num_items = len(source_strings) * num_epochs
    347       output = [sess.run(slices) for _ in range(num_items)]
    348       out_strings, out_ints = zip(*output)
    349       self.assertAllEqual(source_strings * num_epochs, out_strings)
    350       self.assertAllEqual(source_ints * num_epochs, out_ints)
    351 
    352       # Reached the limit.
    353       with self.assertRaises(errors_impl.OutOfRangeError):
    354         sess.run(slices)
    355       for thread in threads:
    356         thread.join()
    357 
    358   def testShuffle(self):
    359     with self.test_session() as sess:
    360       num_epochs = 1200
    361       source_strings = ["A", "B", "D", "G"]
    362       source_ints = [7, 3, 5, 2]
    363       slices = inp.slice_input_producer(
    364           [source_strings, source_ints],
    365           num_epochs=num_epochs,
    366           shuffle=True,
    367           seed=161803)
    368       variables.global_variables_initializer().run()
    369       variables.local_variables_initializer().run()
    370       threads = queue_runner_impl.start_queue_runners()
    371 
    372       # Validate that we only shuffle the integers within an epoch and
    373       # count how often each possible order appears.
    374       expected = [
    375           b",".join(x)
    376           for x in itertools.permutations([b"A7", b"B3", b"D5", b"G2"])
    377       ]
    378       frequency = {}
    379       for e in expected:
    380         frequency[e] = 0
    381       for _ in range(num_epochs):
    382         output = [sess.run(slices) for _ in range(len(source_strings))]
    383         key = b",".join([s + compat.as_bytes(str(i)) for s, i in output])
    384         self.assertIn(key, expected)
    385         frequency[key] += 1
    386 
    387       # Expect an approximately even distribution over all possible orders.
    388       expected_frequency = num_epochs / len(expected)
    389       margin = expected_frequency * 0.4
    390       tf_logging.info("Observed counts: %s", frequency)
    391       for key in expected:
    392         value = frequency[key]
    393         self.assertGreater(value, expected_frequency - margin)
    394         self.assertLess(value, expected_frequency + margin)
    395 
    396       # Reached the limit.
    397       with self.assertRaises(errors_impl.OutOfRangeError):
    398         sess.run(slices)
    399       for thread in threads:
    400         thread.join()
    401 
    402   def testSharedName(self):
    403     with self.test_session():
    404       source_strings = ["A", "B", "D", "G"]
    405       source_ints = [7, 3, 5, 2]
    406       slices = inp.slice_input_producer(
    407           [source_strings, source_ints],
    408           shared_name="SHARED_NAME_XYZ",
    409           name="sip")
    410 
    411       self.assertProtoEquals(
    412           "s: 'SHARED_NAME_XYZ'",
    413           slices[0].op.inputs[1].op.inputs[0].op.node_def.attr["shared_name"])
    414 
    415 
    416 class DictHelperTest(test_lib.TestCase):
    417 
    418   def testListInputs(self):
    419     l = [1, 2, 3, 11, 22, 33]
    420     l2 = inp._as_tensor_list(l)
    421     self.assertEquals(l, l2)
    422     l3 = inp._as_original_type(l, l2)
    423     self.assertEquals(l, l3)
    424 
    425   def testDictInputs(self):
    426     d = {"a": 1, "b": 2, "c": 3, "aa": 11, "bb": 22, "cc": 33}
    427     l = inp._as_tensor_list(d)
    428     self.assertEquals([1, 11, 2, 22, 3, 33], l)
    429     d2 = inp._as_original_type(d, l)
    430     self.assertEquals(d, d2)
    431 
    432   def testHeterogeneousKeysDictInputs(self):
    433     d = {"z": 1, 1: 42, ("a", "b"): 100}
    434     l = inp._as_tensor_list(d)
    435     self.assertEquals([100, 42, 1], l)
    436     d2 = inp._as_original_type(d, l)
    437     self.assertEquals(d, d2)
    438 
    439 
    440 class BatchTest(test_lib.TestCase):
    441 
    442   def _testOneThreadHelper(self, use_dict):
    443     with self.test_session() as sess:
    444       batch_size = 10
    445       num_batches = 3
    446       zero64 = constant_op.constant(0, dtype=dtypes.int64)
    447       examples = variables.Variable(zero64)
    448       counter = examples.count_up_to(num_batches * batch_size)
    449       sparse_counter = sparse_tensor.SparseTensor(
    450           indices=array_ops.reshape(
    451               array_ops.stack([zero64, zero64 + 1]), [2, 1]),
    452           values=math_ops.cast(
    453               array_ops.stack([counter, -counter]), dtypes.float32),
    454           dense_shape=[2])
    455       if use_dict:
    456         batched = inp.batch(
    457             {
    458                 "c": counter,
    459                 "s": sparse_counter,
    460                 "S": "string"
    461             },
    462             batch_size=batch_size)
    463         batched_fetch = [batched["c"], batched["s"], batched["S"]]
    464       else:
    465         batched = inp.batch(
    466             [counter, sparse_counter, "string"], batch_size=batch_size)
    467         batched_fetch = batched
    468       variables.global_variables_initializer().run()
    469       variables.local_variables_initializer().run()
    470       threads = queue_runner_impl.start_queue_runners()
    471 
    472       for i in range(num_batches):
    473         results = sess.run(batched_fetch)
    474         self.assertAllEqual(results[0],
    475                             np.arange(i * batch_size, (i + 1) * batch_size))
    476         self.assertAllEqual(
    477             results[1].indices,
    478             np.vstack((
    479                 np.arange(2 * batch_size) // 2,  # 0, 0, 1, 1, ...
    480                 [0, 1] * batch_size)).T)
    481         #  [x, -x, x+1, -(x+1), ...]
    482         expected = np.arange(2 * i * batch_size, 2 * (i + 1) * batch_size) // 2
    483         expected *= ([1, -1] * batch_size)  # mult by [1, -1, 1, -1, ...]
    484         self.assertAllEqual(results[1].values, expected)
    485         self.assertAllEqual(results[1].dense_shape, [batch_size, 2])
    486         self.assertAllEqual(results[2], [b"string"] * batch_size)
    487 
    488       # Reached the limit.
    489       with self.assertRaises(errors_impl.OutOfRangeError):
    490         sess.run(batched_fetch)
    491       for thread in threads:
    492         thread.join()
    493 
    494   def testOneThread(self):
    495     self._testOneThreadHelper(use_dict=False)
    496 
    497   def testOneThreadDict(self):
    498     self._testOneThreadHelper(use_dict=True)
    499 
    500   def testOneThreadDynamicPad(self):
    501     with self.test_session() as sess:
    502       batch_size = 10
    503       num_batches = 3
    504       zero64 = constant_op.constant(0, dtype=dtypes.int64)
    505       examples = variables.Variable(zero64)
    506       counter = examples.count_up_to(num_batches * batch_size)
    507       string = array_ops.tile(["string"],
    508                               math_ops.to_int32(array_ops.stack([counter])))
    509       variables.global_variables_initializer().run()
    510       variables.local_variables_initializer().run()
    511       batched = inp.batch(
    512           [counter, string], batch_size=batch_size, dynamic_pad=True)
    513       threads = queue_runner_impl.start_queue_runners()
    514 
    515       for i in range(num_batches):
    516         results = sess.run(batched)
    517         expected_results = np.arange(i * batch_size, (i + 1) * batch_size)
    518         max_len = expected_results[-1]
    519         self.assertAllEqual(results[0], expected_results)
    520         expected_strings = [[b"string"] * rep + [b""] * (max_len - rep)
    521                             for rep in expected_results]
    522         self.assertAllEqual(results[1], expected_strings)
    523 
    524       # Reached the limit.
    525       with self.assertRaises(errors_impl.OutOfRangeError):
    526         sess.run(batched)
    527       for thread in threads:
    528         thread.join()
    529 
    530   def testOneThreadEnqueueMany(self):
    531     with self.test_session() as sess:
    532       batch_size = 10
    533       num_batches = 3
    534       zero64 = constant_op.constant(0, dtype=dtypes.int64)
    535       examples = variables.Variable(zero64)
    536       counter = examples.count_up_to(num_batches * batch_size)
    537       sparse_counter = sparse_tensor.SparseTensor(
    538           indices=array_ops.reshape(zero64, [1, 1]),
    539           values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
    540           dense_shape=[1])
    541       pre_batched = inp.batch([counter, sparse_counter, "string"], batch_size=2)
    542       batched = inp.batch(pre_batched, enqueue_many=True, batch_size=batch_size)
    543       variables.global_variables_initializer().run()
    544       variables.local_variables_initializer().run()
    545       threads = queue_runner_impl.start_queue_runners()
    546 
    547       for i in range(num_batches):
    548         results = sess.run(batched)
    549         self.assertAllEqual(results[0],
    550                             np.arange(i * batch_size, (i + 1) * batch_size))
    551         self.assertAllEqual(
    552             results[1].indices,
    553             np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
    554         self.assertAllEqual(results[1].values,
    555                             np.arange(i * batch_size, (i + 1) * batch_size))
    556         self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
    557         self.assertAllEqual(results[2], [b"string"] * batch_size)
    558 
    559       # Reached the limit.
    560       with self.assertRaises(errors_impl.OutOfRangeError):
    561         sess.run(batched)
    562       for thread in threads:
    563         thread.join()
    564 
    565   def testManyThreads(self):
    566     with self.test_session() as sess:
    567       batch_size = 10
    568       num_batches = 3
    569       zero64 = constant_op.constant(0, dtype=dtypes.int64)
    570 
    571       examples = variables.Variable(zero64)
    572       counter = examples.count_up_to(num_batches * batch_size)
    573       sparse_counter = sparse_tensor.SparseTensor(
    574           indices=array_ops.reshape(zero64, [1, 1]),
    575           values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
    576           dense_shape=[1])
    577       batched = inp.batch(
    578           [counter, sparse_counter, "string"],
    579           batch_size=batch_size,
    580           num_threads=4)
    581       variables.global_variables_initializer().run()
    582       variables.local_variables_initializer().run()
    583       threads = queue_runner_impl.start_queue_runners()
    584 
    585       all_counts = []
    586       for i in range(num_batches):
    587         results = sess.run(batched)
    588         tf_logging.info("Batch %d: %s", i, results[0])
    589         self.assertEqual(len(results[0]), batch_size)
    590         self.assertAllEqual(results[0], results[1].values)
    591         self.assertAllEqual(
    592             results[1].indices,
    593             np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
    594         self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
    595         all_counts.extend(results[0])
    596         self.assertAllEqual(results[2], [b"string"] * batch_size)
    597       self.assertItemsEqual(all_counts, range(num_batches * batch_size))
    598 
    599       # Reached the limit.
    600       with self.assertRaises(errors_impl.OutOfRangeError):
    601         sess.run(batched)
    602       for thread in threads:
    603         thread.join()
    604 
    605   def testOneThreadSmallerBatch(self):
    606     with self.test_session() as sess:
    607       batch_size = 10
    608       num_batches = 3
    609       extra_elements = 5
    610       zero64 = constant_op.constant(0, dtype=dtypes.int64)
    611       examples = variables.Variable(zero64)
    612       counter = examples.count_up_to(num_batches * batch_size + extra_elements)
    613       sparse_counter = sparse_tensor.SparseTensor(
    614           indices=array_ops.reshape(
    615               array_ops.stack([zero64, zero64 + 1]), [2, 1]),
    616           values=math_ops.cast(
    617               array_ops.stack([counter, -counter]), dtypes.float32),
    618           dense_shape=[2])
    619       batched = inp.batch(
    620           [counter, sparse_counter, "string"],
    621           batch_size=batch_size,
    622           allow_smaller_final_batch=True)
    623       variables.global_variables_initializer().run()
    624       variables.local_variables_initializer().run()
    625       threads = queue_runner_impl.start_queue_runners()
    626 
    627       for i in range(num_batches):
    628         results = sess.run(batched)
    629         self.assertAllEqual(results[0],
    630                             np.arange(i * batch_size, (i + 1) * batch_size))
    631         self.assertAllEqual(
    632             results[1].indices,
    633             np.vstack((
    634                 np.arange(2 * batch_size) // 2,  # 0, 0, 1, 1, ...
    635                 [0, 1] * batch_size)).T)
    636         #  [x, -x, x+1, -(x+1), ...]
    637         expected = np.arange(2 * i * batch_size, 2 * (i + 1) * batch_size) // 2
    638         expected *= ([1, -1] * batch_size)  # mult by [1, -1, 1, -1, ...]
    639         self.assertAllEqual(results[1].values, expected)
    640         self.assertAllEqual(results[1].dense_shape, [batch_size, 2])
    641         self.assertAllEqual(results[2], [b"string"] * batch_size)
    642 
    643       # Reached the final batch with extra_elements.
    644       results = sess.run(batched)
    645       self.assertAllEqual(results[0],
    646                           np.arange(num_batches * batch_size,
    647                                     num_batches * batch_size + extra_elements))
    648       self.assertAllEqual(
    649           results[1].indices,
    650           np.vstack((
    651               np.arange(2 * extra_elements) // 2,  # 0, 0, 1, 1, ...
    652               [0, 1] * extra_elements)).T)
    653       self.assertAllEqual(results[1].dense_shape, [extra_elements, 2])
    654       self.assertAllEqual(results[2], [b"string"] * extra_elements)
    655 
    656       # Reached the limit.
    657       with self.assertRaises(errors_impl.OutOfRangeError):
    658         sess.run(batched)
    659       for thread in threads:
    660         thread.join()
    661 
    662   def testManyThreadsSmallerBatch(self):
    663     with self.test_session() as sess:
    664       batch_size = 10
    665       num_batches = 3
    666       extra_elements = 5
    667       zero64 = constant_op.constant(0, dtype=dtypes.int64)
    668 
    669       examples = variables.Variable(zero64)
    670       counter = examples.count_up_to(num_batches * batch_size + extra_elements)
    671       sparse_counter = sparse_tensor.SparseTensor(
    672           indices=array_ops.reshape(zero64, [1, 1]),
    673           values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
    674           dense_shape=[1])
    675       batched = inp.batch(
    676           [counter, sparse_counter, "string"],
    677           batch_size=batch_size,
    678           num_threads=4,
    679           allow_smaller_final_batch=True)
    680       variables.global_variables_initializer().run()
    681       variables.local_variables_initializer().run()
    682       threads = queue_runner_impl.start_queue_runners()
    683 
    684       all_counts = []
    685       for i in range(num_batches):
    686         results = sess.run(batched)
    687         tf_logging.info("Batch %d: %s", i, results[0])
    688         self.assertEqual(len(results[0]), batch_size)
    689         self.assertAllEqual(results[0], results[1].values)
    690         self.assertAllEqual(
    691             results[1].indices,
    692             np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
    693         self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
    694         all_counts.extend(results[0])
    695         self.assertAllEqual(results[2], [b"string"] * batch_size)
    696 
    697       # Reached the final batch with extra_elements.
    698       results = sess.run(batched)
    699       tf_logging.info("Last Batch: %s", results[0])
    700       self.assertEqual(len(results[0]), extra_elements)
    701       self.assertAllEqual(results[0], results[1].values)
    702       self.assertAllEqual(
    703           results[1].indices,
    704           np.vstack((np.arange(extra_elements), np.zeros(extra_elements))).T)
    705       self.assertAllEqual(results[1].dense_shape, [extra_elements, 1])
    706       all_counts.extend(results[0])
    707       self.assertAllEqual(results[2], [b"string"] * extra_elements)
    708       self.assertItemsEqual(all_counts,
    709                             range(num_batches * batch_size + extra_elements))
    710 
    711       # Reached the limit.
    712       with self.assertRaises(errors_impl.OutOfRangeError):
    713         sess.run(batched)
    714       for thread in threads:
    715         thread.join()
    716 
    717   def testSharedName(self):
    718     with self.test_session():
    719       batch_size = 10
    720       num_batches = 3
    721       zero64 = constant_op.constant(0, dtype=dtypes.int64)
    722       examples = variables.Variable(zero64)
    723       counter = examples.count_up_to(num_batches * batch_size)
    724       batched = inp.batch(
    725           [counter, "string"],
    726           batch_size=batch_size,
    727           shared_name="SHARED_NAME_XYZ",
    728           name="Q")
    729 
    730       self.assertProtoEquals(
    731           "s: 'SHARED_NAME_XYZ'",
    732           batched[0].op.inputs[0].op.node_def.attr["shared_name"])
    733 
    734   def testCannotInferRankError(self):
    735     with self.test_session():
    736       x = array_ops.placeholder(dtype=dtypes.int64)
    737       with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"):
    738         inp.batch([x], batch_size=2)
    739 
    740   def testBatchedSparseTensorInferredShape(self):
    741     sparse = sparse_tensor.SparseTensor(
    742         indices=[[0]], values=[1.0], dense_shape=[1])
    743     self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
    744     batched = inp.batch([sparse], batch_size=2)
    745     self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list())
    746 
    747   def testBatchedSparseTensorInferredShapeEnqueueMany(self):
    748     sparse = sparse_tensor.SparseTensor(
    749         indices=[[0]], values=[1.0], dense_shape=[1])
    750     self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
    751     batched = inp.batch([sparse], batch_size=2, enqueue_many=True)
    752     self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
    753 
    754   def testBatchedSparseTensorInferredShapeUnknownRank(self):
    755     sparse = sparse_tensor.SparseTensor(
    756         indices=array_ops.placeholder(dtypes.int64),
    757         values=array_ops.placeholder(dtypes.float32),
    758         dense_shape=array_ops.placeholder(dtypes.int64))
    759     self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
    760     batched = inp.batch([sparse], batch_size=2)
    761     self.assertIs(None, batched.dense_shape.get_shape().num_elements())
    762 
    763   def testBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
    764     sparse = sparse_tensor.SparseTensor(
    765         indices=array_ops.placeholder(dtypes.int64),
    766         values=array_ops.placeholder(dtypes.float32),
    767         dense_shape=array_ops.placeholder(dtypes.int64))
    768     self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
    769     batched = inp.batch([sparse], batch_size=2, enqueue_many=True)
    770     self.assertIs(None, batched.dense_shape.get_shape().num_elements())
    771 
    772   def testSingleElementDict(self):
    773     x = inp.batch({"c": [12, 12]}, batch_size=8)
    774     self.assertAllEqual((8, 2), x["c"].get_shape().as_list())
    775 
    776   def _testKeepInputHelper(self, num_threads, enqueue_many,
    777                            keep_input_vector=False):
    778     with self.test_session() as sess:
    779       batch_size = 5
    780       num_batches = 4
    781       examples = variables.Variable(0)
    782       counter = examples.count_up_to(num_batches * batch_size * 2)
    783       sparse_counter = sparse_tensor.SparseTensor(
    784           indices=array_ops.zeros(
    785               [1, 1], dtype=dtypes.int64),
    786           values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
    787           dense_shape=[1])
    788       to_batch = [counter, sparse_counter, "string"]
    789       if enqueue_many:
    790         to_batch = inp.batch(to_batch, 4 if keep_input_vector else 1)
    791       keep_input = array_ops.squeeze(
    792           math_ops.equal(0, math_ops.mod(to_batch[0], 2)))
    793       batched = inp.maybe_batch(
    794           to_batch,
    795           keep_input,
    796           batch_size,
    797           num_threads=num_threads,
    798           enqueue_many=enqueue_many)
    799       variables.initialize_all_variables().run()
    800       variables.initialize_local_variables().run()
    801       threads = queue_runner_impl.start_queue_runners()
    802 
    803       for _ in range(num_batches):
    804         results = sess.run(batched)
    805         self.assertAllEqual([0] * batch_size, np.mod(results[0], 2))
    806         self.assertAllEqual([0] * batch_size, np.mod(results[1].values, 2))
    807         self.assertAllEqual([b"string"] * batch_size, results[2])
    808 
    809       # Reached the limit.
    810       with self.assertRaises(errors_impl.OutOfRangeError):
    811         sess.run(batched)
    812       for thread in threads:
    813         thread.join()
    814 
    815   def testSingleThreadKeepInput(self):
    816     self._testKeepInputHelper(1, False)
    817 
    818   def testSingleThreadKeepInputEnqueueMany(self):
    819     self._testKeepInputHelper(1, True)
    820 
    821   def testMultipleThreadKeepInput(self):
    822     self._testKeepInputHelper(5, False)
    823 
    824   def testMultipleThreadKeepInputEnqueueMany(self):
    825     self._testKeepInputHelper(5, True)
    826 
    827   def testMaybeEnqueuePerExample(self):
    828     self._testKeepInputHelper(1, True, keep_input_vector=True)
    829 
    830   def testMultipleThreadMaybeEnqueuePerExample(self):
    831     self._testKeepInputHelper(5, True, keep_input_vector=True)
    832 
    833   def testInvalidKeepInputVector(self):
    834     # Can't have vector `keep_input` with `enqueue_many=False`.
    835     with self.assertRaisesRegexp(ValueError, "`keep_input` cannot be a vector"):
    836       inp.maybe_batch([array_ops.zeros(5)],
    837                       keep_input=constant_op.constant([True, False]),
    838                       batch_size=1,
    839                       enqueue_many=False)
    840     # Can't have `keep_input` with more than one dimension.
    841     with self.assertRaisesRegexp(ValueError, "must be 0 or 1 dimensions"):
    842       inp.maybe_batch([array_ops.zeros(5)],
    843                       keep_input=constant_op.constant([[True], [False]]),
    844                       batch_size=1,
    845                       enqueue_many=True)
    846     # `keep_input` must have dimensions determined at graph construction.
    847     with self.assertRaisesRegexp(ValueError,
    848                                  "must be known at graph construction"):
    849       inp.maybe_batch([array_ops.zeros(5)],
    850                       keep_input=array_ops.placeholder(dtypes.bool),
    851                       batch_size=1,
    852                       enqueue_many=True)
    853 
    854   def testMaybeBatchedSparseTensorInferredShape(self):
    855     sparse = sparse_tensor.SparseTensor(
    856         indices=[[0]], values=[1.0], dense_shape=[1])
    857     self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
    858     batched = inp.maybe_batch([sparse], keep_input=True, batch_size=2)
    859     self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list())
    860 
    861   def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self):
    862     sparse = sparse_tensor.SparseTensor(
    863         indices=[[0]], values=[1.0], dense_shape=[1])
    864     self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
    865     batched = inp.maybe_batch(
    866         [sparse], keep_input=True, batch_size=2, enqueue_many=True)
    867     self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
    868 
    869   def testMaybeBatchedSparseTensorInferredShapeEnqueueManyPerExample(self):
    870     sparse = sparse_tensor.SparseTensor(
    871         indices=[[0], [0]], values=[1.0, 2.0], dense_shape=[2])
    872     self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
    873     batched = inp.maybe_batch(
    874         [sparse], keep_input=[True, False], batch_size=2, enqueue_many=True)
    875     self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
    876 
    877   def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self):
    878     sparse = sparse_tensor.SparseTensor(
    879         indices=array_ops.placeholder(dtypes.int64),
    880         values=array_ops.placeholder(dtypes.float32),
    881         dense_shape=array_ops.placeholder(dtypes.int64))
    882     self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
    883     batched = inp.maybe_batch([sparse], keep_input=True, batch_size=2)
    884     self.assertIs(None, batched.dense_shape.get_shape().num_elements())
    885 
    886   def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
    887     sparse = sparse_tensor.SparseTensor(
    888         indices=array_ops.placeholder(dtypes.int64),
    889         values=array_ops.placeholder(dtypes.float32),
    890         dense_shape=array_ops.placeholder(dtypes.int64))
    891     self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
    892     batched = inp.maybe_batch(
    893         [sparse], keep_input=True, batch_size=2, enqueue_many=True)
    894     self.assertIs(None, batched.dense_shape.get_shape().num_elements())
    895 
    896   def testMaybeBatchedSparseTensorInferredShapeUnknownRankPerExample(self):
    897     sparse = sparse_tensor.SparseTensor(
    898         indices=array_ops.placeholder(dtypes.int64),
    899         values=array_ops.placeholder(dtypes.float32),
    900         dense_shape=array_ops.placeholder(dtypes.int64))
    901     self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
    902     batched = inp.maybe_batch(
    903         [sparse], keep_input=[True, False], batch_size=2, enqueue_many=True)
    904     self.assertIs(None, batched.dense_shape.get_shape().num_elements())
    905 
    906   def testMaybeBatchCorrectValues(self):
    907     sparse_t = sparse_tensor.SparseTensor(
    908         indices=[[0, 1], [0, 2], [1, 0], [1, 3]],
    909         dense_shape=[2, 4],
    910         values=[5, 4, 7, 2])
    911     keep = constant_op.constant([True, False])
    912     batched = inp.maybe_batch(
    913         [sparse_t], keep_input=keep, batch_size=1, enqueue_many=True)
    914 
    915     with self.test_session():
    916       coord = coordinator.Coordinator()
    917       threads = queue_runner_impl.start_queue_runners(coord=coord)
    918 
    919       batched_np = batched.eval()
    920 
    921       coord.request_stop()
    922       for thread in threads:
    923         thread.join()
    924 
    925     self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices)
    926     self.assertAllEqual([5, 4], batched_np.values)
    927     self.assertAllEqual([1, 4], batched_np.dense_shape)
    928 
    929 
    930 class BatchJoinTest(test_lib.TestCase):
    931 
    932   def _testTwoThreadsHelper(self, use_dict):
    933     with self.test_session() as sess:
    934       # Two threads, the first generates (0..69, "a").
    935       num_a = 70
    936       zero64 = constant_op.constant(0, dtype=dtypes.int64)
    937       examples = variables.Variable(zero64)
    938       counter = examples.count_up_to(num_a)
    939       sparse_counter = sparse_tensor.SparseTensor(
    940           indices=array_ops.reshape(zero64, [1, 1]),
    941           values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
    942           dense_shape=[1])
    943 
    944       # The second generates (99, "b") 90 times and then stops.
    945       num_b = 90
    946       ninety_nine = inp.limit_epochs(
    947           constant_op.constant(
    948               99, dtype=dtypes.int64), num_b)
    949       sparse_ninety_nine = sparse_tensor.SparseTensor(
    950           indices=array_ops.reshape(zero64, [1, 1]),
    951           values=array_ops.stack([math_ops.cast(ninety_nine, dtypes.float32)]),
    952           dense_shape=[1])
    953 
    954       # These get joined together and grouped into batches of 5.
    955       batch_size = 5
    956       if use_dict:
    957         batched = inp.batch_join(
    958             [{
    959                 "c": counter,
    960                 "s": sparse_counter,
    961                 "S": "a"
    962             }, {
    963                 "c": ninety_nine,
    964                 "s": sparse_ninety_nine,
    965                 "S": "b"
    966             }],
    967             batch_size=batch_size)
    968         batched_fetch = [batched["c"], batched["s"], batched["S"]]
    969       else:
    970         batched = inp.batch_join(
    971             [[counter, sparse_counter, "a"],
    972              [ninety_nine, sparse_ninety_nine, "b"]],
    973             batch_size=batch_size)
    974         batched_fetch = batched
    975 
    976       # Shapes.
    977       self.assertEqual(3, len(batched_fetch))
    978       self.assertAllEqual((batch_size,), batched_fetch[0].get_shape().as_list())
    979       self.assertAllEqual((None, 2),
    980                           batched_fetch[1].indices.get_shape().as_list())
    981       self.assertAllEqual((None,),
    982                           batched_fetch[1].values.get_shape().as_list())
    983       self.assertAllEqual((2,),
    984                           batched_fetch[1].dense_shape.get_shape().as_list())
    985       self.assertAllEqual((batch_size,), batched_fetch[2].get_shape().as_list())
    986 
    987       variables.global_variables_initializer().run()
    988       variables.local_variables_initializer().run()
    989       threads = queue_runner_impl.start_queue_runners()
    990 
    991       # Should see the "a" and "b" threads mixed together.
    992       all_a = []
    993       seen_b = 0
    994       saw_both = 0
    995       num_batches = (num_a + num_b) // batch_size
    996       for i in range(num_batches):
    997         results = sess.run(batched_fetch)
    998         self.assertEqual(3, len(results))
    999         self.assertEqual(batch_size, len(results[0]))
   1000         self.assertEqual(batch_size, len(results[2]))
   1001         self.assertAllEqual(results[0], results[1].values)
   1002         self.assertAllEqual(
   1003             results[1].indices,
   1004             np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
   1005         self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
   1006         which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
   1007         which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
   1008         self.assertEqual(len(which_a) + len(which_b), batch_size)
   1009         if which_a and which_b:
   1010           saw_both += 1
   1011         all_a.extend([results[0][i] for i in which_a])
   1012         seen_b += len(which_b)
   1013         self.assertAllEqual([99] * len(which_b),
   1014                             [results[0][i] for i in which_b])
   1015 
   1016       # Some minimum level of mixing of the results of both threads.
   1017       self.assertGreater(saw_both, 1)
   1018 
   1019       # Verify the order of results from "a" were preserved.
   1020       self.assertAllEqual(all_a, np.arange(num_a))
   1021       self.assertEqual(seen_b, num_b)
   1022 
   1023       # Reached the limit.
   1024       with self.assertRaises(errors_impl.OutOfRangeError):
   1025         sess.run(batched_fetch)
   1026       for thread in threads:
   1027         thread.join()
   1028 
   1029   def DISABLED_testTwoThreads(self):
   1030     self._testTwoThreadsHelper(use_dict=False)
   1031 
   1032   def DISABLED_testTwoThreadsDict(self):
   1033     self._testTwoThreadsHelper(use_dict=True)
   1034 
   1035   def testMismatchedDictKeys(self):
   1036     with self.assertRaisesRegexp(ValueError, "must have the same keys"):
   1037       inp.batch_join(
   1038           [{
   1039               "c": 12,
   1040               "s": 123,
   1041               "S": "a"
   1042           }, {
   1043               "cool": -12,
   1044               "s": 99,
   1045               "S": "b"
   1046           }],
   1047           batch_size=8)
   1048 
   1049   def DISABLED_testTwoThreadsDynamicPad(self):
   1050     with self.test_session() as sess:
   1051       # Two threads, the first generates (0..69, ["a"] * 1..70).
   1052       num_a = 70
   1053       zero64 = constant_op.constant(0, dtype=dtypes.int64)
   1054       examples = variables.Variable(zero64)
   1055       counter = examples.count_up_to(num_a)
   1056 
   1057       # The second generates (99, ["b"] * 99) 90 times and then stops.
   1058       num_b = 90
   1059       ninety_nine = inp.limit_epochs(
   1060           constant_op.constant(
   1061               99, dtype=dtypes.int64), num_b)
   1062 
   1063       # These get joined together and grouped into batches of 5.
   1064       batch_size = 5
   1065       a = array_ops.tile(["a"],
   1066                          math_ops.to_int32(array_ops.stack([counter + 1])))
   1067       b = array_ops.tile(["b"],
   1068                          math_ops.to_int32(array_ops.stack([ninety_nine])))
   1069       batched = inp.batch_join(
   1070           [[counter, a], [ninety_nine, b]],
   1071           batch_size=batch_size,
   1072           dynamic_pad=True)
   1073 
   1074       # Shapes.
   1075       self.assertEqual(2, len(batched))
   1076       self.assertAllEqual((batch_size,), batched[0].get_shape().as_list())
   1077       self.assertAllEqual((batch_size, None), batched[1].get_shape().as_list())
   1078 
   1079       variables.global_variables_initializer().run()
   1080       variables.local_variables_initializer().run()
   1081       threads = queue_runner_impl.start_queue_runners()
   1082 
   1083       # Should see the "a" and "b" threads mixed together.
   1084       all_a = []
   1085       count_string_a = []
   1086       seen_b = 0
   1087       saw_both = 0
   1088       num_batches = (num_a + num_b) // batch_size
   1089       for i in range(num_batches):
   1090         results = sess.run(batched)
   1091         self.assertEqual(2, len(results))
   1092         self.assertEqual(len(results[0]), batch_size)
   1093         self.assertEqual(len(results[1]), batch_size)
   1094         for s in results[1]:
   1095           if s[0] == b"b":
   1096             self.assertAllEqual(s, [b"b"] * 99)
   1097           else:
   1098             count_string_a.append(sum(x == b"a" for x in s))
   1099         which_a = [i for i, s in enumerate(results[1]) if s[0] == b"a"]
   1100         which_b = [i for i, s in enumerate(results[1]) if s[0] == b"b"]
   1101         self.assertEqual(len(which_a) + len(which_b), batch_size)
   1102         if which_a and which_b:
   1103           saw_both += 1
   1104         all_a.extend([results[0][i] for i in which_a])
   1105         seen_b += len(which_b)
   1106         self.assertAllEqual([99] * len(which_b),
   1107                             [results[0][i] for i in which_b])
   1108 
   1109       # Some minimum level of mixing of the results of both threads.
   1110       self.assertGreater(saw_both, 1)
   1111 
   1112       # Verify the order of results from "a" were preserved.
   1113       self.assertAllEqual(  # tiled "a" with counter + 1
   1114           count_string_a, np.arange(num_a) + 1)
   1115       self.assertAllEqual(all_a, np.arange(num_a))
   1116       self.assertEqual(seen_b, num_b)
   1117 
   1118       # Reached the limit.
   1119       with self.assertRaises(errors_impl.OutOfRangeError):
   1120         sess.run(batched)
   1121       for thread in threads:
   1122         thread.join()
   1123 
   1124   def DISABLED_testTwoThreadsSmallerBatch(self):
   1125     with self.test_session() as sess:
   1126       extra_elements = 2
   1127       # Two threads, the first generates (0..69, "a").
   1128       num_a = 70 + extra_elements
   1129       zero64 = constant_op.constant(0, dtype=dtypes.int64)
   1130       examples = variables.Variable(zero64)
   1131       counter = examples.count_up_to(num_a)
   1132       sparse_counter = sparse_tensor.SparseTensor(
   1133           indices=array_ops.reshape(zero64, [1, 1]),
   1134           values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
   1135           dense_shape=[1])
   1136 
   1137       # The second generates (99, "b") 90 times and then stops.
   1138       num_b = 90 + extra_elements
   1139       ninety_nine = inp.limit_epochs(
   1140           constant_op.constant(
   1141               99, dtype=dtypes.int64), num_b)
   1142       sparse_ninety_nine = sparse_tensor.SparseTensor(
   1143           indices=array_ops.reshape(zero64, [1, 1]),
   1144           values=array_ops.stack([math_ops.cast(ninety_nine, dtypes.float32)]),
   1145           dense_shape=[1])
   1146 
   1147       # These get joined together and grouped into batches of 5.
   1148       batch_size = 5
   1149       batched = inp.batch_join(
   1150           [[counter, sparse_counter, "a"],
   1151            [ninety_nine, sparse_ninety_nine, "b"]],
   1152           batch_size=batch_size,
   1153           allow_smaller_final_batch=True)
   1154 
   1155       # Shapes.
   1156       self.assertEqual(3, len(batched))
   1157       self.assertAllEqual((None,), batched[0].get_shape().as_list())
   1158       self.assertAllEqual((None, 2), batched[1].indices.get_shape().as_list())
   1159       self.assertAllEqual((None,), batched[1].values.get_shape().as_list())
   1160       self.assertAllEqual((2,), batched[1].dense_shape.get_shape().as_list())
   1161       self.assertAllEqual((None,), batched[2].get_shape().as_list())
   1162 
   1163       variables.global_variables_initializer().run()
   1164       variables.local_variables_initializer().run()
   1165       threads = queue_runner_impl.start_queue_runners()
   1166 
   1167       # Should see the "a" and "b" threads mixed together.
   1168       all_a = []
   1169       seen_b = 0
   1170       saw_both = 0
   1171       num_batches = (num_a + num_b) // batch_size
   1172       for i in range(num_batches):
   1173         results = sess.run(batched)
   1174         tf_logging.info("Batch %d: %s", i, results[0])
   1175         self.assertEqual(len(results[0]), batch_size)
   1176         self.assertEqual(len(results[2]), batch_size)
   1177         self.assertAllEqual(results[0], results[1].values)
   1178         self.assertAllEqual(
   1179             results[1].indices,
   1180             np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
   1181         self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
   1182         which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
   1183         which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
   1184         self.assertEqual(len(which_a) + len(which_b), batch_size)
   1185         if which_a and which_b:
   1186           saw_both += 1
   1187         all_a.extend([results[0][i] for i in which_a])
   1188         seen_b += len(which_b)
   1189         self.assertAllEqual([99] * len(which_b),
   1190                             [results[0][i] for i in which_b])
   1191 
   1192       # Reached the final batch with 2 * extra_elements.
   1193       results = sess.run(batched)
   1194       tf_logging.info("Last Batch: %s", results[0])
   1195       self.assertEqual(len(results[0]), 2 * extra_elements)
   1196       self.assertEqual(len(results[2]), 2 * extra_elements)
   1197       self.assertAllEqual(results[0], results[1].values)
   1198       self.assertAllEqual(results[1].indices,
   1199                           np.vstack((np.arange(2 * extra_elements),
   1200                                      np.zeros(2 * extra_elements))).T)
   1201       self.assertAllEqual(results[1].dense_shape, [2 * extra_elements, 1])
   1202       which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
   1203       which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
   1204       self.assertEqual(len(which_a) + len(which_b), 2 * extra_elements)
   1205       if which_a and which_b:
   1206         saw_both += 1
   1207       all_a.extend([results[0][i] for i in which_a])
   1208       seen_b += len(which_b)
   1209 
   1210       # Some minimum level of mixing of the results of both threads.
   1211       self.assertGreater(saw_both, 1)
   1212 
   1213       # Verify the order of results from "a" were preserved.
   1214       self.assertAllEqual(all_a, np.arange(num_a))
   1215       self.assertEqual(seen_b, num_b)
   1216 
   1217       # Reached the limit.
   1218       with self.assertRaises(errors_impl.OutOfRangeError):
   1219         sess.run(batched)
   1220       for thread in threads:
   1221         thread.join()
   1222 
   1223   def DISABLED_testTwoThreadsDynamicPadSmallerBatch(self):
   1224     with self.test_session() as sess:
   1225       extra_elements = 2
   1226       # Two threads, the first generates (0..69, ["a"] * 1..70).
   1227       num_a = 70 + extra_elements
   1228       zero64 = constant_op.constant(0, dtype=dtypes.int64)
   1229       examples = variables.Variable(zero64)
   1230       counter = examples.count_up_to(num_a)
   1231 
   1232       # The second generates (99, ["b"] * 99) 90 times and then stops.
   1233       num_b = 90 + extra_elements
   1234       ninety_nine = inp.limit_epochs(
   1235           constant_op.constant(
   1236               99, dtype=dtypes.int64), num_b)
   1237 
   1238       # These get joined together and grouped into batches of 5.
   1239       batch_size = 5
   1240       a = array_ops.tile(["a"],
   1241                          math_ops.to_int32(array_ops.stack([counter + 1])))
   1242       b = array_ops.tile(["b"],
   1243                          math_ops.to_int32(array_ops.stack([ninety_nine])))
   1244       batched = inp.batch_join(
   1245           [[counter, a], [ninety_nine, b]],
   1246           batch_size=batch_size,
   1247           dynamic_pad=True,
   1248           allow_smaller_final_batch=True)
   1249 
   1250       # Shapes.
   1251       self.assertEqual(2, len(batched))
   1252       self.assertAllEqual((None,), batched[0].get_shape().as_list())
   1253       self.assertAllEqual((None, None), batched[1].get_shape().as_list())
   1254 
   1255       variables.global_variables_initializer().run()
   1256       variables.local_variables_initializer().run()
   1257       threads = queue_runner_impl.start_queue_runners()
   1258 
   1259       # Should see the "a" and "b" threads mixed together.
   1260       all_a = []
   1261       count_string_a = []
   1262       seen_b = 0
   1263       saw_both = 0
   1264       num_batches = (num_a + num_b) // batch_size
   1265       for i in range(num_batches):
   1266         results = sess.run(batched)
   1267         tf_logging.info("Batch %d: %s", i, results[0])
   1268         self.assertEqual(len(results[0]), batch_size)
   1269         self.assertEqual(len(results[1]), batch_size)
   1270         for s in results[1]:
   1271           if s[0] == b"b":
   1272             self.assertAllEqual(s, [b"b"] * 99)
   1273           else:
   1274             count_string_a.append(sum(x == b"a" for x in s))
   1275         which_a = [i for i, s in enumerate(results[1]) if s[0] == b"a"]
   1276         which_b = [i for i, s in enumerate(results[1]) if s[0] == b"b"]
   1277         self.assertEqual(len(which_a) + len(which_b), batch_size)
   1278         if which_a and which_b:
   1279           saw_both += 1
   1280         all_a.extend([results[0][i] for i in which_a])
   1281         seen_b += len(which_b)
   1282         self.assertAllEqual([99] * len(which_b),
   1283                             [results[0][i] for i in which_b])
   1284 
   1285       # Reached the final batch with 2 * extra_elements.
   1286       results = sess.run(batched)
   1287       tf_logging.info("Last Batch: %s", results[0])
   1288       self.assertEqual(len(results[0]), 2 * extra_elements)
   1289       self.assertEqual(len(results[1]), 2 * extra_elements)
   1290       for s in results[1]:
   1291         if s[0] == b"b":
   1292           self.assertAllEqual(s, [b"b"] * 99)
   1293         else:
   1294           count_string_a.append(sum(x == b"a" for x in s))
   1295       which_a = [i for i, s in enumerate(results[1]) if s[0] == b"a"]
   1296       which_b = [i for i, s in enumerate(results[1]) if s[0] == b"b"]
   1297       self.assertEqual(len(which_a) + len(which_b), 2 * extra_elements)
   1298       if which_a and which_b:
   1299         saw_both += 1
   1300       all_a.extend([results[0][i] for i in which_a])
   1301       seen_b += len(which_b)
   1302 
   1303       # Some minimum level of mixing of the results of both threads.
   1304       self.assertGreater(saw_both, 1)
   1305 
   1306       # Verify the order of results from "a" were preserved.
   1307       self.assertAllEqual(  # tiled "a" with counter + 1
   1308           count_string_a, np.arange(num_a) + 1)
   1309       self.assertAllEqual(all_a, np.arange(num_a))
   1310       self.assertEqual(seen_b, num_b)
   1311 
   1312       # Reached the limit.
   1313       with self.assertRaises(errors_impl.OutOfRangeError):
   1314         sess.run(batched)
   1315       for thread in threads:
   1316         thread.join()
   1317 
   1318   def testSharedName(self):
   1319     with self.test_session():
   1320       batch_size = 10
   1321       num_batches = 3
   1322       zero64 = constant_op.constant(0, dtype=dtypes.int64)
   1323       examples = variables.Variable(zero64)
   1324       counter = examples.count_up_to(num_batches * batch_size)
   1325       batched = inp.batch_join(
   1326           [[counter, "string"]],
   1327           batch_size=batch_size,
   1328           shared_name="SHARED_NAME_XYZ",
   1329           name="Q")
   1330 
   1331       # Shapes.
   1332       self.assertEqual(2, len(batched))
   1333       self.assertAllEqual((batch_size,), batched[0].get_shape().as_list())
   1334       self.assertAllEqual((batch_size,), batched[1].get_shape().as_list())
   1335 
   1336       self.assertProtoEquals(
   1337           "s: 'SHARED_NAME_XYZ'",
   1338           batched[0].op.inputs[0].op.node_def.attr["shared_name"])
   1339 
   1340   def testCannotInferRankError(self):
   1341     with self.test_session():
   1342       x = array_ops.placeholder(dtype=dtypes.int64)
   1343       with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"):
   1344         inp.batch_join([[x]], batch_size=2)
   1345 
   1346   def testSingleElementDict(self):
   1347     x = inp.batch_join([{"c": [12, 12]}], batch_size=8)
   1348     self.assertAllEqual((8, 2), x["c"].get_shape().as_list())
   1349 
   1350   def _testKeepInputHelper(self, num_threads, enqueue_many,
   1351                            keep_input_vector=False):
   1352     with self.test_session() as sess:
   1353       batch_size = 5
   1354       num_batches = 4
   1355       examples = variables.Variable(0)
   1356       counter = examples.count_up_to(num_batches * batch_size * 2)
   1357       sparse_counter = sparse_tensor.SparseTensor(
   1358           indices=array_ops.zeros(
   1359               [1, 1], dtype=dtypes.int64),
   1360           values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
   1361           dense_shape=[1])
   1362       to_batch = [counter, sparse_counter, "string"]
   1363       if enqueue_many:
   1364         to_batch = inp.batch(to_batch, 4 if keep_input_vector else 1)
   1365       keep_input = array_ops.squeeze(
   1366           math_ops.equal(0, math_ops.mod(to_batch[0], 2)))
   1367       batched = inp.maybe_batch_join(
   1368           [to_batch] * num_threads,
   1369           keep_input,
   1370           batch_size,
   1371           enqueue_many=enqueue_many)
   1372       variables.initialize_all_variables().run()
   1373       variables.initialize_local_variables().run()
   1374       threads = queue_runner_impl.start_queue_runners()
   1375 
   1376       for _ in range(num_batches):
   1377         results = sess.run(batched)
   1378         self.assertAllEqual(
   1379             [0] * batch_size,
   1380             np.mod(results[0], 2),)
   1381         self.assertAllEqual(
   1382             [0] * batch_size,
   1383             np.mod(results[1].values, 2),)
   1384         self.assertAllEqual([b"string"] * batch_size, results[2])
   1385 
   1386       # Reached the limit.
   1387       with self.assertRaises(errors_impl.OutOfRangeError):
   1388         sess.run(batched)
   1389       for thread in threads:
   1390         thread.join()
   1391 
   1392   def testSingleThreadKeepInput(self):
   1393     self._testKeepInputHelper(1, False)
   1394 
   1395   def testSingleThreadKeepInputEnqueueMany(self):
   1396     self._testKeepInputHelper(1, True)
   1397 
   1398   def testMultipleThreadKeepInput(self):
   1399     self._testKeepInputHelper(5, False)
   1400 
   1401   def testMultipleThreadKeepInputEnqueueMany(self):
   1402     self._testKeepInputHelper(5, True)
   1403 
   1404   def testSingleThreadKeepInputPerExample(self):
   1405     self._testKeepInputHelper(1, True, keep_input_vector=True)
   1406 
   1407   def testMultipleThreadKeepInputPerExample(self):
   1408     self._testKeepInputHelper(5, True, keep_input_vector=True)
   1409 
   1410   def testInvalidKeepInputVector(self):
   1411     # Can't have vector `keep_input` with `enqueue_many=False`.
   1412     with self.assertRaisesRegexp(ValueError, "`keep_input` cannot be a vector"):
   1413       inp.maybe_batch_join([[array_ops.zeros(5)]],
   1414                            keep_input=constant_op.constant([True, False]),
   1415                            batch_size=1,
   1416                            enqueue_many=False)
   1417     # Can't have `keep_input` with more than one dimension.
   1418     with self.assertRaisesRegexp(ValueError, "must be 0 or 1 dimensions"):
   1419       inp.maybe_batch_join([[array_ops.zeros(5)]],
   1420                            keep_input=constant_op.constant([[True], [False]]),
   1421                            batch_size=1,
   1422                            enqueue_many=True)
   1423     # `keep_input` must have dimensions determined at graph construction.
   1424     with self.assertRaisesRegexp(ValueError,
   1425                                  "must be known at graph construction"):
   1426       inp.maybe_batch_join([[array_ops.zeros(5)]],
   1427                            keep_input=array_ops.placeholder(dtypes.bool),
   1428                            batch_size=1,
   1429                            enqueue_many=True)
   1430 
   1431   def testMaybeBatchedSparseTensorInferredShape(self):
   1432     sparse = sparse_tensor.SparseTensor(
   1433         indices=[[0]], values=[1.0], dense_shape=[1])
   1434     self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
   1435     batched = inp.maybe_batch_join([[sparse]], keep_input=True, batch_size=2)
   1436     self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list())
   1437 
   1438   def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self):
   1439     sparse = sparse_tensor.SparseTensor(
   1440         indices=[[0]], values=[1.0], dense_shape=[1])
   1441     self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
   1442     batched = inp.maybe_batch_join(
   1443         [[sparse]], keep_input=True, batch_size=2, enqueue_many=True)
   1444     self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
   1445 
   1446   def testMaybeBatchedSparseTensorInferredShapeEnqueueManyPerExample(self):
   1447     sparse = sparse_tensor.SparseTensor(
   1448         indices=[[0], [0]], values=[1.0, 2.0], dense_shape=[2])
   1449     self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
   1450     batched = inp.maybe_batch_join(
   1451         [[sparse]], keep_input=[True, False], batch_size=2, enqueue_many=True)
   1452     self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
   1453 
   1454   def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self):
   1455     sparse = sparse_tensor.SparseTensor(
   1456         indices=array_ops.placeholder(dtypes.int64),
   1457         values=array_ops.placeholder(dtypes.float32),
   1458         dense_shape=array_ops.placeholder(dtypes.int64))
   1459     self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
   1460     batched = inp.maybe_batch_join([[sparse]], keep_input=True, batch_size=2)
   1461     self.assertIs(None, batched.dense_shape.get_shape().num_elements())
   1462 
   1463   def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
   1464     sparse = sparse_tensor.SparseTensor(
   1465         indices=array_ops.placeholder(dtypes.int64),
   1466         values=array_ops.placeholder(dtypes.float32),
   1467         dense_shape=array_ops.placeholder(dtypes.int64))
   1468     self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
   1469     batched = inp.maybe_batch_join(
   1470         [[sparse]], keep_input=True, batch_size=2, enqueue_many=True)
   1471     self.assertIs(None, batched.dense_shape.get_shape().num_elements())
   1472 
   1473   def testMaybeBatchedSparseTensorInferredShapeUnknownRankPerExample(self):
   1474     sparse = sparse_tensor.SparseTensor(
   1475         indices=array_ops.placeholder(dtypes.int64),
   1476         values=array_ops.placeholder(dtypes.float32),
   1477         dense_shape=array_ops.placeholder(dtypes.int64))
   1478     self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
   1479     batched = inp.maybe_batch_join(
   1480         [[sparse]], keep_input=[True, False], batch_size=2, enqueue_many=True)
   1481     self.assertIs(None, batched.dense_shape.get_shape().num_elements())
   1482 
   1483   def testMaybeBatchCorrectValues(self):
   1484     sparse = sparse_tensor.SparseTensor(
   1485         indices=[[0, 1], [0, 2], [1, 0], [1, 3]],
   1486         dense_shape=[2, 4],
   1487         values=[5, 4, 7, 2])
   1488     keep = constant_op.constant([True, False])
   1489     batched = inp.maybe_batch_join(
   1490         [[sparse]], keep_input=keep, batch_size=1, enqueue_many=True)
   1491 
   1492     with self.test_session():
   1493       coord = coordinator.Coordinator()
   1494       threads = queue_runner_impl.start_queue_runners(coord=coord)
   1495 
   1496       batched_np = batched.eval()
   1497 
   1498       coord.request_stop()
   1499       for thread in threads:
   1500         thread.join()
   1501 
   1502     self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices)
   1503     self.assertAllEqual([5, 4], batched_np.values)
   1504     self.assertAllEqual([1, 4], batched_np.dense_shape)
   1505 
   1506 
   1507 class ShuffleBatchTest(test_lib.TestCase):
   1508 
   1509   def _testOneThreadHelper(self, use_dict):
   1510     with self.test_session() as sess:
   1511       batch_size = 10
   1512       num_batches = 3
   1513       zero64 = constant_op.constant(0, dtype=dtypes.int64)
   1514       examples = variables.Variable(zero64)
   1515       counter = examples.count_up_to(num_batches * batch_size)
   1516       sparse_counter = sparse_tensor.SparseTensor(
   1517           indices=array_ops.reshape(zero64, [1, 1]),
   1518           values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
   1519           dense_shape=[1])
   1520       if use_dict:
   1521         batched = inp.shuffle_batch(
   1522             {
   1523                 "c": counter,
   1524                 "s": sparse_counter,
   1525                 "S": "string"
   1526             },
   1527             batch_size=batch_size,
   1528             capacity=32,
   1529             min_after_dequeue=16,
   1530             seed=141421)
   1531         batched_fetch = [batched["c"], batched["s"], batched["S"]]
   1532       else:
   1533         batched = inp.shuffle_batch(
   1534             [counter, sparse_counter, "string"],
   1535             batch_size=batch_size,
   1536             capacity=32,
   1537             min_after_dequeue=16,
   1538             seed=141421)
   1539         batched_fetch = batched
   1540       variables.global_variables_initializer().run()
   1541       variables.local_variables_initializer().run()
   1542       threads = queue_runner_impl.start_queue_runners()
   1543 
   1544       all_counts = []
   1545       for i in range(num_batches):
   1546         results = sess.run(batched_fetch)
   1547         self.assertEqual(len(results[0]), batch_size)
   1548         all_counts.extend(results[0])
   1549         self.assertAllEqual(
   1550             results[1].indices,
   1551             np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
   1552         self.assertAllEqual(results[0], results[1].values)
   1553         self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
   1554         self.assertAllEqual(results[2], [b"string"] * batch_size)
   1555       # Results scrambled, but include all the expected numbers.
   1556       deltas = [
   1557           all_counts[i + 1] - all_counts[i] for i in range(len(all_counts) - 1)
   1558       ]
   1559       self.assertFalse(all(d == deltas[0] for d in deltas))
   1560       self.assertItemsEqual(all_counts, range(num_batches * batch_size))
   1561 
   1562       # Reached the limit.
   1563       with self.assertRaises(errors_impl.OutOfRangeError):
   1564         sess.run(batched_fetch)
   1565       for thread in threads:
   1566         thread.join()
   1567 
   1568   def testOneThread(self):
   1569     self._testOneThreadHelper(use_dict=False)
   1570 
   1571   def testOneThreadDict(self):
   1572     self._testOneThreadHelper(use_dict=True)
   1573 
   1574   def testOneThreadSmallerBatch(self):
   1575     with self.test_session() as sess:
   1576       batch_size = 10
   1577       num_batches = 3
   1578       extra_elements = 5
   1579       zero64 = constant_op.constant(0, dtype=dtypes.int64)
   1580       examples = variables.Variable(zero64)
   1581       total_elements = num_batches * batch_size + extra_elements
   1582       counter = examples.count_up_to(total_elements)
   1583       sparse_counter = sparse_tensor.SparseTensor(
   1584           indices=array_ops.reshape(zero64, [1, 1]),
   1585           values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
   1586           dense_shape=[1])
   1587       batched = inp.shuffle_batch(
   1588           [counter, sparse_counter, "string"],
   1589           batch_size=batch_size,
   1590           capacity=32,
   1591           min_after_dequeue=16,
   1592           seed=141421,
   1593           allow_smaller_final_batch=True)
   1594       batched_fetch = batched
   1595       variables.global_variables_initializer().run()
   1596       variables.local_variables_initializer().run()
   1597       threads = queue_runner_impl.start_queue_runners()
   1598 
   1599       all_counts = []
   1600       for _ in range(num_batches):
   1601         results = sess.run(batched_fetch)
   1602         self.assertEqual(len(results[0]), batch_size)
   1603         all_counts.extend(results[0])
   1604         self.assertAllEqual(
   1605             results[1].indices,
   1606             np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
   1607         self.assertAllEqual(results[0], results[1].values)
   1608         self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
   1609         self.assertAllEqual(results[2], [b"string"] * batch_size)
   1610 
   1611       # Reached the final batch with extra elements.
   1612       results = sess.run(batched)
   1613       self.assertAllEqual(results[1].dense_shape, [extra_elements, 1])
   1614       self.assertAllEqual(results[2], [b"string"] * extra_elements)
   1615       all_counts.extend(results[0])
   1616 
   1617       # Results scrambled, but include all the expected numbers.
   1618       deltas = [
   1619           all_counts[i + 1] - all_counts[i] for i in range(len(all_counts) - 1)
   1620       ]
   1621       self.assertFalse(all(d == deltas[0] for d in deltas))
   1622       self.assertItemsEqual(all_counts, range(total_elements))
   1623 
   1624       # Reached the limit.
   1625       with self.assertRaises(errors_impl.OutOfRangeError):
   1626         sess.run(batched_fetch)
   1627       for thread in threads:
   1628         thread.join()
   1629 
   1630   def testManyThreads(self):
   1631     with self.test_session() as sess:
   1632       batch_size = 10
   1633       num_batches = 3
   1634       zero64 = constant_op.constant(0, dtype=dtypes.int64)
   1635       examples = variables.Variable(zero64)
   1636       counter = examples.count_up_to(num_batches * batch_size)
   1637       sparse_counter = sparse_tensor.SparseTensor(
   1638           indices=array_ops.reshape(zero64, [1, 1]),
   1639           values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
   1640           dense_shape=[1])
   1641       batched = inp.shuffle_batch(
   1642           [counter, sparse_counter, "string"],
   1643           batch_size=batch_size,
   1644           capacity=32,
   1645           min_after_dequeue=16,
   1646           seed=173205,
   1647           num_threads=4)
   1648       variables.global_variables_initializer().run()
   1649       variables.local_variables_initializer().run()
   1650       threads = queue_runner_impl.start_queue_runners()
   1651 
   1652       all_counts = []
   1653       for i in range(num_batches):
   1654         results = sess.run(batched)
   1655         tf_logging.info("Batch %d: %s", i, results[0])
   1656         self.assertEqual(len(results[0]), batch_size)
   1657         all_counts.extend(results[0])
   1658         self.assertAllEqual(
   1659             results[1].indices,
   1660             np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
   1661         self.assertAllEqual(results[0], results[1].values)
   1662         self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
   1663         self.assertAllEqual(results[2], [b"string"] * batch_size)
   1664       # Results scrambled, but include all the expected numbers.
   1665       deltas = [
   1666           all_counts[i + 1] - all_counts[i] for i in range(len(all_counts) - 1)
   1667       ]
   1668       self.assertFalse(all(d == deltas[0] for d in deltas))
   1669       self.assertItemsEqual(all_counts, range(num_batches * batch_size))
   1670 
   1671       # Reached the limit.
   1672       with self.assertRaises(errors_impl.OutOfRangeError):
   1673         sess.run(batched)
   1674       for thread in threads:
   1675         thread.join()
   1676 
   1677   def testManyThreadsSmallerBatch(self):
   1678     with self.test_session() as sess:
   1679       batch_size = 10
   1680       num_batches = 3
   1681       extra_elements = 5
   1682       zero64 = constant_op.constant(0, dtype=dtypes.int64)
   1683       examples = variables.Variable(zero64)
   1684       total_elements = num_batches * batch_size + extra_elements
   1685       counter = examples.count_up_to(total_elements)
   1686       sparse_counter = sparse_tensor.SparseTensor(
   1687           indices=array_ops.reshape(zero64, [1, 1]),
   1688           values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
   1689           dense_shape=[1])
   1690       batched = inp.shuffle_batch(
   1691           [counter, sparse_counter, "string"],
   1692           batch_size=batch_size,
   1693           capacity=32,
   1694           min_after_dequeue=16,
   1695           seed=173205,
   1696           num_threads=4,
   1697           allow_smaller_final_batch=True)
   1698       variables.global_variables_initializer().run()
   1699       variables.local_variables_initializer().run()
   1700       threads = queue_runner_impl.start_queue_runners()
   1701 
   1702       all_counts = []
   1703       for i in range(num_batches):
   1704         results = sess.run(batched)
   1705         tf_logging.info("Batch %d: %s", i, results[0])
   1706         self.assertEqual(len(results[0]), batch_size)
   1707         all_counts.extend(results[0])
   1708         self.assertAllEqual(
   1709             results[1].indices,
   1710             np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
   1711         self.assertAllEqual(results[0], results[1].values)
   1712         self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
   1713         self.assertAllEqual(results[2], [b"string"] * batch_size)
   1714 
   1715       # Reached the final batch with extra elements.
   1716       results = sess.run(batched)
   1717       self.assertAllEqual(results[0].shape, [extra_elements])
   1718       self.assertAllEqual(results[1].dense_shape, [extra_elements, 1])
   1719       self.assertAllEqual(results[2], [b"string"] * extra_elements)
   1720       all_counts.extend(results[0])
   1721 
   1722       # Results scrambled, but include all the expected numbers.
   1723       deltas = [
   1724           all_counts[i + 1] - all_counts[i] for i in range(len(all_counts) - 1)
   1725       ]
   1726       self.assertFalse(all(d == deltas[0] for d in deltas))
   1727       self.assertItemsEqual(all_counts, range(total_elements))
   1728 
   1729       # Reached the limit.
   1730       with self.assertRaises(errors_impl.OutOfRangeError):
   1731         sess.run(batched)
   1732       for thread in threads:
   1733         thread.join()
   1734 
   1735   def testSharedName(self):
   1736     with self.test_session():
   1737       batch_size = 10
   1738       num_batches = 3
   1739       zero64 = constant_op.constant(0, dtype=dtypes.int64)
   1740       examples = variables.Variable(zero64)
   1741       counter = examples.count_up_to(num_batches * batch_size)
   1742       batched = inp.shuffle_batch(
   1743           [counter, "string"],
   1744           batch_size=batch_size,
   1745           capacity=32,
   1746           min_after_dequeue=10,
   1747           shared_name="SHARED_NAME_XYZ",
   1748           name="Q")
   1749 
   1750       self.assertProtoEquals(
   1751           "s: 'SHARED_NAME_XYZ'",
   1752           batched[0].op.inputs[0].op.node_def.attr["shared_name"])
   1753 
   1754   def _testKeepInputHelper(self, num_threads, enqueue_many,
   1755                            keep_input_vector=False):
   1756     with self.test_session() as sess:
   1757       batch_size = 5
   1758       num_batches = 4
   1759       examples = variables.Variable(0)
   1760       counter = examples.count_up_to(num_batches * batch_size * 2)
   1761       sparse_counter = sparse_tensor.SparseTensor(
   1762           indices=array_ops.zeros(
   1763               [1, 1], dtype=dtypes.int64),
   1764           values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
   1765           dense_shape=[1])
   1766       to_batch = [counter, sparse_counter, "string"]
   1767       if enqueue_many:
   1768         to_batch = inp.batch(to_batch, 4 if keep_input_vector else 1)
   1769       keep_input = array_ops.squeeze(
   1770           math_ops.equal(0, math_ops.mod(to_batch[0], 2)))
   1771       batched = inp.maybe_shuffle_batch(
   1772           to_batch,
   1773           batch_size,
   1774           10,
   1775           1,
   1776           keep_input,
   1777           num_threads=num_threads,
   1778           enqueue_many=enqueue_many)
   1779       variables.initialize_all_variables().run()
   1780       variables.initialize_local_variables().run()
   1781       threads = queue_runner_impl.start_queue_runners()
   1782 
   1783       for _ in range(num_batches):
   1784         results = sess.run(batched)
   1785         self.assertAllEqual([0] * batch_size, np.mod(results[0], 2))
   1786         self.assertAllEqual([0] * batch_size, np.mod(results[1].values, 2))
   1787         self.assertAllEqual([b"string"] * batch_size, results[2])
   1788 
   1789       # Reached the limit.
   1790       with self.assertRaises(errors_impl.OutOfRangeError):
   1791         sess.run(batched)
   1792       for thread in threads:
   1793         thread.join()
   1794 
   1795   def testSingleThreadKeepInput(self):
   1796     self._testKeepInputHelper(1, False)
   1797 
   1798   def testSingleThreadKeepInputEnqueueMany(self):
   1799     self._testKeepInputHelper(1, True)
   1800 
   1801   def testMultipleThreadKeepInput(self):
   1802     self._testKeepInputHelper(5, False)
   1803 
   1804   def testMultipleThreadKeepInputEnqueueMany(self):
   1805     self._testKeepInputHelper(5, True)
   1806 
   1807   def testSingleThreadKeepInputPerExample(self):
   1808     self._testKeepInputHelper(1, True, keep_input_vector=True)
   1809 
   1810   def testMultipleThreadKeepInputPerExample(self):
   1811     self._testKeepInputHelper(5, True, keep_input_vector=True)
   1812 
   1813   def testInvalidKeepInputVector(self):
   1814     # Can't have vector `keep_input` with `enqueue_many=False`.
   1815     with self.assertRaisesRegexp(ValueError, "`keep_input` cannot be a vector"):
   1816       inp.maybe_shuffle_batch([array_ops.zeros(5)], 1, 10, 1,
   1817                               keep_input=constant_op.constant([True, False]),
   1818                               enqueue_many=False)
   1819     # Can't have `keep_input` with more than one dimension.
   1820     with self.assertRaisesRegexp(ValueError, "must be 0 or 1 dimensions"):
   1821       inp.maybe_shuffle_batch([array_ops.zeros(5)], 1, 10, 1,
   1822                               keep_input=constant_op.constant([[True]]),
   1823                               enqueue_many=True)
   1824     # `keep_input` must have dimensions determined at graph construction.
   1825     with self.assertRaisesRegexp(ValueError,
   1826                                  "must be known at graph construction"):
   1827       inp.maybe_shuffle_batch([array_ops.zeros(5)], 1, 10, 1,
   1828                               keep_input=array_ops.placeholder(dtypes.bool),
   1829                               enqueue_many=True)
   1830 
   1831   def testMaybeBatchedSparseTensorInferredShape(self):
   1832     sparse = sparse_tensor.SparseTensor(
   1833         indices=[[0]], values=[1.0], dense_shape=[1])
   1834     self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
   1835     batched = inp.maybe_shuffle_batch([sparse], 2, 10, 1, True)
   1836     self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list())
   1837 
   1838   def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self):
   1839     sparse = sparse_tensor.SparseTensor(
   1840         indices=[[0]], values=[1.0], dense_shape=[1])
   1841     self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
   1842     batched = inp.maybe_shuffle_batch(
   1843         [sparse], 2, 10, 1, True, enqueue_many=True)
   1844     self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
   1845 
   1846   def testMaybeBatchedSparseTensorInferredShapeEnqueueManyPerExample(self):
   1847     sparse = sparse_tensor.SparseTensor(
   1848         indices=[[0], [0]], values=[1.0, 2.0], dense_shape=[2])
   1849     self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
   1850     batched = inp.maybe_shuffle_batch(
   1851         [sparse], 2, 10, 1, [True, False], enqueue_many=True)
   1852     self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
   1853 
   1854   def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self):
   1855     sparse = sparse_tensor.SparseTensor(
   1856         indices=array_ops.placeholder(dtypes.int64),
   1857         values=array_ops.placeholder(dtypes.float32),
   1858         dense_shape=array_ops.placeholder(dtypes.int64))
   1859     self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
   1860     batched = inp.maybe_shuffle_batch([sparse], 2, 10, 1, True)
   1861     self.assertIs(None, batched.dense_shape.get_shape().num_elements())
   1862 
   1863   def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
   1864     sparse = sparse_tensor.SparseTensor(
   1865         indices=array_ops.placeholder(dtypes.int64),
   1866         values=array_ops.placeholder(dtypes.float32),
   1867         dense_shape=array_ops.placeholder(dtypes.int64))
   1868     self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
   1869     batched = inp.maybe_shuffle_batch(
   1870         [sparse], 2, 10, 1, True, enqueue_many=True)
   1871     self.assertIs(None, batched.dense_shape.get_shape().num_elements())
   1872 
   1873   def testMaybeBatchedSparseTensorInferredShapeUnknownRankPerExample(self):
   1874     sparse = sparse_tensor.SparseTensor(
   1875         indices=array_ops.placeholder(dtypes.int64),
   1876         values=array_ops.placeholder(dtypes.float32),
   1877         dense_shape=array_ops.placeholder(dtypes.int64))
   1878     self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
   1879     batched = inp.maybe_shuffle_batch(
   1880         [sparse], 2, 10, 1, [True, False], enqueue_many=True)
   1881     self.assertIs(None, batched.dense_shape.get_shape().num_elements())
   1882 
   1883 
   1884 class ShuffleBatchJoinTest(test_lib.TestCase):
   1885 
   1886   def _testTwoThreadsHelper(self, use_dict):
   1887     with self.test_session() as sess:
   1888       # Two threads, the first generates (0..24, "a").
   1889       num_a = 25
   1890       zero64 = constant_op.constant(0, dtype=dtypes.int64)
   1891       examples = variables.Variable(zero64)
   1892       counter = examples.count_up_to(num_a)
   1893       sparse_counter = sparse_tensor.SparseTensor(
   1894           indices=array_ops.reshape(zero64, [1, 1]),
   1895           values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
   1896           dense_shape=[1])
   1897 
   1898       # The second generates (99, "b") 35 times and then stops.
   1899       num_b = 35
   1900       ninety_nine = inp.limit_epochs(
   1901           constant_op.constant(
   1902               99, dtype=dtypes.int64), num_b)
   1903       sparse_ninety_nine = sparse_tensor.SparseTensor(
   1904           indices=array_ops.reshape(zero64, [1, 1]),
   1905           values=array_ops.stack([math_ops.cast(ninety_nine, dtypes.float32)]),
   1906           dense_shape=[1])
   1907 
   1908       # These get joined together and grouped into batches of 5.
   1909       batch_size = 5
   1910       if use_dict:
   1911         batched = inp.shuffle_batch_join(
   1912             [{
   1913                 "c": counter,
   1914                 "s": sparse_counter,
   1915                 "S": "a"
   1916             }, {
   1917                 "c": ninety_nine,
   1918                 "s": sparse_ninety_nine,
   1919                 "S": "b"
   1920             }],
   1921             batch_size=batch_size,
   1922             capacity=32,
   1923             min_after_dequeue=16,
   1924             seed=223607)
   1925         batched_fetch = [batched["c"], batched["s"], batched["S"]]
   1926       else:
   1927         batched = inp.shuffle_batch_join(
   1928             [[counter, sparse_counter, "a"],
   1929              [ninety_nine, sparse_ninety_nine, "b"]],
   1930             batch_size=batch_size,
   1931             capacity=32,
   1932             min_after_dequeue=16,
   1933             seed=223607)
   1934         batched_fetch = batched
   1935 
   1936       # Shapes.
   1937       self.assertEqual(3, len(batched_fetch))
   1938       self.assertAllEqual((batch_size,), batched_fetch[0].get_shape().as_list())
   1939       self.assertAllEqual((None, 2),
   1940                           batched_fetch[1].indices.get_shape().as_list())
   1941       self.assertAllEqual((None,),
   1942                           batched_fetch[1].values.get_shape().as_list())
   1943       self.assertAllEqual((2,),
   1944                           batched_fetch[1].dense_shape.get_shape().as_list())
   1945       self.assertAllEqual((batch_size,), batched_fetch[2].get_shape().as_list())
   1946 
   1947       variables.global_variables_initializer().run()
   1948       variables.local_variables_initializer().run()
   1949       threads = queue_runner_impl.start_queue_runners()
   1950 
   1951       # Should see the "a" and "b" threads mixed together.
   1952       all_a = []
   1953       seen_b = 0
   1954       saw_both = 0
   1955       num_batches = (num_a + num_b) // batch_size
   1956       for i in range(num_batches):
   1957         results = sess.run(batched_fetch)
   1958         self.assertEqual(3, len(results))
   1959         self.assertEqual(len(results[0]), batch_size)
   1960         self.assertEqual(len(results[2]), batch_size)
   1961         self.assertAllEqual(results[0], results[1].values)
   1962         self.assertAllEqual(
   1963             results[1].indices,
   1964             np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
   1965         self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
   1966         which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
   1967         which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
   1968         self.assertEqual(len(which_a) + len(which_b), batch_size)
   1969         if which_a and which_b:
   1970           saw_both += 1
   1971         all_a.extend([results[0][i] for i in which_a])
   1972         seen_b += len(which_b)
   1973         self.assertAllEqual([99] * len(which_b),
   1974                             [results[0][i] for i in which_b])
   1975 
   1976       # Some minimum level of mixing of the results of both threads.
   1977       self.assertGreater(saw_both, 1)
   1978 
   1979       # Saw all the items from "a", but scrambled.
   1980       self.assertItemsEqual(all_a, range(num_a))
   1981       deltas = [all_a[i + 1] - all_a[i] for i in range(len(all_a) - 1)]
   1982       self.assertFalse(all(d == deltas[0] for d in deltas))
   1983       self.assertEqual(seen_b, num_b)
   1984 
   1985       # Reached the limit.
   1986       with self.assertRaises(errors_impl.OutOfRangeError):
   1987         sess.run(batched_fetch)
   1988       for thread in threads:
   1989         thread.join()
   1990 
   1991   def testTwoThreads(self):
   1992     self._testTwoThreadsHelper(use_dict=False)
   1993 
   1994   def testTwoThreadsDict(self):
   1995     self._testTwoThreadsHelper(use_dict=True)
   1996 
   1997   def testTwoThreadsSmallerBatch(self):
   1998     with self.test_session() as sess:
   1999       # Two threads, the first generates (0..26, "a").
   2000       extra_elements = 2
   2001       num_a = 25 + extra_elements
   2002       zero64 = constant_op.constant(0, dtype=dtypes.int64)
   2003       examples = variables.Variable(zero64)
   2004       counter = examples.count_up_to(num_a)
   2005       sparse_counter = sparse_tensor.SparseTensor(
   2006           indices=array_ops.reshape(zero64, [1, 1]),
   2007           values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
   2008           dense_shape=[1])
   2009 
   2010       # The second generates (99, "b") 37 times and then stops.
   2011       num_b = 35 + extra_elements
   2012       ninety_nine = inp.limit_epochs(
   2013           constant_op.constant(
   2014               99, dtype=dtypes.int64), num_b)
   2015       sparse_ninety_nine = sparse_tensor.SparseTensor(
   2016           indices=array_ops.reshape(zero64, [1, 1]),
   2017           values=array_ops.stack([math_ops.cast(ninety_nine, dtypes.float32)]),
   2018           dense_shape=[1])
   2019 
   2020       # These get joined together and grouped into batches of 5.
   2021       batch_size = 5
   2022       batched = inp.shuffle_batch_join(
   2023           [[counter, sparse_counter, "a"],
   2024            [ninety_nine, sparse_ninety_nine, "b"]],
   2025           batch_size=batch_size,
   2026           capacity=32,
   2027           min_after_dequeue=16,
   2028           seed=223607,
   2029           allow_smaller_final_batch=True)
   2030 
   2031       # Shapes.
   2032       self.assertEqual(3, len(batched))
   2033       self.assertAllEqual((None,), batched[0].get_shape().as_list())
   2034       self.assertAllEqual((None, 2), batched[1].indices.get_shape().as_list())
   2035       self.assertAllEqual((None,), batched[1].values.get_shape().as_list())
   2036       self.assertAllEqual((2,), batched[1].dense_shape.get_shape().as_list())
   2037       self.assertAllEqual((None,), batched[2].get_shape().as_list())
   2038 
   2039       variables.global_variables_initializer().run()
   2040       variables.local_variables_initializer().run()
   2041       threads = queue_runner_impl.start_queue_runners()
   2042 
   2043       # Should see the "a" and "b" threads mixed together.
   2044       all_a = []
   2045       seen_b = 0
   2046       saw_both = 0
   2047       num_batches = (num_a + num_b) // batch_size
   2048       for i in range(num_batches):
   2049         results = sess.run(batched)
   2050         tf_logging.info("Batch %d: %s", i, results[0])
   2051         self.assertEqual(len(results[0]), batch_size)
   2052         self.assertEqual(len(results[2]), batch_size)
   2053         self.assertAllEqual(results[0], results[1].values)
   2054         self.assertAllEqual(
   2055             results[1].indices,
   2056             np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
   2057         self.assertAllEqual(results[1].dense_shape, [batch_size, 1])
   2058         which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
   2059         which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
   2060         self.assertEqual(len(which_a) + len(which_b), batch_size)
   2061         if which_a and which_b:
   2062           saw_both += 1
   2063         all_a.extend([results[0][i] for i in which_a])
   2064         seen_b += len(which_b)
   2065         self.assertAllEqual([99] * len(which_b),
   2066                             [results[0][i] for i in which_b])
   2067 
   2068       # Reached end with 2 * extra_elements left
   2069       results = sess.run(batched)
   2070       self.assertEqual(len(results[0]), 2 * extra_elements)
   2071       self.assertAllEqual(results[1].dense_shape, [2 * extra_elements, 1])
   2072       self.assertEqual(len(results[2]), 2 * extra_elements)
   2073       self.assertAllEqual(results[0], results[1].values)
   2074       self.assertAllEqual(results[1].indices,
   2075                           np.vstack((np.arange(2 * extra_elements),
   2076                                      np.zeros(2 * extra_elements))).T)
   2077       which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
   2078       which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
   2079       self.assertEqual(len(which_a) + len(which_b), 2 * extra_elements)
   2080       if which_a and which_b:
   2081         saw_both += 1
   2082       all_a.extend([results[0][i] for i in which_a])
   2083       seen_b += len(which_b)
   2084 
   2085       # Some minimum level of mixing of the results of both threads.
   2086       self.assertGreater(saw_both, 1)
   2087 
   2088       # Saw all the items from "a", but scrambled, including extras.
   2089       self.assertItemsEqual(all_a, range(num_a))
   2090       deltas = [all_a[i + 1] - all_a[i] for i in range(len(all_a) - 1)]
   2091       self.assertFalse(all(d == deltas[0] for d in deltas))
   2092       self.assertEqual(seen_b, num_b)
   2093 
   2094       # Reached the limit.
   2095       with self.assertRaises(errors_impl.OutOfRangeError):
   2096         sess.run(batched)
   2097       for thread in threads:
   2098         thread.join()
   2099 
   2100   def testMismatchedDictKeys(self):
   2101     with self.assertRaisesRegexp(ValueError, "must have the same keys"):
   2102       inp.shuffle_batch_join(
   2103           [{
   2104               "c": 12,
   2105               "s": 123,
   2106               "S": "a"
   2107           }, {
   2108               "cool": -12,
   2109               "s": 99,
   2110               "S": "b"
   2111           }],
   2112           batch_size=8,
   2113           capacity=32,
   2114           min_after_dequeue=16,
   2115           seed=223607)
   2116 
   2117   def testSharedName(self):
   2118     with self.test_session():
   2119       batch_size = 10
   2120       num_batches = 3
   2121       zero64 = constant_op.constant(0, dtype=dtypes.int64)
   2122       examples = variables.Variable(zero64)
   2123       counter = examples.count_up_to(num_batches * batch_size)
   2124       batched = inp.shuffle_batch_join(
   2125           [[counter, "string"]],
   2126           batch_size=batch_size,
   2127           capacity=32,
   2128           min_after_dequeue=10,
   2129           shared_name="SHARED_NAME_XYZ",
   2130           name="Q")
   2131 
   2132       # Shapes.
   2133       self.assertEqual(2, len(batched))
   2134       self.assertAllEqual((batch_size,), batched[0].get_shape().as_list())
   2135       self.assertAllEqual((batch_size,), batched[1].get_shape().as_list())
   2136 
   2137       self.assertProtoEquals(
   2138           "s: 'SHARED_NAME_XYZ'",
   2139           batched[0].op.inputs[0].op.node_def.attr["shared_name"])
   2140 
   2141   def _testKeepInputHelper(self, num_threads, enqueue_many,
   2142                            keep_input_vector=False):
   2143     with self.test_session() as sess:
   2144       batch_size = 5
   2145       num_batches = 4
   2146       examples = variables.Variable(0)
   2147       counter = examples.count_up_to(num_batches * batch_size * 2)
   2148       sparse_counter = sparse_tensor.SparseTensor(
   2149           indices=array_ops.zeros(
   2150               [1, 1], dtype=dtypes.int64),
   2151           values=array_ops.stack([math_ops.cast(counter, dtypes.float32)]),
   2152           dense_shape=[1])
   2153       to_batch = [counter, sparse_counter, "string"]
   2154       if enqueue_many:
   2155         to_batch = inp.batch(to_batch, 4 if keep_input_vector else 1)
   2156       keep_input = array_ops.squeeze(
   2157           math_ops.equal(0, math_ops.mod(to_batch[0], 2)))
   2158       batched = inp.maybe_shuffle_batch_join(
   2159           [to_batch] * num_threads,
   2160           batch_size,
   2161           10,
   2162           1,
   2163           keep_input,
   2164           enqueue_many=enqueue_many)
   2165       variables.initialize_all_variables().run()
   2166       variables.initialize_local_variables().run()
   2167       threads = queue_runner_impl.start_queue_runners()
   2168 
   2169       for _ in range(num_batches):
   2170         results = sess.run(batched)
   2171         self.assertAllEqual([0] * batch_size, np.mod(results[0], 2))
   2172         self.assertAllEqual([0] * batch_size, np.mod(results[1].values, 2))
   2173         self.assertAllEqual([b"string"] * batch_size, results[2])
   2174 
   2175       # Reached the limit.
   2176       with self.assertRaises(errors_impl.OutOfRangeError):
   2177         sess.run(batched)
   2178       for thread in threads:
   2179         thread.join()
   2180 
   2181   def testSingleThreadKeepInput(self):
   2182     self._testKeepInputHelper(1, False)
   2183 
   2184   def testSingleThreadKeepInputEnqueueMany(self):
   2185     self._testKeepInputHelper(1, True)
   2186 
   2187   def testMultipleThreadKeepInput(self):
   2188     self._testKeepInputHelper(5, False)
   2189 
   2190   def testMultipleThreadKeepInputEnqueueMany(self):
   2191     self._testKeepInputHelper(5, True)
   2192 
   2193   def testSingleThreadKeepInputPerExample(self):
   2194     self._testKeepInputHelper(1, True, keep_input_vector=True)
   2195 
   2196   def testMultipleThreadKeepInputPerExample(self):
   2197     self._testKeepInputHelper(5, True, keep_input_vector=True)
   2198 
   2199   def testInvalidKeepInputVector(self):
   2200     # Can't have vector `keep_input` with `enqueue_many=False`.
   2201     with self.assertRaisesRegexp(ValueError, "`keep_input` cannot be a vector"):
   2202       inp.maybe_shuffle_batch_join(
   2203           [[array_ops.zeros(5)]], 1, 10, 1,
   2204           keep_input=constant_op.constant([True, False]),
   2205           enqueue_many=False)
   2206     # Can't have `keep_input` with more than one dimension.
   2207     with self.assertRaisesRegexp(ValueError, "must be 0 or 1 dimensions"):
   2208       inp.maybe_shuffle_batch_join(
   2209           [[array_ops.zeros(5)]], 1, 10, 1,
   2210           keep_input=constant_op.constant([[True]]),
   2211           enqueue_many=True)
   2212     # `keep_input` must have dimensions determined at graph construction.
   2213     with self.assertRaisesRegexp(ValueError,
   2214                                  "must be known at graph construction"):
   2215       inp.maybe_shuffle_batch_join(
   2216           [[array_ops.zeros(5)]], 1, 10, 1,
   2217           keep_input=array_ops.placeholder(dtypes.bool),
   2218           enqueue_many=True)
   2219 
   2220   def testMaybeBatchedSparseTensorInferredShape(self):
   2221     sparse = sparse_tensor.SparseTensor(
   2222         indices=[[0]], values=[1.0], dense_shape=[1])
   2223     self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
   2224     batched = inp.maybe_shuffle_batch_join([[sparse]], 2, 10, 1, True)
   2225     self.assertAllEqual((2,), batched.dense_shape.get_shape().as_list())
   2226 
   2227   def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self):
   2228     sparse = sparse_tensor.SparseTensor(
   2229         indices=[[0]], values=[1.0], dense_shape=[1])
   2230     self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
   2231     batched = inp.maybe_shuffle_batch_join(
   2232         [[sparse]], 2, 10, 1, True, enqueue_many=True)
   2233     self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
   2234 
   2235   def testMaybeBatchedSparseTensorInferredShapeEnqueueManyPerExample(self):
   2236     sparse = sparse_tensor.SparseTensor(
   2237         indices=[[0], [0]], values=[1.0, 2.0], dense_shape=[2])
   2238     self.assertAllEqual((1,), sparse.dense_shape.get_shape().as_list())
   2239     batched = inp.maybe_shuffle_batch_join(
   2240         [[sparse]], 2, 10, 1, [True, False], enqueue_many=True)
   2241     self.assertAllEqual((1,), batched.dense_shape.get_shape().as_list())
   2242 
   2243   def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self):
   2244     sparse = sparse_tensor.SparseTensor(
   2245         indices=array_ops.placeholder(dtypes.int64),
   2246         values=array_ops.placeholder(dtypes.float32),
   2247         dense_shape=array_ops.placeholder(dtypes.int64))
   2248     self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
   2249     batched = inp.maybe_shuffle_batch_join([[sparse]], 2, 10, 1, True)
   2250     self.assertIs(None, batched.dense_shape.get_shape().num_elements())
   2251 
   2252   def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
   2253     sparse = sparse_tensor.SparseTensor(
   2254         indices=array_ops.placeholder(dtypes.int64),
   2255         values=array_ops.placeholder(dtypes.float32),
   2256         dense_shape=array_ops.placeholder(dtypes.int64))
   2257     self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
   2258     batched = inp.maybe_shuffle_batch_join(
   2259         [[sparse]], 2, 10, 1, True, enqueue_many=True)
   2260     self.assertIs(None, batched.dense_shape.get_shape().num_elements())
   2261 
   2262   def testMaybeBatchedSparseTensorInferredShapeUnknownRankPerExample(self):
   2263     sparse = sparse_tensor.SparseTensor(
   2264         indices=array_ops.placeholder(dtypes.int64),
   2265         values=array_ops.placeholder(dtypes.float32),
   2266         dense_shape=array_ops.placeholder(dtypes.int64))
   2267     self.assertIs(None, sparse.dense_shape.get_shape().num_elements())
   2268     batched = inp.maybe_shuffle_batch_join(
   2269         [[sparse]], 2, 10, 1, [True, False], enqueue_many=True)
   2270     self.assertIs(None, batched.dense_shape.get_shape().num_elements())
   2271 
   2272 
   2273 if __name__ == "__main__":
   2274   test_lib.main()
   2275