Home | History | Annotate | Download | only in test
      1 """TestCases for multi-threaded access to a DB.
      2 """
      3 
      4 import os
      5 import sys
      6 import time
      7 import errno
      8 from random import random
      9 
     10 DASH = '-'
     11 
     12 try:
     13     WindowsError
     14 except NameError:
     15     class WindowsError(Exception):
     16         pass
     17 
     18 import unittest
     19 from test_all import db, dbutils, test_support, verbose, have_threads, \
     20         get_new_environment_path, get_new_database_path
     21 
     22 if have_threads :
     23     from threading import Thread
     24     if sys.version_info[0] < 3 :
     25         from threading import currentThread
     26     else :
     27         from threading import current_thread as currentThread
     28 
     29 
     30 #----------------------------------------------------------------------
     31 
     32 class BaseThreadedTestCase(unittest.TestCase):
     33     dbtype       = db.DB_UNKNOWN  # must be set in derived class
     34     dbopenflags  = 0
     35     dbsetflags   = 0
     36     envflags     = 0
     37 
     38     def setUp(self):
     39         if verbose:
     40             dbutils._deadlock_VerboseFile = sys.stdout
     41 
     42         self.homeDir = get_new_environment_path()
     43         self.env = db.DBEnv()
     44         self.setEnvOpts()
     45         self.env.open(self.homeDir, self.envflags | db.DB_CREATE)
     46 
     47         self.filename = self.__class__.__name__ + '.db'
     48         self.d = db.DB(self.env)
     49         if self.dbsetflags:
     50             self.d.set_flags(self.dbsetflags)
     51         self.d.open(self.filename, self.dbtype, self.dbopenflags|db.DB_CREATE)
     52 
     53     def tearDown(self):
     54         self.d.close()
     55         self.env.close()
     56         test_support.rmtree(self.homeDir)
     57 
     58     def setEnvOpts(self):
     59         pass
     60 
     61     def makeData(self, key):
     62         return DASH.join([key] * 5)
     63 
     64 
     65 #----------------------------------------------------------------------
     66 
     67 
     68 class ConcurrentDataStoreBase(BaseThreadedTestCase):
     69     dbopenflags = db.DB_THREAD
     70     envflags    = db.DB_THREAD | db.DB_INIT_CDB | db.DB_INIT_MPOOL
     71     readers     = 0 # derived class should set
     72     writers     = 0
     73     records     = 1000
     74 
     75     def test01_1WriterMultiReaders(self):
     76         if verbose:
     77             print '\n', '-=' * 30
     78             print "Running %s.test01_1WriterMultiReaders..." % \
     79                   self.__class__.__name__
     80 
     81         keys=range(self.records)
     82         import random
     83         random.shuffle(keys)
     84         records_per_writer=self.records//self.writers
     85         readers_per_writer=self.readers//self.writers
     86         self.assertEqual(self.records,self.writers*records_per_writer)
     87         self.assertEqual(self.readers,self.writers*readers_per_writer)
     88         self.assertTrue((records_per_writer%readers_per_writer)==0)
     89         readers = []
     90 
     91         for x in xrange(self.readers):
     92             rt = Thread(target = self.readerThread,
     93                         args = (self.d, x),
     94                         name = 'reader %d' % x,
     95                         )#verbose = verbose)
     96             if sys.version_info[0] < 3 :
     97                 rt.setDaemon(True)
     98             else :
     99                 rt.daemon = True
    100             readers.append(rt)
    101 
    102         writers=[]
    103         for x in xrange(self.writers):
    104             a=keys[records_per_writer*x:records_per_writer*(x+1)]
    105             a.sort()  # Generate conflicts
    106             b=readers[readers_per_writer*x:readers_per_writer*(x+1)]
    107             wt = Thread(target = self.writerThread,
    108                         args = (self.d, a, b),
    109                         name = 'writer %d' % x,
    110                         )#verbose = verbose)
    111             writers.append(wt)
    112 
    113         for t in writers:
    114             if sys.version_info[0] < 3 :
    115                 t.setDaemon(True)
    116             else :
    117                 t.daemon = True
    118             t.start()
    119 
    120         for t in writers:
    121             t.join()
    122         for t in readers:
    123             t.join()
    124 
    125     def writerThread(self, d, keys, readers):
    126         if sys.version_info[0] < 3 :
    127             name = currentThread().getName()
    128         else :
    129             name = currentThread().name
    130 
    131         if verbose:
    132             print "%s: creating records %d - %d" % (name, start, stop)
    133 
    134         count=len(keys)//len(readers)
    135         count2=count
    136         for x in keys :
    137             key = '%04d' % x
    138             dbutils.DeadlockWrap(d.put, key, self.makeData(key),
    139                                  max_retries=12)
    140             if verbose and x % 100 == 0:
    141                 print "%s: records %d - %d finished" % (name, start, x)
    142 
    143             count2-=1
    144             if not count2 :
    145                 readers.pop().start()
    146                 count2=count
    147 
    148         if verbose:
    149             print "%s: finished creating records" % name
    150 
    151         if verbose:
    152             print "%s: thread finished" % name
    153 
    154     def readerThread(self, d, readerNum):
    155         if sys.version_info[0] < 3 :
    156             name = currentThread().getName()
    157         else :
    158             name = currentThread().name
    159 
    160         for i in xrange(5) :
    161             c = d.cursor()
    162             count = 0
    163             rec = c.first()
    164             while rec:
    165                 count += 1
    166                 key, data = rec
    167                 self.assertEqual(self.makeData(key), data)
    168                 rec = c.next()
    169             if verbose:
    170                 print "%s: found %d records" % (name, count)
    171             c.close()
    172 
    173         if verbose:
    174             print "%s: thread finished" % name
    175 
    176 
    177 class BTreeConcurrentDataStore(ConcurrentDataStoreBase):
    178     dbtype  = db.DB_BTREE
    179     writers = 2
    180     readers = 10
    181     records = 1000
    182 
    183 
    184 class HashConcurrentDataStore(ConcurrentDataStoreBase):
    185     dbtype  = db.DB_HASH
    186     writers = 2
    187     readers = 10
    188     records = 1000
    189 
    190 
    191 #----------------------------------------------------------------------
    192 
    193 class SimpleThreadedBase(BaseThreadedTestCase):
    194     dbopenflags = db.DB_THREAD
    195     envflags    = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK
    196     readers = 10
    197     writers = 2
    198     records = 1000
    199 
    200     def setEnvOpts(self):
    201         self.env.set_lk_detect(db.DB_LOCK_DEFAULT)
    202 
    203     def test02_SimpleLocks(self):
    204         if verbose:
    205             print '\n', '-=' * 30
    206             print "Running %s.test02_SimpleLocks..." % self.__class__.__name__
    207 
    208 
    209         keys=range(self.records)
    210         import random
    211         random.shuffle(keys)
    212         records_per_writer=self.records//self.writers
    213         readers_per_writer=self.readers//self.writers
    214         self.assertEqual(self.records,self.writers*records_per_writer)
    215         self.assertEqual(self.readers,self.writers*readers_per_writer)
    216         self.assertTrue((records_per_writer%readers_per_writer)==0)
    217 
    218         readers = []
    219         for x in xrange(self.readers):
    220             rt = Thread(target = self.readerThread,
    221                         args = (self.d, x),
    222                         name = 'reader %d' % x,
    223                         )#verbose = verbose)
    224             if sys.version_info[0] < 3 :
    225                 rt.setDaemon(True)
    226             else :
    227                 rt.daemon = True
    228             readers.append(rt)
    229 
    230         writers = []
    231         for x in xrange(self.writers):
    232             a=keys[records_per_writer*x:records_per_writer*(x+1)]
    233             a.sort()  # Generate conflicts
    234             b=readers[readers_per_writer*x:readers_per_writer*(x+1)]
    235             wt = Thread(target = self.writerThread,
    236                         args = (self.d, a, b),
    237                         name = 'writer %d' % x,
    238                         )#verbose = verbose)
    239             writers.append(wt)
    240 
    241         for t in writers:
    242             if sys.version_info[0] < 3 :
    243                 t.setDaemon(True)
    244             else :
    245                 t.daemon = True
    246             t.start()
    247 
    248         for t in writers:
    249             t.join()
    250         for t in readers:
    251             t.join()
    252 
    253     def writerThread(self, d, keys, readers):
    254         if sys.version_info[0] < 3 :
    255             name = currentThread().getName()
    256         else :
    257             name = currentThread().name
    258         if verbose:
    259             print "%s: creating records %d - %d" % (name, start, stop)
    260 
    261         count=len(keys)//len(readers)
    262         count2=count
    263         for x in keys :
    264             key = '%04d' % x
    265             dbutils.DeadlockWrap(d.put, key, self.makeData(key),
    266                                  max_retries=12)
    267 
    268             if verbose and x % 100 == 0:
    269                 print "%s: records %d - %d finished" % (name, start, x)
    270 
    271             count2-=1
    272             if not count2 :
    273                 readers.pop().start()
    274                 count2=count
    275 
    276         if verbose:
    277             print "%s: thread finished" % name
    278 
    279     def readerThread(self, d, readerNum):
    280         if sys.version_info[0] < 3 :
    281             name = currentThread().getName()
    282         else :
    283             name = currentThread().name
    284 
    285         c = d.cursor()
    286         count = 0
    287         rec = dbutils.DeadlockWrap(c.first, max_retries=10)
    288         while rec:
    289             count += 1
    290             key, data = rec
    291             self.assertEqual(self.makeData(key), data)
    292             rec = dbutils.DeadlockWrap(c.next, max_retries=10)
    293         if verbose:
    294             print "%s: found %d records" % (name, count)
    295         c.close()
    296 
    297         if verbose:
    298             print "%s: thread finished" % name
    299 
    300 
    301 class BTreeSimpleThreaded(SimpleThreadedBase):
    302     dbtype = db.DB_BTREE
    303 
    304 
    305 class HashSimpleThreaded(SimpleThreadedBase):
    306     dbtype = db.DB_HASH
    307 
    308 
    309 #----------------------------------------------------------------------
    310 
    311 
    312 class ThreadedTransactionsBase(BaseThreadedTestCase):
    313     dbopenflags = db.DB_THREAD | db.DB_AUTO_COMMIT
    314     envflags    = (db.DB_THREAD |
    315                    db.DB_INIT_MPOOL |
    316                    db.DB_INIT_LOCK |
    317                    db.DB_INIT_LOG |
    318                    db.DB_INIT_TXN
    319                    )
    320     readers = 0
    321     writers = 0
    322     records = 2000
    323     txnFlag = 0
    324 
    325     def setEnvOpts(self):
    326         #self.env.set_lk_detect(db.DB_LOCK_DEFAULT)
    327         pass
    328 
    329     def test03_ThreadedTransactions(self):
    330         if verbose:
    331             print '\n', '-=' * 30
    332             print "Running %s.test03_ThreadedTransactions..." % \
    333                   self.__class__.__name__
    334 
    335         keys=range(self.records)
    336         import random
    337         random.shuffle(keys)
    338         records_per_writer=self.records//self.writers
    339         readers_per_writer=self.readers//self.writers
    340         self.assertEqual(self.records,self.writers*records_per_writer)
    341         self.assertEqual(self.readers,self.writers*readers_per_writer)
    342         self.assertTrue((records_per_writer%readers_per_writer)==0)
    343 
    344         readers=[]
    345         for x in xrange(self.readers):
    346             rt = Thread(target = self.readerThread,
    347                         args = (self.d, x),
    348                         name = 'reader %d' % x,
    349                         )#verbose = verbose)
    350             if sys.version_info[0] < 3 :
    351                 rt.setDaemon(True)
    352             else :
    353                 rt.daemon = True
    354             readers.append(rt)
    355 
    356         writers = []
    357         for x in xrange(self.writers):
    358             a=keys[records_per_writer*x:records_per_writer*(x+1)]
    359             b=readers[readers_per_writer*x:readers_per_writer*(x+1)]
    360             wt = Thread(target = self.writerThread,
    361                         args = (self.d, a, b),
    362                         name = 'writer %d' % x,
    363                         )#verbose = verbose)
    364             writers.append(wt)
    365 
    366         dt = Thread(target = self.deadlockThread)
    367         if sys.version_info[0] < 3 :
    368             dt.setDaemon(True)
    369         else :
    370             dt.daemon = True
    371         dt.start()
    372 
    373         for t in writers:
    374             if sys.version_info[0] < 3 :
    375                 t.setDaemon(True)
    376             else :
    377                 t.daemon = True
    378             t.start()
    379 
    380         for t in writers:
    381             t.join()
    382         for t in readers:
    383             t.join()
    384 
    385         self.doLockDetect = False
    386         dt.join()
    387 
    388     def writerThread(self, d, keys, readers):
    389         if sys.version_info[0] < 3 :
    390             name = currentThread().getName()
    391         else :
    392             name = currentThread().name
    393 
    394         count=len(keys)//len(readers)
    395         while len(keys):
    396             try:
    397                 txn = self.env.txn_begin(None, self.txnFlag)
    398                 keys2=keys[:count]
    399                 for x in keys2 :
    400                     key = '%04d' % x
    401                     d.put(key, self.makeData(key), txn)
    402                     if verbose and x % 100 == 0:
    403                         print "%s: records %d - %d finished" % (name, start, x)
    404                 txn.commit()
    405                 keys=keys[count:]
    406                 readers.pop().start()
    407             except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val:
    408                 if verbose:
    409                     if sys.version_info < (2, 6) :
    410                         print "%s: Aborting transaction (%s)" % (name, val[1])
    411                     else :
    412                         print "%s: Aborting transaction (%s)" % (name,
    413                                 val.args[1])
    414                 txn.abort()
    415 
    416         if verbose:
    417             print "%s: thread finished" % name
    418 
    419     def readerThread(self, d, readerNum):
    420         if sys.version_info[0] < 3 :
    421             name = currentThread().getName()
    422         else :
    423             name = currentThread().name
    424 
    425         finished = False
    426         while not finished:
    427             try:
    428                 txn = self.env.txn_begin(None, self.txnFlag)
    429                 c = d.cursor(txn)
    430                 count = 0
    431                 rec = c.first()
    432                 while rec:
    433                     count += 1
    434                     key, data = rec
    435                     self.assertEqual(self.makeData(key), data)
    436                     rec = c.next()
    437                 if verbose: print "%s: found %d records" % (name, count)
    438                 c.close()
    439                 txn.commit()
    440                 finished = True
    441             except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val:
    442                 if verbose:
    443                     if sys.version_info < (2, 6) :
    444                         print "%s: Aborting transaction (%s)" % (name, val[1])
    445                     else :
    446                         print "%s: Aborting transaction (%s)" % (name,
    447                                 val.args[1])
    448                 c.close()
    449                 txn.abort()
    450 
    451         if verbose:
    452             print "%s: thread finished" % name
    453 
    454     def deadlockThread(self):
    455         self.doLockDetect = True
    456         while self.doLockDetect:
    457             time.sleep(0.05)
    458             try:
    459                 aborted = self.env.lock_detect(
    460                     db.DB_LOCK_RANDOM, db.DB_LOCK_CONFLICT)
    461                 if verbose and aborted:
    462                     print "deadlock: Aborted %d deadlocked transaction(s)" \
    463                           % aborted
    464             except db.DBError:
    465                 pass
    466 
    467 
    468 class BTreeThreadedTransactions(ThreadedTransactionsBase):
    469     dbtype = db.DB_BTREE
    470     writers = 2
    471     readers = 10
    472     records = 1000
    473 
    474 class HashThreadedTransactions(ThreadedTransactionsBase):
    475     dbtype = db.DB_HASH
    476     writers = 2
    477     readers = 10
    478     records = 1000
    479 
    480 class BTreeThreadedNoWaitTransactions(ThreadedTransactionsBase):
    481     dbtype = db.DB_BTREE
    482     writers = 2
    483     readers = 10
    484     records = 1000
    485     txnFlag = db.DB_TXN_NOWAIT
    486 
    487 class HashThreadedNoWaitTransactions(ThreadedTransactionsBase):
    488     dbtype = db.DB_HASH
    489     writers = 2
    490     readers = 10
    491     records = 1000
    492     txnFlag = db.DB_TXN_NOWAIT
    493 
    494 
    495 #----------------------------------------------------------------------
    496 
    497 def test_suite():
    498     suite = unittest.TestSuite()
    499 
    500     if have_threads:
    501         suite.addTest(unittest.makeSuite(BTreeConcurrentDataStore))
    502         suite.addTest(unittest.makeSuite(HashConcurrentDataStore))
    503         suite.addTest(unittest.makeSuite(BTreeSimpleThreaded))
    504         suite.addTest(unittest.makeSuite(HashSimpleThreaded))
    505         suite.addTest(unittest.makeSuite(BTreeThreadedTransactions))
    506         suite.addTest(unittest.makeSuite(HashThreadedTransactions))
    507         suite.addTest(unittest.makeSuite(BTreeThreadedNoWaitTransactions))
    508         suite.addTest(unittest.makeSuite(HashThreadedNoWaitTransactions))
    509 
    510     else:
    511         print "Threads not available, skipping thread tests."
    512 
    513     return suite
    514 
    515 
    516 if __name__ == '__main__':
    517     unittest.main(defaultTest='test_suite')
    518