Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.ops.data_flow_ops.PriorityQueue."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import copy
     22 import random
     23 import threading
     24 
     25 import numpy as np
     26 
     27 from tensorflow.python.framework import constant_op
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.framework import errors_impl
     30 from tensorflow.python.ops import array_ops
     31 from tensorflow.python.ops import data_flow_ops
     32 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
     33 from tensorflow.python.platform import test
     34 
     35 
     36 class PriorityQueueTest(test.TestCase):
     37 
     38   def testRoundTripInsertReadOnceSorts(self):
     39     with self.test_session() as sess:
     40       q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
     41           (), ()))
     42       elem = np.random.randint(-5, 5, size=100).astype(np.int64)
     43       side_value_0 = np.random.rand(100).astype(bytes)
     44       side_value_1 = np.random.rand(100).astype(bytes)
     45       enq_list = [
     46           q.enqueue((e, constant_op.constant(v0), constant_op.constant(v1)))
     47           for e, v0, v1 in zip(elem, side_value_0, side_value_1)
     48       ]
     49       for enq in enq_list:
     50         enq.run()
     51 
     52       deq = q.dequeue_many(100)
     53       deq_elem, deq_value_0, deq_value_1 = sess.run(deq)
     54 
     55       allowed = {}
     56       missed = set()
     57       for e, v0, v1 in zip(elem, side_value_0, side_value_1):
     58         if e not in allowed:
     59           allowed[e] = set()
     60         allowed[e].add((v0, v1))
     61         missed.add((v0, v1))
     62 
     63       self.assertAllEqual(deq_elem, sorted(elem))
     64       for e, dv0, dv1 in zip(deq_elem, deq_value_0, deq_value_1):
     65         self.assertTrue((dv0, dv1) in allowed[e])
     66         missed.remove((dv0, dv1))
     67       self.assertEqual(missed, set())
     68 
     69   def testRoundTripInsertMultiThreadedReadOnceSorts(self):
     70     with self.test_session() as sess:
     71       q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
     72           (), ()))
     73       elem = np.random.randint(-5, 5, size=100).astype(np.int64)
     74       side_value_0 = np.random.rand(100).astype(bytes)
     75       side_value_1 = np.random.rand(100).astype(bytes)
     76 
     77       enqueue_ops = [
     78           q.enqueue((e, constant_op.constant(v0), constant_op.constant(v1)))
     79           for e, v0, v1 in zip(elem, side_value_0, side_value_1)
     80       ]
     81 
     82       # Run one producer thread for each element in elems.
     83       def enqueue(enqueue_op):
     84         sess.run(enqueue_op)
     85 
     86       dequeue_op = q.dequeue_many(100)
     87 
     88       enqueue_threads = [
     89           self.checkedThread(
     90               target=enqueue, args=(op,)) for op in enqueue_ops
     91       ]
     92 
     93       for t in enqueue_threads:
     94         t.start()
     95 
     96       deq_elem, deq_value_0, deq_value_1 = sess.run(dequeue_op)
     97 
     98       for t in enqueue_threads:
     99         t.join()
    100 
    101       allowed = {}
    102       missed = set()
    103       for e, v0, v1 in zip(elem, side_value_0, side_value_1):
    104         if e not in allowed:
    105           allowed[e] = set()
    106         allowed[e].add((v0, v1))
    107         missed.add((v0, v1))
    108 
    109       self.assertAllEqual(deq_elem, sorted(elem))
    110       for e, dv0, dv1 in zip(deq_elem, deq_value_0, deq_value_1):
    111         self.assertTrue((dv0, dv1) in allowed[e])
    112         missed.remove((dv0, dv1))
    113       self.assertEqual(missed, set())
    114 
    115   def testRoundTripFillsCapacityMultiThreadedEnqueueAndDequeue(self):
    116     with self.test_session() as sess:
    117       q = data_flow_ops.PriorityQueue(10, (dtypes.int64), (()))
    118 
    119       num_threads = 40
    120       enqueue_counts = np.random.randint(10, size=num_threads)
    121       enqueue_values = [
    122           np.random.randint(
    123               5, size=count) for count in enqueue_counts
    124       ]
    125       enqueue_ops = [
    126           q.enqueue_many((values, values)) for values in enqueue_values
    127       ]
    128       shuffled_counts = copy.deepcopy(enqueue_counts)
    129       random.shuffle(shuffled_counts)
    130       dequeue_ops = [q.dequeue_many(count) for count in shuffled_counts]
    131       all_enqueued_values = np.hstack(enqueue_values)
    132 
    133       # Run one producer thread for each element in elems.
    134       def enqueue(enqueue_op):
    135         sess.run(enqueue_op)
    136 
    137       dequeued = []
    138 
    139       def dequeue(dequeue_op):
    140         (dequeue_indices, dequeue_values) = sess.run(dequeue_op)
    141         self.assertAllEqual(dequeue_indices, dequeue_values)
    142         dequeued.extend(dequeue_indices)
    143 
    144       enqueue_threads = [
    145           self.checkedThread(
    146               target=enqueue, args=(op,)) for op in enqueue_ops
    147       ]
    148       dequeue_threads = [
    149           self.checkedThread(
    150               target=dequeue, args=(op,)) for op in dequeue_ops
    151       ]
    152 
    153       # Dequeue and check
    154       for t in dequeue_threads:
    155         t.start()
    156       for t in enqueue_threads:
    157         t.start()
    158       for t in enqueue_threads:
    159         t.join()
    160       for t in dequeue_threads:
    161         t.join()
    162 
    163       self.assertAllEqual(sorted(dequeued), sorted(all_enqueued_values))
    164 
    165   def testRoundTripInsertManyMultiThreadedReadManyMultithreadedSorts(self):
    166     with self.test_session() as sess:
    167       q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (()))
    168 
    169       num_threads = 40
    170       enqueue_counts = np.random.randint(10, size=num_threads)
    171       enqueue_values = [
    172           np.random.randint(
    173               5, size=count) for count in enqueue_counts
    174       ]
    175       enqueue_ops = [
    176           q.enqueue_many((values, values)) for values in enqueue_values
    177       ]
    178       shuffled_counts = copy.deepcopy(enqueue_counts)
    179       random.shuffle(shuffled_counts)
    180       dequeue_ops = [q.dequeue_many(count) for count in shuffled_counts]
    181       all_enqueued_values = np.hstack(enqueue_values)
    182 
    183       dequeue_wait = threading.Condition()
    184 
    185       # Run one producer thread for each element in elems.
    186       def enqueue(enqueue_op):
    187         sess.run(enqueue_op)
    188 
    189       def dequeue(dequeue_op, dequeued):
    190         (dequeue_indices, dequeue_values) = sess.run(dequeue_op)
    191         self.assertAllEqual(dequeue_indices, dequeue_values)
    192         dequeue_wait.acquire()
    193         dequeued.extend(dequeue_indices)
    194         dequeue_wait.release()
    195 
    196       dequeued = []
    197       enqueue_threads = [
    198           self.checkedThread(
    199               target=enqueue, args=(op,)) for op in enqueue_ops
    200       ]
    201       dequeue_threads = [
    202           self.checkedThread(
    203               target=dequeue, args=(op, dequeued)) for op in dequeue_ops
    204       ]
    205 
    206       for t in enqueue_threads:
    207         t.start()
    208       for t in enqueue_threads:
    209         t.join()
    210       # Dequeue and check
    211       for t in dequeue_threads:
    212         t.start()
    213       for t in dequeue_threads:
    214         t.join()
    215 
    216       # We can't guarantee full sorting because we can't guarantee
    217       # that the dequeued.extend() call runs immediately after the
    218       # sess.run() call.  Here we're just happy everything came out.
    219       self.assertAllEqual(set(dequeued), set(all_enqueued_values))
    220 
    221   def testRoundTripInsertManyMultiThreadedReadOnceSorts(self):
    222     with self.test_session() as sess:
    223       q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
    224           (), ()))
    225       elem = np.random.randint(-5, 5, size=100).astype(np.int64)
    226       side_value_0 = np.random.rand(100).astype(bytes)
    227       side_value_1 = np.random.rand(100).astype(bytes)
    228 
    229       batch = 5
    230       enqueue_ops = [
    231           q.enqueue_many((elem[i * batch:(i + 1) * batch],
    232                           side_value_0[i * batch:(i + 1) * batch],
    233                           side_value_1[i * batch:(i + 1) * batch]))
    234           for i in range(20)
    235       ]
    236 
    237       # Run one producer thread for each element in elems.
    238       def enqueue(enqueue_op):
    239         sess.run(enqueue_op)
    240 
    241       dequeue_op = q.dequeue_many(100)
    242 
    243       enqueue_threads = [
    244           self.checkedThread(
    245               target=enqueue, args=(op,)) for op in enqueue_ops
    246       ]
    247 
    248       for t in enqueue_threads:
    249         t.start()
    250 
    251       deq_elem, deq_value_0, deq_value_1 = sess.run(dequeue_op)
    252 
    253       for t in enqueue_threads:
    254         t.join()
    255 
    256       allowed = {}
    257       missed = set()
    258       for e, v0, v1 in zip(elem, side_value_0, side_value_1):
    259         if e not in allowed:
    260           allowed[e] = set()
    261         allowed[e].add((v0, v1))
    262         missed.add((v0, v1))
    263 
    264       self.assertAllEqual(deq_elem, sorted(elem))
    265       for e, dv0, dv1 in zip(deq_elem, deq_value_0, deq_value_1):
    266         self.assertTrue((dv0, dv1) in allowed[e])
    267         missed.remove((dv0, dv1))
    268       self.assertEqual(missed, set())
    269 
    270   def testRoundTripInsertOnceReadOnceSorts(self):
    271     with self.test_session() as sess:
    272       q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
    273           (), ()))
    274       elem = np.random.randint(-100, 100, size=1000).astype(np.int64)
    275       side_value_0 = np.random.rand(1000).astype(bytes)
    276       side_value_1 = np.random.rand(1000).astype(bytes)
    277       q.enqueue_many((elem, side_value_0, side_value_1)).run()
    278       deq = q.dequeue_many(1000)
    279       deq_elem, deq_value_0, deq_value_1 = sess.run(deq)
    280 
    281       allowed = {}
    282       for e, v0, v1 in zip(elem, side_value_0, side_value_1):
    283         if e not in allowed:
    284           allowed[e] = set()
    285         allowed[e].add((v0, v1))
    286 
    287       self.assertAllEqual(deq_elem, sorted(elem))
    288       for e, dv0, dv1 in zip(deq_elem, deq_value_0, deq_value_1):
    289         self.assertTrue((dv0, dv1) in allowed[e])
    290 
    291   def testRoundTripInsertOnceReadManySorts(self):
    292     with self.test_session():
    293       q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (()))
    294       elem = np.random.randint(-100, 100, size=1000).astype(np.int64)
    295       q.enqueue_many((elem, elem)).run()
    296       deq_values = np.hstack((q.dequeue_many(100)[0].eval() for _ in range(10)))
    297       self.assertAllEqual(deq_values, sorted(elem))
    298 
    299   def testRoundTripInsertOnceReadOnceLotsSorts(self):
    300     with self.test_session():
    301       q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (()))
    302       elem = np.random.randint(-100, 100, size=1000).astype(np.int64)
    303       q.enqueue_many((elem, elem)).run()
    304       dequeue_op = q.dequeue()
    305       deq_values = np.hstack(dequeue_op[0].eval() for _ in range(1000))
    306       self.assertAllEqual(deq_values, sorted(elem))
    307 
    308   def testInsertingNonInt64Fails(self):
    309     with self.test_session():
    310       q = data_flow_ops.PriorityQueue(2000, (dtypes.string), (()))
    311       with self.assertRaises(TypeError):
    312         q.enqueue_many((["a", "b", "c"], ["a", "b", "c"])).run()
    313 
    314   def testInsertingNonScalarFails(self):
    315     with self.test_session() as sess:
    316       input_priority = array_ops.placeholder(dtypes.int64)
    317       input_other = array_ops.placeholder(dtypes.string)
    318       q = data_flow_ops.PriorityQueue(2000, (dtypes.string,), (()))
    319 
    320       with self.assertRaisesRegexp(
    321           errors_impl.InvalidArgumentError,
    322           r"Shape mismatch in tuple component 0. Expected \[\], got \[2\]"):
    323         sess.run([q.enqueue((input_priority, input_other))],
    324                  feed_dict={
    325                      input_priority: np.array(
    326                          [0, 2], dtype=np.int64),
    327                      input_other: np.random.rand(3, 5).astype(bytes)
    328                  })
    329 
    330       with self.assertRaisesRegexp(
    331           errors_impl.InvalidArgumentError,
    332           r"Shape mismatch in tuple component 0. Expected \[2\], got \[2,2\]"):
    333         sess.run(
    334             [q.enqueue_many((input_priority, input_other))],
    335             feed_dict={
    336                 input_priority: np.array(
    337                     [[0, 2], [3, 4]], dtype=np.int64),
    338                 input_other: np.random.rand(2, 3).astype(bytes)
    339             })
    340 
    341 
    342 if __name__ == "__main__":
    343   test.main()
    344