Home | History | Annotate | Download | only in test
      1 import unittest
      2 from doctest import DocTestSuite
      3 from test import test_support as support
      4 import weakref
      5 import gc
      6 
      7 # Modules under test
      8 _thread = support.import_module('thread')
      9 threading = support.import_module('threading')
     10 import _threading_local
     11 
     12 
     13 class Weak(object):
     14     pass
     15 
     16 def target(local, weaklist):
     17     weak = Weak()
     18     local.weak = weak
     19     weaklist.append(weakref.ref(weak))
     20 
     21 class BaseLocalTest:
     22 
     23     def test_local_refs(self):
     24         self._local_refs(20)
     25         self._local_refs(50)
     26         self._local_refs(100)
     27 
     28     def _local_refs(self, n):
     29         local = self._local()
     30         weaklist = []
     31         for i in range(n):
     32             t = threading.Thread(target=target, args=(local, weaklist))
     33             t.start()
     34             t.join()
     35         del t
     36 
     37         gc.collect()
     38         self.assertEqual(len(weaklist), n)
     39 
     40         # XXX _threading_local keeps the local of the last stopped thread alive.
     41         deadlist = [weak for weak in weaklist if weak() is None]
     42         self.assertIn(len(deadlist), (n-1, n))
     43 
     44         # Assignment to the same thread local frees it sometimes (!)
     45         local.someothervar = None
     46         gc.collect()
     47         deadlist = [weak for weak in weaklist if weak() is None]
     48         self.assertIn(len(deadlist), (n-1, n), (n, len(deadlist)))
     49 
     50     def test_derived(self):
     51         # Issue 3088: if there is a threads switch inside the __init__
     52         # of a threading.local derived class, the per-thread dictionary
     53         # is created but not correctly set on the object.
     54         # The first member set may be bogus.
     55         import time
     56         class Local(self._local):
     57             def __init__(self):
     58                 time.sleep(0.01)
     59         local = Local()
     60 
     61         def f(i):
     62             local.x = i
     63             # Simply check that the variable is correctly set
     64             self.assertEqual(local.x, i)
     65 
     66         with support.start_threads(threading.Thread(target=f, args=(i,))
     67                                    for i in range(10)):
     68             pass
     69 
     70     def test_derived_cycle_dealloc(self):
     71         # http://bugs.python.org/issue6990
     72         class Local(self._local):
     73             pass
     74         locals = None
     75         passed = [False]
     76         e1 = threading.Event()
     77         e2 = threading.Event()
     78 
     79         def f():
     80             # 1) Involve Local in a cycle
     81             cycle = [Local()]
     82             cycle.append(cycle)
     83             cycle[0].foo = 'bar'
     84 
     85             # 2) GC the cycle (triggers threadmodule.c::local_clear
     86             # before local_dealloc)
     87             del cycle
     88             gc.collect()
     89             e1.set()
     90             e2.wait()
     91 
     92             # 4) New Locals should be empty
     93             passed[0] = all(not hasattr(local, 'foo') for local in locals)
     94 
     95         t = threading.Thread(target=f)
     96         t.start()
     97         e1.wait()
     98 
     99         # 3) New Locals should recycle the original's address. Creating
    100         # them in the thread overwrites the thread state and avoids the
    101         # bug
    102         locals = [Local() for i in range(10)]
    103         e2.set()
    104         t.join()
    105 
    106         self.assertTrue(passed[0])
    107 
    108     def test_arguments(self):
    109         # Issue 1522237
    110         from thread import _local as local
    111         from _threading_local import local as py_local
    112 
    113         for cls in (local, py_local):
    114             class MyLocal(cls):
    115                 def __init__(self, *args, **kwargs):
    116                     pass
    117 
    118             MyLocal(a=1)
    119             MyLocal(1)
    120             self.assertRaises(TypeError, cls, a=1)
    121             self.assertRaises(TypeError, cls, 1)
    122 
    123     def _test_one_class(self, c):
    124         self._failed = "No error message set or cleared."
    125         obj = c()
    126         e1 = threading.Event()
    127         e2 = threading.Event()
    128 
    129         def f1():
    130             obj.x = 'foo'
    131             obj.y = 'bar'
    132             del obj.y
    133             e1.set()
    134             e2.wait()
    135 
    136         def f2():
    137             try:
    138                 foo = obj.x
    139             except AttributeError:
    140                 # This is expected -- we haven't set obj.x in this thread yet!
    141                 self._failed = ""  # passed
    142             else:
    143                 self._failed = ('Incorrectly got value %r from class %r\n' %
    144                                 (foo, c))
    145                 sys.stderr.write(self._failed)
    146 
    147         t1 = threading.Thread(target=f1)
    148         t1.start()
    149         e1.wait()
    150         t2 = threading.Thread(target=f2)
    151         t2.start()
    152         t2.join()
    153         # The test is done; just let t1 know it can exit, and wait for it.
    154         e2.set()
    155         t1.join()
    156 
    157         self.assertFalse(self._failed, self._failed)
    158 
    159     def test_threading_local(self):
    160         self._test_one_class(self._local)
    161 
    162     def test_threading_local_subclass(self):
    163         class LocalSubclass(self._local):
    164             """To test that subclasses behave properly."""
    165         self._test_one_class(LocalSubclass)
    166 
    167     def _test_dict_attribute(self, cls):
    168         obj = cls()
    169         obj.x = 5
    170         self.assertEqual(obj.__dict__, {'x': 5})
    171         with self.assertRaises(AttributeError):
    172             obj.__dict__ = {}
    173         with self.assertRaises(AttributeError):
    174             del obj.__dict__
    175 
    176     def test_dict_attribute(self):
    177         self._test_dict_attribute(self._local)
    178 
    179     def test_dict_attribute_subclass(self):
    180         class LocalSubclass(self._local):
    181             """To test that subclasses behave properly."""
    182         self._test_dict_attribute(LocalSubclass)
    183 
    184 
    185 class ThreadLocalTest(unittest.TestCase, BaseLocalTest):
    186     _local = _thread._local
    187 
    188     # Fails for the pure Python implementation
    189     def test_cycle_collection(self):
    190         class X:
    191             pass
    192 
    193         x = X()
    194         x.local = self._local()
    195         x.local.x = x
    196         wr = weakref.ref(x)
    197         del x
    198         gc.collect()
    199         self.assertIsNone(wr())
    200 
    201 class PyThreadingLocalTest(unittest.TestCase, BaseLocalTest):
    202     _local = _threading_local.local
    203 
    204 
    205 def test_main():
    206     suite = unittest.TestSuite()
    207     suite.addTest(DocTestSuite('_threading_local'))
    208     suite.addTest(unittest.makeSuite(ThreadLocalTest))
    209     suite.addTest(unittest.makeSuite(PyThreadingLocalTest))
    210 
    211     try:
    212         from thread import _local
    213     except ImportError:
    214         pass
    215     else:
    216         import _threading_local
    217         local_orig = _threading_local.local
    218         def setUp(test):
    219             _threading_local.local = _local
    220         def tearDown(test):
    221             _threading_local.local = local_orig
    222         suite.addTest(DocTestSuite('_threading_local',
    223                                    setUp=setUp, tearDown=tearDown)
    224                       )
    225 
    226     support.run_unittest(suite)
    227 
    228 if __name__ == '__main__':
    229     test_main()
    230