1 """TestCases for multi-threaded access to a DB. 2 """ 3 4 import os 5 import sys 6 import time 7 import errno 8 from random import random 9 10 DASH = '-' 11 12 try: 13 WindowsError 14 except NameError: 15 class WindowsError(Exception): 16 pass 17 18 import unittest 19 from test_all import db, dbutils, test_support, verbose, have_threads, \ 20 get_new_environment_path, get_new_database_path 21 22 if have_threads : 23 from threading import Thread 24 if sys.version_info[0] < 3 : 25 from threading import currentThread 26 else : 27 from threading import current_thread as currentThread 28 29 30 #---------------------------------------------------------------------- 31 32 class BaseThreadedTestCase(unittest.TestCase): 33 dbtype = db.DB_UNKNOWN # must be set in derived class 34 dbopenflags = 0 35 dbsetflags = 0 36 envflags = 0 37 38 def setUp(self): 39 if verbose: 40 dbutils._deadlock_VerboseFile = sys.stdout 41 42 self.homeDir = get_new_environment_path() 43 self.env = db.DBEnv() 44 self.setEnvOpts() 45 self.env.open(self.homeDir, self.envflags | db.DB_CREATE) 46 47 self.filename = self.__class__.__name__ + '.db' 48 self.d = db.DB(self.env) 49 if self.dbsetflags: 50 self.d.set_flags(self.dbsetflags) 51 self.d.open(self.filename, self.dbtype, self.dbopenflags|db.DB_CREATE) 52 53 def tearDown(self): 54 self.d.close() 55 self.env.close() 56 test_support.rmtree(self.homeDir) 57 58 def setEnvOpts(self): 59 pass 60 61 def makeData(self, key): 62 return DASH.join([key] * 5) 63 64 65 #---------------------------------------------------------------------- 66 67 68 class ConcurrentDataStoreBase(BaseThreadedTestCase): 69 dbopenflags = db.DB_THREAD 70 envflags = db.DB_THREAD | db.DB_INIT_CDB | db.DB_INIT_MPOOL 71 readers = 0 # derived class should set 72 writers = 0 73 records = 1000 74 75 def test01_1WriterMultiReaders(self): 76 if verbose: 77 print '\n', '-=' * 30 78 print "Running %s.test01_1WriterMultiReaders..." % \ 79 self.__class__.__name__ 80 81 keys=range(self.records) 82 import random 83 random.shuffle(keys) 84 records_per_writer=self.records//self.writers 85 readers_per_writer=self.readers//self.writers 86 self.assertEqual(self.records,self.writers*records_per_writer) 87 self.assertEqual(self.readers,self.writers*readers_per_writer) 88 self.assertTrue((records_per_writer%readers_per_writer)==0) 89 readers = [] 90 91 for x in xrange(self.readers): 92 rt = Thread(target = self.readerThread, 93 args = (self.d, x), 94 name = 'reader %d' % x, 95 )#verbose = verbose) 96 if sys.version_info[0] < 3 : 97 rt.setDaemon(True) 98 else : 99 rt.daemon = True 100 readers.append(rt) 101 102 writers=[] 103 for x in xrange(self.writers): 104 a=keys[records_per_writer*x:records_per_writer*(x+1)] 105 a.sort() # Generate conflicts 106 b=readers[readers_per_writer*x:readers_per_writer*(x+1)] 107 wt = Thread(target = self.writerThread, 108 args = (self.d, a, b), 109 name = 'writer %d' % x, 110 )#verbose = verbose) 111 writers.append(wt) 112 113 for t in writers: 114 if sys.version_info[0] < 3 : 115 t.setDaemon(True) 116 else : 117 t.daemon = True 118 t.start() 119 120 for t in writers: 121 t.join() 122 for t in readers: 123 t.join() 124 125 def writerThread(self, d, keys, readers): 126 if sys.version_info[0] < 3 : 127 name = currentThread().getName() 128 else : 129 name = currentThread().name 130 131 if verbose: 132 print "%s: creating records %d - %d" % (name, start, stop) 133 134 count=len(keys)//len(readers) 135 count2=count 136 for x in keys : 137 key = '%04d' % x 138 dbutils.DeadlockWrap(d.put, key, self.makeData(key), 139 max_retries=12) 140 if verbose and x % 100 == 0: 141 print "%s: records %d - %d finished" % (name, start, x) 142 143 count2-=1 144 if not count2 : 145 readers.pop().start() 146 count2=count 147 148 if verbose: 149 print "%s: finished creating records" % name 150 151 if verbose: 152 print "%s: thread finished" % name 153 154 def readerThread(self, d, readerNum): 155 if sys.version_info[0] < 3 : 156 name = currentThread().getName() 157 else : 158 name = currentThread().name 159 160 for i in xrange(5) : 161 c = d.cursor() 162 count = 0 163 rec = c.first() 164 while rec: 165 count += 1 166 key, data = rec 167 self.assertEqual(self.makeData(key), data) 168 rec = c.next() 169 if verbose: 170 print "%s: found %d records" % (name, count) 171 c.close() 172 173 if verbose: 174 print "%s: thread finished" % name 175 176 177 class BTreeConcurrentDataStore(ConcurrentDataStoreBase): 178 dbtype = db.DB_BTREE 179 writers = 2 180 readers = 10 181 records = 1000 182 183 184 class HashConcurrentDataStore(ConcurrentDataStoreBase): 185 dbtype = db.DB_HASH 186 writers = 2 187 readers = 10 188 records = 1000 189 190 191 #---------------------------------------------------------------------- 192 193 class SimpleThreadedBase(BaseThreadedTestCase): 194 dbopenflags = db.DB_THREAD 195 envflags = db.DB_THREAD | db.DB_INIT_MPOOL | db.DB_INIT_LOCK 196 readers = 10 197 writers = 2 198 records = 1000 199 200 def setEnvOpts(self): 201 self.env.set_lk_detect(db.DB_LOCK_DEFAULT) 202 203 def test02_SimpleLocks(self): 204 if verbose: 205 print '\n', '-=' * 30 206 print "Running %s.test02_SimpleLocks..." % self.__class__.__name__ 207 208 209 keys=range(self.records) 210 import random 211 random.shuffle(keys) 212 records_per_writer=self.records//self.writers 213 readers_per_writer=self.readers//self.writers 214 self.assertEqual(self.records,self.writers*records_per_writer) 215 self.assertEqual(self.readers,self.writers*readers_per_writer) 216 self.assertTrue((records_per_writer%readers_per_writer)==0) 217 218 readers = [] 219 for x in xrange(self.readers): 220 rt = Thread(target = self.readerThread, 221 args = (self.d, x), 222 name = 'reader %d' % x, 223 )#verbose = verbose) 224 if sys.version_info[0] < 3 : 225 rt.setDaemon(True) 226 else : 227 rt.daemon = True 228 readers.append(rt) 229 230 writers = [] 231 for x in xrange(self.writers): 232 a=keys[records_per_writer*x:records_per_writer*(x+1)] 233 a.sort() # Generate conflicts 234 b=readers[readers_per_writer*x:readers_per_writer*(x+1)] 235 wt = Thread(target = self.writerThread, 236 args = (self.d, a, b), 237 name = 'writer %d' % x, 238 )#verbose = verbose) 239 writers.append(wt) 240 241 for t in writers: 242 if sys.version_info[0] < 3 : 243 t.setDaemon(True) 244 else : 245 t.daemon = True 246 t.start() 247 248 for t in writers: 249 t.join() 250 for t in readers: 251 t.join() 252 253 def writerThread(self, d, keys, readers): 254 if sys.version_info[0] < 3 : 255 name = currentThread().getName() 256 else : 257 name = currentThread().name 258 if verbose: 259 print "%s: creating records %d - %d" % (name, start, stop) 260 261 count=len(keys)//len(readers) 262 count2=count 263 for x in keys : 264 key = '%04d' % x 265 dbutils.DeadlockWrap(d.put, key, self.makeData(key), 266 max_retries=12) 267 268 if verbose and x % 100 == 0: 269 print "%s: records %d - %d finished" % (name, start, x) 270 271 count2-=1 272 if not count2 : 273 readers.pop().start() 274 count2=count 275 276 if verbose: 277 print "%s: thread finished" % name 278 279 def readerThread(self, d, readerNum): 280 if sys.version_info[0] < 3 : 281 name = currentThread().getName() 282 else : 283 name = currentThread().name 284 285 c = d.cursor() 286 count = 0 287 rec = dbutils.DeadlockWrap(c.first, max_retries=10) 288 while rec: 289 count += 1 290 key, data = rec 291 self.assertEqual(self.makeData(key), data) 292 rec = dbutils.DeadlockWrap(c.next, max_retries=10) 293 if verbose: 294 print "%s: found %d records" % (name, count) 295 c.close() 296 297 if verbose: 298 print "%s: thread finished" % name 299 300 301 class BTreeSimpleThreaded(SimpleThreadedBase): 302 dbtype = db.DB_BTREE 303 304 305 class HashSimpleThreaded(SimpleThreadedBase): 306 dbtype = db.DB_HASH 307 308 309 #---------------------------------------------------------------------- 310 311 312 class ThreadedTransactionsBase(BaseThreadedTestCase): 313 dbopenflags = db.DB_THREAD | db.DB_AUTO_COMMIT 314 envflags = (db.DB_THREAD | 315 db.DB_INIT_MPOOL | 316 db.DB_INIT_LOCK | 317 db.DB_INIT_LOG | 318 db.DB_INIT_TXN 319 ) 320 readers = 0 321 writers = 0 322 records = 2000 323 txnFlag = 0 324 325 def setEnvOpts(self): 326 #self.env.set_lk_detect(db.DB_LOCK_DEFAULT) 327 pass 328 329 def test03_ThreadedTransactions(self): 330 if verbose: 331 print '\n', '-=' * 30 332 print "Running %s.test03_ThreadedTransactions..." % \ 333 self.__class__.__name__ 334 335 keys=range(self.records) 336 import random 337 random.shuffle(keys) 338 records_per_writer=self.records//self.writers 339 readers_per_writer=self.readers//self.writers 340 self.assertEqual(self.records,self.writers*records_per_writer) 341 self.assertEqual(self.readers,self.writers*readers_per_writer) 342 self.assertTrue((records_per_writer%readers_per_writer)==0) 343 344 readers=[] 345 for x in xrange(self.readers): 346 rt = Thread(target = self.readerThread, 347 args = (self.d, x), 348 name = 'reader %d' % x, 349 )#verbose = verbose) 350 if sys.version_info[0] < 3 : 351 rt.setDaemon(True) 352 else : 353 rt.daemon = True 354 readers.append(rt) 355 356 writers = [] 357 for x in xrange(self.writers): 358 a=keys[records_per_writer*x:records_per_writer*(x+1)] 359 b=readers[readers_per_writer*x:readers_per_writer*(x+1)] 360 wt = Thread(target = self.writerThread, 361 args = (self.d, a, b), 362 name = 'writer %d' % x, 363 )#verbose = verbose) 364 writers.append(wt) 365 366 dt = Thread(target = self.deadlockThread) 367 if sys.version_info[0] < 3 : 368 dt.setDaemon(True) 369 else : 370 dt.daemon = True 371 dt.start() 372 373 for t in writers: 374 if sys.version_info[0] < 3 : 375 t.setDaemon(True) 376 else : 377 t.daemon = True 378 t.start() 379 380 for t in writers: 381 t.join() 382 for t in readers: 383 t.join() 384 385 self.doLockDetect = False 386 dt.join() 387 388 def writerThread(self, d, keys, readers): 389 if sys.version_info[0] < 3 : 390 name = currentThread().getName() 391 else : 392 name = currentThread().name 393 394 count=len(keys)//len(readers) 395 while len(keys): 396 try: 397 txn = self.env.txn_begin(None, self.txnFlag) 398 keys2=keys[:count] 399 for x in keys2 : 400 key = '%04d' % x 401 d.put(key, self.makeData(key), txn) 402 if verbose and x % 100 == 0: 403 print "%s: records %d - %d finished" % (name, start, x) 404 txn.commit() 405 keys=keys[count:] 406 readers.pop().start() 407 except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val: 408 if verbose: 409 if sys.version_info < (2, 6) : 410 print "%s: Aborting transaction (%s)" % (name, val[1]) 411 else : 412 print "%s: Aborting transaction (%s)" % (name, 413 val.args[1]) 414 txn.abort() 415 416 if verbose: 417 print "%s: thread finished" % name 418 419 def readerThread(self, d, readerNum): 420 if sys.version_info[0] < 3 : 421 name = currentThread().getName() 422 else : 423 name = currentThread().name 424 425 finished = False 426 while not finished: 427 try: 428 txn = self.env.txn_begin(None, self.txnFlag) 429 c = d.cursor(txn) 430 count = 0 431 rec = c.first() 432 while rec: 433 count += 1 434 key, data = rec 435 self.assertEqual(self.makeData(key), data) 436 rec = c.next() 437 if verbose: print "%s: found %d records" % (name, count) 438 c.close() 439 txn.commit() 440 finished = True 441 except (db.DBLockDeadlockError, db.DBLockNotGrantedError), val: 442 if verbose: 443 if sys.version_info < (2, 6) : 444 print "%s: Aborting transaction (%s)" % (name, val[1]) 445 else : 446 print "%s: Aborting transaction (%s)" % (name, 447 val.args[1]) 448 c.close() 449 txn.abort() 450 451 if verbose: 452 print "%s: thread finished" % name 453 454 def deadlockThread(self): 455 self.doLockDetect = True 456 while self.doLockDetect: 457 time.sleep(0.05) 458 try: 459 aborted = self.env.lock_detect( 460 db.DB_LOCK_RANDOM, db.DB_LOCK_CONFLICT) 461 if verbose and aborted: 462 print "deadlock: Aborted %d deadlocked transaction(s)" \ 463 % aborted 464 except db.DBError: 465 pass 466 467 468 class BTreeThreadedTransactions(ThreadedTransactionsBase): 469 dbtype = db.DB_BTREE 470 writers = 2 471 readers = 10 472 records = 1000 473 474 class HashThreadedTransactions(ThreadedTransactionsBase): 475 dbtype = db.DB_HASH 476 writers = 2 477 readers = 10 478 records = 1000 479 480 class BTreeThreadedNoWaitTransactions(ThreadedTransactionsBase): 481 dbtype = db.DB_BTREE 482 writers = 2 483 readers = 10 484 records = 1000 485 txnFlag = db.DB_TXN_NOWAIT 486 487 class HashThreadedNoWaitTransactions(ThreadedTransactionsBase): 488 dbtype = db.DB_HASH 489 writers = 2 490 readers = 10 491 records = 1000 492 txnFlag = db.DB_TXN_NOWAIT 493 494 495 #---------------------------------------------------------------------- 496 497 def test_suite(): 498 suite = unittest.TestSuite() 499 500 if have_threads: 501 suite.addTest(unittest.makeSuite(BTreeConcurrentDataStore)) 502 suite.addTest(unittest.makeSuite(HashConcurrentDataStore)) 503 suite.addTest(unittest.makeSuite(BTreeSimpleThreaded)) 504 suite.addTest(unittest.makeSuite(HashSimpleThreaded)) 505 suite.addTest(unittest.makeSuite(BTreeThreadedTransactions)) 506 suite.addTest(unittest.makeSuite(HashThreadedTransactions)) 507 suite.addTest(unittest.makeSuite(BTreeThreadedNoWaitTransactions)) 508 suite.addTest(unittest.makeSuite(HashThreadedNoWaitTransactions)) 509 510 else: 511 print "Threads not available, skipping thread tests." 512 513 return suite 514 515 516 if __name__ == '__main__': 517 unittest.main(defaultTest='test_suite') 518