Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for Coordinator."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import sys
     22 import threading
     23 import time
     24 
     25 from tensorflow.python.framework import errors_impl
     26 from tensorflow.python.platform import test
     27 from tensorflow.python.training import coordinator
     28 
     29 
     30 def StopOnEvent(coord, wait_for_stop, set_when_stopped):
     31   wait_for_stop.wait()
     32   coord.request_stop()
     33   set_when_stopped.set()
     34 
     35 
     36 def RaiseOnEvent(coord, wait_for_stop, set_when_stopped, ex, report_exception):
     37   try:
     38     wait_for_stop.wait()
     39     raise ex
     40   except RuntimeError as e:
     41     if report_exception:
     42       coord.request_stop(e)
     43     else:
     44       coord.request_stop(sys.exc_info())
     45   finally:
     46     if set_when_stopped:
     47       set_when_stopped.set()
     48 
     49 
     50 def RaiseOnEventUsingContextHandler(coord, wait_for_stop, set_when_stopped, ex):
     51   with coord.stop_on_exception():
     52     wait_for_stop.wait()
     53     raise ex
     54   if set_when_stopped:
     55     set_when_stopped.set()
     56 
     57 
     58 def SleepABit(n_secs, coord=None):
     59   if coord:
     60     coord.register_thread(threading.current_thread())
     61   time.sleep(n_secs)
     62 
     63 
     64 def WaitForThreadsToRegister(coord, num_threads):
     65   while True:
     66     with coord._lock:
     67       if len(coord._registered_threads) == num_threads:
     68         break
     69     time.sleep(0.001)
     70 
     71 
     72 class CoordinatorTest(test.TestCase):
     73 
     74   def testStopAPI(self):
     75     coord = coordinator.Coordinator()
     76     self.assertFalse(coord.should_stop())
     77     self.assertFalse(coord.wait_for_stop(0.01))
     78     coord.request_stop()
     79     self.assertTrue(coord.should_stop())
     80     self.assertTrue(coord.wait_for_stop(0.01))
     81 
     82   def testStopAsync(self):
     83     coord = coordinator.Coordinator()
     84     self.assertFalse(coord.should_stop())
     85     self.assertFalse(coord.wait_for_stop(0.1))
     86     wait_for_stop_ev = threading.Event()
     87     has_stopped_ev = threading.Event()
     88     t = threading.Thread(
     89         target=StopOnEvent, args=(coord, wait_for_stop_ev, has_stopped_ev))
     90     t.start()
     91     self.assertFalse(coord.should_stop())
     92     self.assertFalse(coord.wait_for_stop(0.01))
     93     wait_for_stop_ev.set()
     94     has_stopped_ev.wait()
     95     self.assertTrue(coord.wait_for_stop(0.05))
     96     self.assertTrue(coord.should_stop())
     97 
     98   def testJoin(self):
     99     coord = coordinator.Coordinator()
    100     threads = [
    101         threading.Thread(target=SleepABit, args=(0.01,)),
    102         threading.Thread(target=SleepABit, args=(0.02,)),
    103         threading.Thread(target=SleepABit, args=(0.01,))
    104     ]
    105     for t in threads:
    106       t.start()
    107     coord.join(threads)
    108     for t in threads:
    109       self.assertFalse(t.is_alive())
    110 
    111   def testJoinAllRegistered(self):
    112     coord = coordinator.Coordinator()
    113     threads = [
    114         threading.Thread(target=SleepABit, args=(0.01, coord)),
    115         threading.Thread(target=SleepABit, args=(0.02, coord)),
    116         threading.Thread(target=SleepABit, args=(0.01, coord))
    117     ]
    118     for t in threads:
    119       t.start()
    120     WaitForThreadsToRegister(coord, 3)
    121     coord.join()
    122     for t in threads:
    123       self.assertFalse(t.is_alive())
    124 
    125   def testJoinSomeRegistered(self):
    126     coord = coordinator.Coordinator()
    127     threads = [
    128         threading.Thread(target=SleepABit, args=(0.01, coord)),
    129         threading.Thread(target=SleepABit, args=(0.02,)),
    130         threading.Thread(target=SleepABit, args=(0.01, coord))
    131     ]
    132     for t in threads:
    133       t.start()
    134     WaitForThreadsToRegister(coord, 2)
    135     # threads[1] is not registered we must pass it in.
    136     coord.join(threads[1:1])
    137     for t in threads:
    138       self.assertFalse(t.is_alive())
    139 
    140   def testJoinGraceExpires(self):
    141 
    142     def TestWithGracePeriod(stop_grace_period):
    143       coord = coordinator.Coordinator()
    144       wait_for_stop_ev = threading.Event()
    145       has_stopped_ev = threading.Event()
    146       threads = [
    147           threading.Thread(
    148               target=StopOnEvent,
    149               args=(coord, wait_for_stop_ev, has_stopped_ev)),
    150           threading.Thread(target=SleepABit, args=(10.0,))
    151       ]
    152       for t in threads:
    153         t.daemon = True
    154         t.start()
    155       wait_for_stop_ev.set()
    156       has_stopped_ev.wait()
    157       with self.assertRaisesRegexp(RuntimeError, "threads still running"):
    158         coord.join(threads, stop_grace_period_secs=stop_grace_period)
    159 
    160     TestWithGracePeriod(1e-10)
    161     TestWithGracePeriod(0.002)
    162     TestWithGracePeriod(1.0)
    163 
    164   def testJoinWithoutGraceExpires(self):
    165     coord = coordinator.Coordinator()
    166     wait_for_stop_ev = threading.Event()
    167     has_stopped_ev = threading.Event()
    168     threads = [
    169         threading.Thread(
    170             target=StopOnEvent, args=(coord, wait_for_stop_ev, has_stopped_ev)),
    171         threading.Thread(target=SleepABit, args=(10.0,))
    172     ]
    173     for t in threads:
    174       t.daemon = True
    175       t.start()
    176     wait_for_stop_ev.set()
    177     has_stopped_ev.wait()
    178     coord.join(threads, stop_grace_period_secs=1., ignore_live_threads=True)
    179 
    180   def testJoinRaiseReportExcInfo(self):
    181     coord = coordinator.Coordinator()
    182     ev_1 = threading.Event()
    183     ev_2 = threading.Event()
    184     threads = [
    185         threading.Thread(
    186             target=RaiseOnEvent,
    187             args=(coord, ev_1, ev_2, RuntimeError("First"), False)),
    188         threading.Thread(
    189             target=RaiseOnEvent,
    190             args=(coord, ev_2, None, RuntimeError("Too late"), False))
    191     ]
    192     for t in threads:
    193       t.start()
    194 
    195     ev_1.set()
    196 
    197     with self.assertRaisesRegexp(RuntimeError, "First"):
    198       coord.join(threads)
    199 
    200   def testJoinRaiseReportException(self):
    201     coord = coordinator.Coordinator()
    202     ev_1 = threading.Event()
    203     ev_2 = threading.Event()
    204     threads = [
    205         threading.Thread(
    206             target=RaiseOnEvent,
    207             args=(coord, ev_1, ev_2, RuntimeError("First"), True)),
    208         threading.Thread(
    209             target=RaiseOnEvent,
    210             args=(coord, ev_2, None, RuntimeError("Too late"), True))
    211     ]
    212     for t in threads:
    213       t.start()
    214 
    215     ev_1.set()
    216     with self.assertRaisesRegexp(RuntimeError, "First"):
    217       coord.join(threads)
    218 
    219   def testJoinIgnoresOutOfRange(self):
    220     coord = coordinator.Coordinator()
    221     ev_1 = threading.Event()
    222     threads = [
    223         threading.Thread(
    224             target=RaiseOnEvent,
    225             args=(coord, ev_1, None,
    226                   errors_impl.OutOfRangeError(None, None, "First"), True))
    227     ]
    228     for t in threads:
    229       t.start()
    230 
    231     ev_1.set()
    232     coord.join(threads)
    233 
    234   def testJoinIgnoresMyExceptionType(self):
    235     coord = coordinator.Coordinator(clean_stop_exception_types=(ValueError,))
    236     ev_1 = threading.Event()
    237     threads = [
    238         threading.Thread(
    239             target=RaiseOnEvent,
    240             args=(coord, ev_1, None, ValueError("Clean stop"), True))
    241     ]
    242     for t in threads:
    243       t.start()
    244 
    245     ev_1.set()
    246     coord.join(threads)
    247 
    248   def testJoinRaiseReportExceptionUsingHandler(self):
    249     coord = coordinator.Coordinator()
    250     ev_1 = threading.Event()
    251     ev_2 = threading.Event()
    252     threads = [
    253         threading.Thread(
    254             target=RaiseOnEventUsingContextHandler,
    255             args=(coord, ev_1, ev_2, RuntimeError("First"))),
    256         threading.Thread(
    257             target=RaiseOnEventUsingContextHandler,
    258             args=(coord, ev_2, None, RuntimeError("Too late")))
    259     ]
    260     for t in threads:
    261       t.start()
    262 
    263     ev_1.set()
    264     with self.assertRaisesRegexp(RuntimeError, "First"):
    265       coord.join(threads)
    266 
    267   def testClearStopClearsExceptionToo(self):
    268     coord = coordinator.Coordinator()
    269     ev_1 = threading.Event()
    270     threads = [
    271         threading.Thread(
    272             target=RaiseOnEvent,
    273             args=(coord, ev_1, None, RuntimeError("First"), True)),
    274     ]
    275     for t in threads:
    276       t.start()
    277 
    278     with self.assertRaisesRegexp(RuntimeError, "First"):
    279       ev_1.set()
    280       coord.join(threads)
    281     coord.clear_stop()
    282     threads = [
    283         threading.Thread(
    284             target=RaiseOnEvent,
    285             args=(coord, ev_1, None, RuntimeError("Second"), True)),
    286     ]
    287     for t in threads:
    288       t.start()
    289     with self.assertRaisesRegexp(RuntimeError, "Second"):
    290       ev_1.set()
    291       coord.join(threads)
    292 
    293   def testRequestStopRaisesIfJoined(self):
    294     coord = coordinator.Coordinator()
    295     # Join the coordinator right away.
    296     coord.join([])
    297     reported = False
    298     with self.assertRaisesRegexp(RuntimeError, "Too late"):
    299       try:
    300         raise RuntimeError("Too late")
    301       except RuntimeError as e:
    302         reported = True
    303         coord.request_stop(e)
    304     self.assertTrue(reported)
    305     # If we clear_stop the exceptions are handled normally.
    306     coord.clear_stop()
    307     try:
    308       raise RuntimeError("After clear")
    309     except RuntimeError as e:
    310       coord.request_stop(e)
    311     with self.assertRaisesRegexp(RuntimeError, "After clear"):
    312       coord.join([])
    313 
    314   def testRequestStopRaisesIfJoined_ExcInfo(self):
    315     # Same as testRequestStopRaisesIfJoined but using syc.exc_info().
    316     coord = coordinator.Coordinator()
    317     # Join the coordinator right away.
    318     coord.join([])
    319     reported = False
    320     with self.assertRaisesRegexp(RuntimeError, "Too late"):
    321       try:
    322         raise RuntimeError("Too late")
    323       except RuntimeError:
    324         reported = True
    325         coord.request_stop(sys.exc_info())
    326     self.assertTrue(reported)
    327     # If we clear_stop the exceptions are handled normally.
    328     coord.clear_stop()
    329     try:
    330       raise RuntimeError("After clear")
    331     except RuntimeError:
    332       coord.request_stop(sys.exc_info())
    333     with self.assertRaisesRegexp(RuntimeError, "After clear"):
    334       coord.join([])
    335 
    336 
    337 def _StopAt0(coord, n):
    338   if n[0] == 0:
    339     coord.request_stop()
    340   else:
    341     n[0] -= 1
    342 
    343 
    344 class LooperTest(test.TestCase):
    345 
    346   def testTargetArgs(self):
    347     n = [3]
    348     coord = coordinator.Coordinator()
    349     thread = coordinator.LooperThread.loop(
    350         coord, 0, target=_StopAt0, args=(coord, n))
    351     coord.join([thread])
    352     self.assertEqual(0, n[0])
    353 
    354   def testTargetKwargs(self):
    355     n = [3]
    356     coord = coordinator.Coordinator()
    357     thread = coordinator.LooperThread.loop(
    358         coord, 0, target=_StopAt0, kwargs={
    359             "coord": coord,
    360             "n": n
    361         })
    362     coord.join([thread])
    363     self.assertEqual(0, n[0])
    364 
    365   def testTargetMixedArgs(self):
    366     n = [3]
    367     coord = coordinator.Coordinator()
    368     thread = coordinator.LooperThread.loop(
    369         coord, 0, target=_StopAt0, args=(coord,), kwargs={
    370             "n": n
    371         })
    372     coord.join([thread])
    373     self.assertEqual(0, n[0])
    374 
    375 
    376 if __name__ == "__main__":
    377   test.main()
    378