Home | History | Annotate | Download | only in test
      1 import unittest
      2 from doctest import DocTestSuite
      3 from test import test_support
      4 import weakref
      5 import gc
      6 
      7 # Modules under test
      8 _thread = test_support.import_module('thread')
      9 threading = test_support.import_module('threading')
     10 import _threading_local
     11 
     12 
     13 class Weak(object):
     14     pass
     15 
     16 def target(local, weaklist):
     17     weak = Weak()
     18     local.weak = weak
     19     weaklist.append(weakref.ref(weak))
     20 
     21 class BaseLocalTest:
     22 
     23     def test_local_refs(self):
     24         self._local_refs(20)
     25         self._local_refs(50)
     26         self._local_refs(100)
     27 
     28     def _local_refs(self, n):
     29         local = self._local()
     30         weaklist = []
     31         for i in range(n):
     32             t = threading.Thread(target=target, args=(local, weaklist))
     33             t.start()
     34             t.join()
     35         del t
     36 
     37         gc.collect()
     38         self.assertEqual(len(weaklist), n)
     39 
     40         # XXX _threading_local keeps the local of the last stopped thread alive.
     41         deadlist = [weak for weak in weaklist if weak() is None]
     42         self.assertIn(len(deadlist), (n-1, n))
     43 
     44         # Assignment to the same thread local frees it sometimes (!)
     45         local.someothervar = None
     46         gc.collect()
     47         deadlist = [weak for weak in weaklist if weak() is None]
     48         self.assertIn(len(deadlist), (n-1, n), (n, len(deadlist)))
     49 
     50     def test_derived(self):
     51         # Issue 3088: if there is a threads switch inside the __init__
     52         # of a threading.local derived class, the per-thread dictionary
     53         # is created but not correctly set on the object.
     54         # The first member set may be bogus.
     55         import time
     56         class Local(self._local):
     57             def __init__(self):
     58                 time.sleep(0.01)
     59         local = Local()
     60 
     61         def f(i):
     62             local.x = i
     63             # Simply check that the variable is correctly set
     64             self.assertEqual(local.x, i)
     65 
     66         threads= []
     67         for i in range(10):
     68             t = threading.Thread(target=f, args=(i,))
     69             t.start()
     70             threads.append(t)
     71 
     72         for t in threads:
     73             t.join()
     74 
     75     def test_derived_cycle_dealloc(self):
     76         # http://bugs.python.org/issue6990
     77         class Local(self._local):
     78             pass
     79         locals = None
     80         passed = [False]
     81         e1 = threading.Event()
     82         e2 = threading.Event()
     83 
     84         def f():
     85             # 1) Involve Local in a cycle
     86             cycle = [Local()]
     87             cycle.append(cycle)
     88             cycle[0].foo = 'bar'
     89 
     90             # 2) GC the cycle (triggers threadmodule.c::local_clear
     91             # before local_dealloc)
     92             del cycle
     93             gc.collect()
     94             e1.set()
     95             e2.wait()
     96 
     97             # 4) New Locals should be empty
     98             passed[0] = all(not hasattr(local, 'foo') for local in locals)
     99 
    100         t = threading.Thread(target=f)
    101         t.start()
    102         e1.wait()
    103 
    104         # 3) New Locals should recycle the original's address. Creating
    105         # them in the thread overwrites the thread state and avoids the
    106         # bug
    107         locals = [Local() for i in range(10)]
    108         e2.set()
    109         t.join()
    110 
    111         self.assertTrue(passed[0])
    112 
    113     def test_arguments(self):
    114         # Issue 1522237
    115         from thread import _local as local
    116         from _threading_local import local as py_local
    117 
    118         for cls in (local, py_local):
    119             class MyLocal(cls):
    120                 def __init__(self, *args, **kwargs):
    121                     pass
    122 
    123             MyLocal(a=1)
    124             MyLocal(1)
    125             self.assertRaises(TypeError, cls, a=1)
    126             self.assertRaises(TypeError, cls, 1)
    127 
    128     def _test_one_class(self, c):
    129         self._failed = "No error message set or cleared."
    130         obj = c()
    131         e1 = threading.Event()
    132         e2 = threading.Event()
    133 
    134         def f1():
    135             obj.x = 'foo'
    136             obj.y = 'bar'
    137             del obj.y
    138             e1.set()
    139             e2.wait()
    140 
    141         def f2():
    142             try:
    143                 foo = obj.x
    144             except AttributeError:
    145                 # This is expected -- we haven't set obj.x in this thread yet!
    146                 self._failed = ""  # passed
    147             else:
    148                 self._failed = ('Incorrectly got value %r from class %r\n' %
    149                                 (foo, c))
    150                 sys.stderr.write(self._failed)
    151 
    152         t1 = threading.Thread(target=f1)
    153         t1.start()
    154         e1.wait()
    155         t2 = threading.Thread(target=f2)
    156         t2.start()
    157         t2.join()
    158         # The test is done; just let t1 know it can exit, and wait for it.
    159         e2.set()
    160         t1.join()
    161 
    162         self.assertFalse(self._failed, self._failed)
    163 
    164     def test_threading_local(self):
    165         self._test_one_class(self._local)
    166 
    167     def test_threading_local_subclass(self):
    168         class LocalSubclass(self._local):
    169             """To test that subclasses behave properly."""
    170         self._test_one_class(LocalSubclass)
    171 
    172     def _test_dict_attribute(self, cls):
    173         obj = cls()
    174         obj.x = 5
    175         self.assertEqual(obj.__dict__, {'x': 5})
    176         with self.assertRaises(AttributeError):
    177             obj.__dict__ = {}
    178         with self.assertRaises(AttributeError):
    179             del obj.__dict__
    180 
    181     def test_dict_attribute(self):
    182         self._test_dict_attribute(self._local)
    183 
    184     def test_dict_attribute_subclass(self):
    185         class LocalSubclass(self._local):
    186             """To test that subclasses behave properly."""
    187         self._test_dict_attribute(LocalSubclass)
    188 
    189 
    190 class ThreadLocalTest(unittest.TestCase, BaseLocalTest):
    191     _local = _thread._local
    192 
    193     # Fails for the pure Python implementation
    194     def test_cycle_collection(self):
    195         class X:
    196             pass
    197 
    198         x = X()
    199         x.local = self._local()
    200         x.local.x = x
    201         wr = weakref.ref(x)
    202         del x
    203         gc.collect()
    204         self.assertIs(wr(), None)
    205 
    206 class PyThreadingLocalTest(unittest.TestCase, BaseLocalTest):
    207     _local = _threading_local.local
    208 
    209 
    210 def test_main():
    211     suite = unittest.TestSuite()
    212     suite.addTest(DocTestSuite('_threading_local'))
    213     suite.addTest(unittest.makeSuite(ThreadLocalTest))
    214     suite.addTest(unittest.makeSuite(PyThreadingLocalTest))
    215 
    216     try:
    217         from thread import _local
    218     except ImportError:
    219         pass
    220     else:
    221         import _threading_local
    222         local_orig = _threading_local.local
    223         def setUp(test):
    224             _threading_local.local = _local
    225         def tearDown(test):
    226             _threading_local.local = local_orig
    227         suite.addTest(DocTestSuite('_threading_local',
    228                                    setUp=setUp, tearDown=tearDown)
    229                       )
    230 
    231     test_support.run_unittest(suite)
    232 
    233 if __name__ == '__main__':
    234     test_main()
    235