1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.ops.data_flow_ops.PriorityQueue.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import copy 22 import random 23 import threading 24 25 import numpy as np 26 27 from tensorflow.python.framework import constant_op 28 from tensorflow.python.framework import dtypes 29 from tensorflow.python.framework import errors_impl 30 from tensorflow.python.ops import array_ops 31 from tensorflow.python.ops import data_flow_ops 32 import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 33 from tensorflow.python.platform import test 34 35 36 class PriorityQueueTest(test.TestCase): 37 38 def testRoundTripInsertReadOnceSorts(self): 39 with self.test_session() as sess: 40 q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), ( 41 (), ())) 42 elem = np.random.randint(-5, 5, size=100).astype(np.int64) 43 side_value_0 = np.random.rand(100).astype(bytes) 44 side_value_1 = np.random.rand(100).astype(bytes) 45 enq_list = [ 46 q.enqueue((e, constant_op.constant(v0), constant_op.constant(v1))) 47 for e, v0, v1 in zip(elem, side_value_0, side_value_1) 48 ] 49 for enq in enq_list: 50 enq.run() 51 52 deq = q.dequeue_many(100) 53 deq_elem, deq_value_0, deq_value_1 = sess.run(deq) 54 55 allowed = {} 56 missed = set() 57 for e, v0, v1 in zip(elem, side_value_0, side_value_1): 58 if e not in allowed: 59 allowed[e] = set() 60 allowed[e].add((v0, v1)) 61 missed.add((v0, v1)) 62 63 self.assertAllEqual(deq_elem, sorted(elem)) 64 for e, dv0, dv1 in zip(deq_elem, deq_value_0, deq_value_1): 65 self.assertTrue((dv0, dv1) in allowed[e]) 66 missed.remove((dv0, dv1)) 67 self.assertEqual(missed, set()) 68 69 def testRoundTripInsertMultiThreadedReadOnceSorts(self): 70 with self.test_session() as sess: 71 q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), ( 72 (), ())) 73 elem = np.random.randint(-5, 5, size=100).astype(np.int64) 74 side_value_0 = np.random.rand(100).astype(bytes) 75 side_value_1 = np.random.rand(100).astype(bytes) 76 77 enqueue_ops = [ 78 q.enqueue((e, constant_op.constant(v0), constant_op.constant(v1))) 79 for e, v0, v1 in zip(elem, side_value_0, side_value_1) 80 ] 81 82 # Run one producer thread for each element in elems. 83 def enqueue(enqueue_op): 84 sess.run(enqueue_op) 85 86 dequeue_op = q.dequeue_many(100) 87 88 enqueue_threads = [ 89 self.checkedThread( 90 target=enqueue, args=(op,)) for op in enqueue_ops 91 ] 92 93 for t in enqueue_threads: 94 t.start() 95 96 deq_elem, deq_value_0, deq_value_1 = sess.run(dequeue_op) 97 98 for t in enqueue_threads: 99 t.join() 100 101 allowed = {} 102 missed = set() 103 for e, v0, v1 in zip(elem, side_value_0, side_value_1): 104 if e not in allowed: 105 allowed[e] = set() 106 allowed[e].add((v0, v1)) 107 missed.add((v0, v1)) 108 109 self.assertAllEqual(deq_elem, sorted(elem)) 110 for e, dv0, dv1 in zip(deq_elem, deq_value_0, deq_value_1): 111 self.assertTrue((dv0, dv1) in allowed[e]) 112 missed.remove((dv0, dv1)) 113 self.assertEqual(missed, set()) 114 115 def testRoundTripFillsCapacityMultiThreadedEnqueueAndDequeue(self): 116 with self.test_session() as sess: 117 q = data_flow_ops.PriorityQueue(10, (dtypes.int64), (())) 118 119 num_threads = 40 120 enqueue_counts = np.random.randint(10, size=num_threads) 121 enqueue_values = [ 122 np.random.randint( 123 5, size=count) for count in enqueue_counts 124 ] 125 enqueue_ops = [ 126 q.enqueue_many((values, values)) for values in enqueue_values 127 ] 128 shuffled_counts = copy.deepcopy(enqueue_counts) 129 random.shuffle(shuffled_counts) 130 dequeue_ops = [q.dequeue_many(count) for count in shuffled_counts] 131 all_enqueued_values = np.hstack(enqueue_values) 132 133 # Run one producer thread for each element in elems. 134 def enqueue(enqueue_op): 135 sess.run(enqueue_op) 136 137 dequeued = [] 138 139 def dequeue(dequeue_op): 140 (dequeue_indices, dequeue_values) = sess.run(dequeue_op) 141 self.assertAllEqual(dequeue_indices, dequeue_values) 142 dequeued.extend(dequeue_indices) 143 144 enqueue_threads = [ 145 self.checkedThread( 146 target=enqueue, args=(op,)) for op in enqueue_ops 147 ] 148 dequeue_threads = [ 149 self.checkedThread( 150 target=dequeue, args=(op,)) for op in dequeue_ops 151 ] 152 153 # Dequeue and check 154 for t in dequeue_threads: 155 t.start() 156 for t in enqueue_threads: 157 t.start() 158 for t in enqueue_threads: 159 t.join() 160 for t in dequeue_threads: 161 t.join() 162 163 self.assertAllEqual(sorted(dequeued), sorted(all_enqueued_values)) 164 165 def testRoundTripInsertManyMultiThreadedReadManyMultithreadedSorts(self): 166 with self.test_session() as sess: 167 q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (())) 168 169 num_threads = 40 170 enqueue_counts = np.random.randint(10, size=num_threads) 171 enqueue_values = [ 172 np.random.randint( 173 5, size=count) for count in enqueue_counts 174 ] 175 enqueue_ops = [ 176 q.enqueue_many((values, values)) for values in enqueue_values 177 ] 178 shuffled_counts = copy.deepcopy(enqueue_counts) 179 random.shuffle(shuffled_counts) 180 dequeue_ops = [q.dequeue_many(count) for count in shuffled_counts] 181 all_enqueued_values = np.hstack(enqueue_values) 182 183 dequeue_wait = threading.Condition() 184 185 # Run one producer thread for each element in elems. 186 def enqueue(enqueue_op): 187 sess.run(enqueue_op) 188 189 def dequeue(dequeue_op, dequeued): 190 (dequeue_indices, dequeue_values) = sess.run(dequeue_op) 191 self.assertAllEqual(dequeue_indices, dequeue_values) 192 dequeue_wait.acquire() 193 dequeued.extend(dequeue_indices) 194 dequeue_wait.release() 195 196 dequeued = [] 197 enqueue_threads = [ 198 self.checkedThread( 199 target=enqueue, args=(op,)) for op in enqueue_ops 200 ] 201 dequeue_threads = [ 202 self.checkedThread( 203 target=dequeue, args=(op, dequeued)) for op in dequeue_ops 204 ] 205 206 for t in enqueue_threads: 207 t.start() 208 for t in enqueue_threads: 209 t.join() 210 # Dequeue and check 211 for t in dequeue_threads: 212 t.start() 213 for t in dequeue_threads: 214 t.join() 215 216 # We can't guarantee full sorting because we can't guarantee 217 # that the dequeued.extend() call runs immediately after the 218 # sess.run() call. Here we're just happy everything came out. 219 self.assertAllEqual(set(dequeued), set(all_enqueued_values)) 220 221 def testRoundTripInsertManyMultiThreadedReadOnceSorts(self): 222 with self.test_session() as sess: 223 q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), ( 224 (), ())) 225 elem = np.random.randint(-5, 5, size=100).astype(np.int64) 226 side_value_0 = np.random.rand(100).astype(bytes) 227 side_value_1 = np.random.rand(100).astype(bytes) 228 229 batch = 5 230 enqueue_ops = [ 231 q.enqueue_many((elem[i * batch:(i + 1) * batch], 232 side_value_0[i * batch:(i + 1) * batch], 233 side_value_1[i * batch:(i + 1) * batch])) 234 for i in range(20) 235 ] 236 237 # Run one producer thread for each element in elems. 238 def enqueue(enqueue_op): 239 sess.run(enqueue_op) 240 241 dequeue_op = q.dequeue_many(100) 242 243 enqueue_threads = [ 244 self.checkedThread( 245 target=enqueue, args=(op,)) for op in enqueue_ops 246 ] 247 248 for t in enqueue_threads: 249 t.start() 250 251 deq_elem, deq_value_0, deq_value_1 = sess.run(dequeue_op) 252 253 for t in enqueue_threads: 254 t.join() 255 256 allowed = {} 257 missed = set() 258 for e, v0, v1 in zip(elem, side_value_0, side_value_1): 259 if e not in allowed: 260 allowed[e] = set() 261 allowed[e].add((v0, v1)) 262 missed.add((v0, v1)) 263 264 self.assertAllEqual(deq_elem, sorted(elem)) 265 for e, dv0, dv1 in zip(deq_elem, deq_value_0, deq_value_1): 266 self.assertTrue((dv0, dv1) in allowed[e]) 267 missed.remove((dv0, dv1)) 268 self.assertEqual(missed, set()) 269 270 def testRoundTripInsertOnceReadOnceSorts(self): 271 with self.test_session() as sess: 272 q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), ( 273 (), ())) 274 elem = np.random.randint(-100, 100, size=1000).astype(np.int64) 275 side_value_0 = np.random.rand(1000).astype(bytes) 276 side_value_1 = np.random.rand(1000).astype(bytes) 277 q.enqueue_many((elem, side_value_0, side_value_1)).run() 278 deq = q.dequeue_many(1000) 279 deq_elem, deq_value_0, deq_value_1 = sess.run(deq) 280 281 allowed = {} 282 for e, v0, v1 in zip(elem, side_value_0, side_value_1): 283 if e not in allowed: 284 allowed[e] = set() 285 allowed[e].add((v0, v1)) 286 287 self.assertAllEqual(deq_elem, sorted(elem)) 288 for e, dv0, dv1 in zip(deq_elem, deq_value_0, deq_value_1): 289 self.assertTrue((dv0, dv1) in allowed[e]) 290 291 def testRoundTripInsertOnceReadManySorts(self): 292 with self.test_session(): 293 q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (())) 294 elem = np.random.randint(-100, 100, size=1000).astype(np.int64) 295 q.enqueue_many((elem, elem)).run() 296 deq_values = np.hstack((q.dequeue_many(100)[0].eval() for _ in range(10))) 297 self.assertAllEqual(deq_values, sorted(elem)) 298 299 def testRoundTripInsertOnceReadOnceLotsSorts(self): 300 with self.test_session(): 301 q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (())) 302 elem = np.random.randint(-100, 100, size=1000).astype(np.int64) 303 q.enqueue_many((elem, elem)).run() 304 dequeue_op = q.dequeue() 305 deq_values = np.hstack(dequeue_op[0].eval() for _ in range(1000)) 306 self.assertAllEqual(deq_values, sorted(elem)) 307 308 def testInsertingNonInt64Fails(self): 309 with self.test_session(): 310 q = data_flow_ops.PriorityQueue(2000, (dtypes.string), (())) 311 with self.assertRaises(TypeError): 312 q.enqueue_many((["a", "b", "c"], ["a", "b", "c"])).run() 313 314 def testInsertingNonScalarFails(self): 315 with self.test_session() as sess: 316 input_priority = array_ops.placeholder(dtypes.int64) 317 input_other = array_ops.placeholder(dtypes.string) 318 q = data_flow_ops.PriorityQueue(2000, (dtypes.string,), (())) 319 320 with self.assertRaisesRegexp( 321 errors_impl.InvalidArgumentError, 322 r"Shape mismatch in tuple component 0. Expected \[\], got \[2\]"): 323 sess.run([q.enqueue((input_priority, input_other))], 324 feed_dict={ 325 input_priority: np.array( 326 [0, 2], dtype=np.int64), 327 input_other: np.random.rand(3, 5).astype(bytes) 328 }) 329 330 with self.assertRaisesRegexp( 331 errors_impl.InvalidArgumentError, 332 r"Shape mismatch in tuple component 0. Expected \[2\], got \[2,2\]"): 333 sess.run( 334 [q.enqueue_many((input_priority, input_other))], 335 feed_dict={ 336 input_priority: np.array( 337 [[0, 2], [3, 4]], dtype=np.int64), 338 input_other: np.random.rand(2, 3).astype(bytes) 339 }) 340 341 342 if __name__ == "__main__": 343 test.main() 344