1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Tests for the currently experimental in-graph batch ops.""" 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import threading 22 import time 23 24 from tensorflow.contrib.batching.python.ops import batch_ops 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.ops import array_ops 27 from tensorflow.python.ops import gradients_impl 28 from tensorflow.python.ops import script_ops 29 from tensorflow.python.platform import test 30 31 32 def delayed_plus1(x): 33 """Sleeps for 100ms then returns x+1.""" 34 time.sleep(0.1) 35 return x + 1 36 37 38 class BatchOpsTest(test.TestCase): 39 """Tests for batch_ops.{un,}batch.""" 40 41 def testBasicBatch(self): 42 """Tests that a single batched tensor executes together and only once.""" 43 with self.test_session() as sess: 44 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 45 batched, index, _ = batch_ops.batch( 46 [inp], num_batch_threads=1, max_batch_size=2, 47 batch_timeout_micros=36000000, grad_timeout_micros=0, 48 batching_queue="") 49 thread_results = [] 50 51 def worker(): 52 thread_results.extend( 53 sess.run([batched, index], feed_dict={inp: [1]})) 54 55 worker_thread = threading.Thread(target=worker) 56 worker_thread.start() 57 main_results = sess.run([batched, index], feed_dict={inp: [2]}) 58 worker_thread.join() 59 60 # At this point either the thread or the main did the batch and the other 61 # should have empty results. 62 if list(thread_results[0][0]): 63 batch_t = thread_results[0][0] 64 index_t = thread_results[1] 65 empty_b = main_results[0][0] 66 empty_m = main_results[1] 67 else: 68 batch_t = main_results[0][0] 69 index_t = main_results[1] 70 empty_b = thread_results[0][0] 71 empty_m = thread_results[1] 72 73 # Check that both the inputs made it out exactly once. 74 self.assertAllEqual(sorted(batch_t), (1, 2)) 75 # Check that we get 2 rows in the index tensor. 76 self.assertEqual(len(index_t), 2) 77 # Check that the other ones are empty. 78 self.assertEqual(len(empty_b), 0) 79 self.assertEqual(len(empty_m), 0) 80 81 def testBatchWithPadding(self): 82 """Test that batching with padding up to an allowed batch size works.""" 83 with self.test_session() as sess: 84 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2]) 85 batched, index, _ = batch_ops.batch( 86 [inp], num_batch_threads=1, max_batch_size=10, 87 batch_timeout_micros=100000, # 100ms 88 allowed_batch_sizes=[5, 10], 89 grad_timeout_micros=0, batching_queue="") 90 thread_results = [] 91 92 def worker(): 93 thread_results.extend( 94 sess.run([batched, index], feed_dict={inp: [1, 3]})) 95 96 worker_thread = threading.Thread(target=worker) 97 worker_thread.start() 98 main_results = sess.run([batched, index], feed_dict={inp: [2, 4]}) 99 worker_thread.join() 100 101 # At this point either the thread or the main did the batch and the other 102 # should have empty results. 103 if list(thread_results[0][0]): 104 batch_t = thread_results[0][0] 105 else: 106 batch_t = main_results[0][0] 107 108 # Check that the batch tensor incorporates the padding. 109 self.assertEqual(len(batch_t), 5) 110 111 def testMultipleBatch(self): 112 """Tests that multiple batched tensors execute together.""" 113 with self.test_session() as sess: 114 inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 115 inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 116 batched, _, _ = batch_ops.batch( 117 [inp0, inp1], 118 num_batch_threads=1, 119 max_batch_size=2, 120 batch_timeout_micros=36000000, 121 grad_timeout_micros=0, 122 batching_queue="") 123 thread_results = [] 124 125 def worker(): 126 thread_results.extend( 127 sess.run([batched], feed_dict={inp0: [1], 128 inp1: [2]})) 129 130 worker_thread = threading.Thread(target=worker) 131 worker_thread.start() 132 main_results = sess.run([batched], feed_dict={inp0: [2], inp1: [3]}) 133 worker_thread.join() 134 135 # At this point either the thread or the main did the batch and the other 136 # should have empty results. 137 if list(thread_results[0][0]): 138 batch_t = thread_results[0] 139 empty_t = main_results[0] 140 else: 141 batch_t = main_results[0] 142 empty_t = thread_results[0] 143 144 # Assert that the tensors were batched together. 145 self.assertAllEqual(sorted(batch_t[0]), [1, 2]) 146 self.assertAllEqual(sorted(batch_t[1]), [2, 3]) 147 self.assertAllEqual(empty_t[0], []) 148 self.assertAllEqual(empty_t[1], []) 149 150 def testIllegalBatchDifferentDim0Sizes(self): 151 """Tests illegally feeding tensors with different dim0 sizes.""" 152 with self.test_session() as sess: 153 inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 154 inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2]) 155 batched, index, _ = batch_ops.batch( 156 [inp0, inp1], num_batch_threads=1, max_batch_size=2, 157 batch_timeout_micros=0, grad_timeout_micros=0, batching_queue="") 158 with self.assertRaises(Exception) as raised: 159 _ = sess.run([batched, index], feed_dict={inp0: [0], inp1: [1, 2]}) 160 self.assertGreater( 161 raised.exception.message.find("must have equal 0th-dimension size"), 162 0) 163 164 def testBasicUnbatch(self): 165 """Tests that batch and unbatch work together.""" 166 with self.test_session() as sess: 167 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 168 batched, index, id_t = batch_ops.batch( 169 [inp], num_batch_threads=1, max_batch_size=10, 170 batch_timeout_micros=100000, # 100ms 171 allowed_batch_sizes=[3, 10], 172 grad_timeout_micros=0, batching_queue="") 173 computation = batched[0] + 1 174 result = batch_ops.unbatch(computation, index, id_t, 175 timeout_micros=1000000, shared_name="unbatch") 176 thread_results = [] 177 178 def worker(): 179 thread_results.extend(sess.run([result], feed_dict={inp: [1]})) 180 181 worker_thread = threading.Thread(target=worker) 182 worker_thread.start() 183 main_results = sess.run([result], feed_dict={inp: [2]}) 184 worker_thread.join() 185 self.assertEqual(thread_results[0], [2]) 186 self.assertEqual(main_results[0], [3]) 187 188 def testBasicUnbatchDecorated(self): 189 """Tests that the batch_function decorator works.""" 190 with self.test_session() as sess: 191 @batch_ops.batch_function(1, 10, 100000) 192 def computation(in_t): 193 return in_t + 1 194 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 195 result = computation(inp) 196 thread_results = [] 197 198 def worker(): 199 thread_results.extend(sess.run([result], feed_dict={inp: [1]})) 200 201 worker_thread = threading.Thread(target=worker) 202 worker_thread.start() 203 main_results = sess.run([result], feed_dict={inp: [2]}) 204 worker_thread.join() 205 self.assertEqual(thread_results[0], [2]) 206 self.assertEqual(main_results[0], [3]) 207 208 def testUnbatchTimeout(self): 209 """Tests that the unbatch timeout works.""" 210 with self.test_session() as sess: 211 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 212 batched, index, id_t = batch_ops.batch( 213 [inp], num_batch_threads=1, max_batch_size=2, 214 batch_timeout_micros=36000000, grad_timeout_micros=0, 215 batching_queue="") 216 computation = batched[0] + 1 217 timeout_micros = 10 218 result = batch_ops.unbatch(computation, index, id_t, timeout_micros, 219 shared_name="shared_unbatch") 220 # Set up a parallel pipeline that delays the computation, but uses the 221 # same unbatch resource object as the non-delayed pipeline. 222 computation_delayed = script_ops.py_func(delayed_plus1, 223 [batched[0]], 224 dtypes.int32) 225 result_delayed = batch_ops.unbatch(computation_delayed, 226 index, 227 id_t, 228 timeout_micros, 229 shared_name="shared_unbatch") 230 231 thread_results = [] 232 def worker(): 233 # A first call using the non-delayed pipeline. The batcher will send an 234 # empty tensor along the non-delayed pipeline. 235 thread_results.extend(sess.run([result], feed_dict={inp: [1]})) 236 worker_thread = threading.Thread(target=worker) 237 worker_thread.start() 238 time.sleep(0.1) # Ensure the thread's call starts first. 239 # A second call using the delayed pipeline. The batcher will send the 240 # batched tensor along the delayed pipeline, thus delaying the arrival of 241 # the batched tensor at the unbatch op, relative to the empty tensor. 242 # 243 # TODO(olston, apassos): Avoid relying on the order in which the batch op 244 # emits the empty tensor versus the batched one. 245 _ = sess.run([result_delayed], feed_dict={inp: [2]}) 246 worker_thread.join() 247 # The thread's call should hit the timeout, and thus get 0 results. 248 self.assertEqual(len(thread_results), 0) 249 250 def testUnbatchGrad(self): 251 """Tests that batch and unbatch are differentiable.""" 252 with self.test_session() as sess: 253 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 254 batched, index, id_t = batch_ops.batch( 255 [inp], num_batch_threads=1, max_batch_size=2, 256 batch_timeout_micros=36000000, grad_timeout_micros=1000000, 257 batching_queue="") 258 computation = batched[0] * batched[0] 259 result = batch_ops.unbatch(computation, index, id_t, 260 timeout_micros=1000000, shared_name="unbatch") 261 grad = gradients_impl.gradients(result, inp) 262 thread_results = [] 263 264 def worker(): 265 thread_results.extend(sess.run([grad], feed_dict={inp: [1]})) 266 267 worker_thread = threading.Thread(target=worker) 268 worker_thread.start() 269 main_results = sess.run([grad], feed_dict={inp: [2]}) 270 worker_thread.join() 271 self.assertEqual(thread_results[0], [2]) 272 self.assertEqual(main_results[0], [4]) 273 274 275 if __name__ == "__main__": 276 test.main() 277