Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Tests for the currently experimental in-graph batch ops."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import threading
     22 import time
     23 
     24 from tensorflow.contrib.batching.python.ops import batch_ops
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import gradients_impl
     28 from tensorflow.python.ops import script_ops
     29 from tensorflow.python.platform import test
     30 
     31 
     32 def delayed_plus1(x):
     33   """Sleeps for 100ms then returns x+1."""
     34   time.sleep(0.1)
     35   return x + 1
     36 
     37 
     38 class BatchOpsTest(test.TestCase):
     39   """Tests for batch_ops.{un,}batch."""
     40 
     41   def testBasicBatch(self):
     42     """Tests that a single batched tensor executes together and only once."""
     43     with self.test_session() as sess:
     44       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
     45       batched, index, _ = batch_ops.batch(
     46           [inp], num_batch_threads=1, max_batch_size=2,
     47           batch_timeout_micros=36000000, grad_timeout_micros=0,
     48           batching_queue="")
     49       thread_results = []
     50 
     51       def worker():
     52         thread_results.extend(
     53             sess.run([batched, index], feed_dict={inp: [1]}))
     54 
     55       worker_thread = threading.Thread(target=worker)
     56       worker_thread.start()
     57       main_results = sess.run([batched, index], feed_dict={inp: [2]})
     58       worker_thread.join()
     59 
     60       # At this point either the thread or the main did the batch and the other
     61       # should have empty results.
     62       if list(thread_results[0][0]):
     63         batch_t = thread_results[0][0]
     64         index_t = thread_results[1]
     65         empty_b = main_results[0][0]
     66         empty_m = main_results[1]
     67       else:
     68         batch_t = main_results[0][0]
     69         index_t = main_results[1]
     70         empty_b = thread_results[0][0]
     71         empty_m = thread_results[1]
     72 
     73       # Check that both the inputs made it out exactly once.
     74       self.assertAllEqual(sorted(batch_t), (1, 2))
     75       # Check that we get 2 rows in the index tensor.
     76       self.assertEqual(len(index_t), 2)
     77       # Check that the other ones are empty.
     78       self.assertEqual(len(empty_b), 0)
     79       self.assertEqual(len(empty_m), 0)
     80 
     81   def testBatchWithPadding(self):
     82     """Test that batching with padding up to an allowed batch size works."""
     83     with self.test_session() as sess:
     84       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
     85       batched, index, _ = batch_ops.batch(
     86           [inp], num_batch_threads=1, max_batch_size=10,
     87           batch_timeout_micros=100000,  # 100ms
     88           allowed_batch_sizes=[5, 10],
     89           grad_timeout_micros=0, batching_queue="")
     90       thread_results = []
     91 
     92       def worker():
     93         thread_results.extend(
     94             sess.run([batched, index], feed_dict={inp: [1, 3]}))
     95 
     96       worker_thread = threading.Thread(target=worker)
     97       worker_thread.start()
     98       main_results = sess.run([batched, index], feed_dict={inp: [2, 4]})
     99       worker_thread.join()
    100 
    101       # At this point either the thread or the main did the batch and the other
    102       # should have empty results.
    103       if list(thread_results[0][0]):
    104         batch_t = thread_results[0][0]
    105       else:
    106         batch_t = main_results[0][0]
    107 
    108       # Check that the batch tensor incorporates the padding.
    109       self.assertEqual(len(batch_t), 5)
    110 
    111   def testMultipleBatch(self):
    112     """Tests that multiple batched tensors execute together."""
    113     with self.test_session() as sess:
    114       inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
    115       inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
    116       batched, _, _ = batch_ops.batch(
    117           [inp0, inp1],
    118           num_batch_threads=1,
    119           max_batch_size=2,
    120           batch_timeout_micros=36000000,
    121           grad_timeout_micros=0,
    122           batching_queue="")
    123       thread_results = []
    124 
    125       def worker():
    126         thread_results.extend(
    127             sess.run([batched], feed_dict={inp0: [1],
    128                                            inp1: [2]}))
    129 
    130       worker_thread = threading.Thread(target=worker)
    131       worker_thread.start()
    132       main_results = sess.run([batched], feed_dict={inp0: [2], inp1: [3]})
    133       worker_thread.join()
    134 
    135       # At this point either the thread or the main did the batch and the other
    136       # should have empty results.
    137       if list(thread_results[0][0]):
    138         batch_t = thread_results[0]
    139         empty_t = main_results[0]
    140       else:
    141         batch_t = main_results[0]
    142         empty_t = thread_results[0]
    143 
    144       # Assert that the tensors were batched together.
    145       self.assertAllEqual(sorted(batch_t[0]), [1, 2])
    146       self.assertAllEqual(sorted(batch_t[1]), [2, 3])
    147       self.assertAllEqual(empty_t[0], [])
    148       self.assertAllEqual(empty_t[1], [])
    149 
    150   def testIllegalBatchDifferentDim0Sizes(self):
    151     """Tests illegally feeding tensors with different dim0 sizes."""
    152     with self.test_session() as sess:
    153       inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
    154       inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
    155       batched, index, _ = batch_ops.batch(
    156           [inp0, inp1], num_batch_threads=1, max_batch_size=2,
    157           batch_timeout_micros=0, grad_timeout_micros=0, batching_queue="")
    158       with self.assertRaises(Exception) as raised:
    159         _ = sess.run([batched, index], feed_dict={inp0: [0], inp1: [1, 2]})
    160       self.assertGreater(
    161           raised.exception.message.find("must have equal 0th-dimension size"),
    162           0)
    163 
    164   def testBasicUnbatch(self):
    165     """Tests that batch and unbatch work together."""
    166     with self.test_session() as sess:
    167       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
    168       batched, index, id_t = batch_ops.batch(
    169           [inp], num_batch_threads=1, max_batch_size=10,
    170           batch_timeout_micros=100000,  # 100ms
    171           allowed_batch_sizes=[3, 10],
    172           grad_timeout_micros=0, batching_queue="")
    173       computation = batched[0] + 1
    174       result = batch_ops.unbatch(computation, index, id_t,
    175                                  timeout_micros=1000000, shared_name="unbatch")
    176       thread_results = []
    177 
    178       def worker():
    179         thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
    180 
    181       worker_thread = threading.Thread(target=worker)
    182       worker_thread.start()
    183       main_results = sess.run([result], feed_dict={inp: [2]})
    184       worker_thread.join()
    185       self.assertEqual(thread_results[0], [2])
    186       self.assertEqual(main_results[0], [3])
    187 
    188   def testBasicUnbatchDecorated(self):
    189     """Tests that the batch_function decorator works."""
    190     with self.test_session() as sess:
    191       @batch_ops.batch_function(1, 10, 100000)
    192       def computation(in_t):
    193         return in_t + 1
    194       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
    195       result = computation(inp)
    196       thread_results = []
    197 
    198       def worker():
    199         thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
    200 
    201       worker_thread = threading.Thread(target=worker)
    202       worker_thread.start()
    203       main_results = sess.run([result], feed_dict={inp: [2]})
    204       worker_thread.join()
    205       self.assertEqual(thread_results[0], [2])
    206       self.assertEqual(main_results[0], [3])
    207 
    208   def testUnbatchTimeout(self):
    209     """Tests that the unbatch timeout works."""
    210     with self.test_session() as sess:
    211       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
    212       batched, index, id_t = batch_ops.batch(
    213           [inp], num_batch_threads=1, max_batch_size=2,
    214           batch_timeout_micros=36000000, grad_timeout_micros=0,
    215           batching_queue="")
    216       computation = batched[0] + 1
    217       timeout_micros = 10
    218       result = batch_ops.unbatch(computation, index, id_t, timeout_micros,
    219                                  shared_name="shared_unbatch")
    220       # Set up a parallel pipeline that delays the computation, but uses the
    221       # same unbatch resource object as the non-delayed pipeline.
    222       computation_delayed = script_ops.py_func(delayed_plus1,
    223                                                [batched[0]],
    224                                                dtypes.int32)
    225       result_delayed = batch_ops.unbatch(computation_delayed,
    226                                          index,
    227                                          id_t,
    228                                          timeout_micros,
    229                                          shared_name="shared_unbatch")
    230 
    231       thread_results = []
    232       def worker():
    233         # A first call using the non-delayed pipeline. The batcher will send an
    234         # empty tensor along the non-delayed pipeline.
    235         thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
    236       worker_thread = threading.Thread(target=worker)
    237       worker_thread.start()
    238       time.sleep(0.1)  # Ensure the thread's call starts first.
    239       # A second call using the delayed pipeline.  The batcher will send the
    240       # batched tensor along the delayed pipeline, thus delaying the arrival of
    241       # the batched tensor at the unbatch op, relative to the empty tensor.
    242       #
    243       # TODO(olston, apassos): Avoid relying on the order in which the batch op
    244       # emits the empty tensor versus the batched one.
    245       _ = sess.run([result_delayed], feed_dict={inp: [2]})
    246       worker_thread.join()
    247       # The thread's call should hit the timeout, and thus get 0 results.
    248       self.assertEqual(len(thread_results), 0)
    249 
    250   def testUnbatchGrad(self):
    251     """Tests that batch and unbatch are differentiable."""
    252     with self.test_session() as sess:
    253       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
    254       batched, index, id_t = batch_ops.batch(
    255           [inp], num_batch_threads=1, max_batch_size=2,
    256           batch_timeout_micros=36000000, grad_timeout_micros=1000000,
    257           batching_queue="")
    258       computation = batched[0] * batched[0]
    259       result = batch_ops.unbatch(computation, index, id_t,
    260                                  timeout_micros=1000000, shared_name="unbatch")
    261       grad = gradients_impl.gradients(result, inp)
    262       thread_results = []
    263 
    264       def worker():
    265         thread_results.extend(sess.run([grad], feed_dict={inp: [1]}))
    266 
    267       worker_thread = threading.Thread(target=worker)
    268       worker_thread.start()
    269       main_results = sess.run([grad], feed_dict={inp: [2]})
    270       worker_thread.join()
    271       self.assertEqual(thread_results[0], [2])
    272       self.assertEqual(main_results[0], [4])
    273 
    274 
    275 if __name__ == "__main__":
    276   test.main()
    277