Home | History | Annotate | Download | only in test
      1 """
      2 TestCases for python DB duplicate and Btree key comparison function.
      3 """
      4 
      5 import sys, os, re
      6 import test_all
      7 from cStringIO import StringIO
      8 
      9 import unittest
     10 
     11 from test_all import db, dbshelve, test_support, \
     12         get_new_environment_path, get_new_database_path
     13 
     14 
     15 # Needed for python 3. "cmp" vanished in 3.0.1
     16 def cmp(a, b) :
     17     if a==b : return 0
     18     if a<b : return -1
     19     return 1
     20 
     21 lexical_cmp = cmp
     22 
     23 def lowercase_cmp(left, right) :
     24     return cmp(left.lower(), right.lower())
     25 
     26 def make_reverse_comparator(cmp) :
     27     def reverse(left, right, delegate=cmp) :
     28         return - delegate(left, right)
     29     return reverse
     30 
     31 _expected_lexical_test_data = ['', 'CCCP', 'a', 'aaa', 'b', 'c', 'cccce', 'ccccf']
     32 _expected_lowercase_test_data = ['', 'a', 'aaa', 'b', 'c', 'CC', 'cccce', 'ccccf', 'CCCP']
     33 
     34 class ComparatorTests(unittest.TestCase) :
     35     def comparator_test_helper(self, comparator, expected_data) :
     36         data = expected_data[:]
     37 
     38         import sys
     39         if sys.version_info < (2, 6) :
     40             data.sort(cmp=comparator)
     41         else :  # Insertion Sort. Please, improve
     42             data2 = []
     43             for i in data :
     44                 for j, k in enumerate(data2) :
     45                     r = comparator(k, i)
     46                     if r == 1 :
     47                         data2.insert(j, i)
     48                         break
     49                 else :
     50                     data2.append(i)
     51             data = data2
     52 
     53         self.assertEqual(data, expected_data,
     54                          "comparator `%s' is not right: %s vs. %s"
     55                          % (comparator, expected_data, data))
     56     def test_lexical_comparator(self) :
     57         self.comparator_test_helper(lexical_cmp, _expected_lexical_test_data)
     58     def test_reverse_lexical_comparator(self) :
     59         rev = _expected_lexical_test_data[:]
     60         rev.reverse()
     61         self.comparator_test_helper(make_reverse_comparator(lexical_cmp),
     62                                      rev)
     63     def test_lowercase_comparator(self) :
     64         self.comparator_test_helper(lowercase_cmp,
     65                                      _expected_lowercase_test_data)
     66 
     67 class AbstractBtreeKeyCompareTestCase(unittest.TestCase) :
     68     env = None
     69     db = None
     70 
     71     if (sys.version_info < (2, 7)) or ((sys.version_info >= (3,0)) and
     72             (sys.version_info < (3, 2))) :
     73         def assertLess(self, a, b, msg=None) :
     74             return self.assertTrue(a<b, msg=msg)
     75 
     76     def setUp(self) :
     77         self.filename = self.__class__.__name__ + '.db'
     78         self.homeDir = get_new_environment_path()
     79         env = db.DBEnv()
     80         env.open(self.homeDir,
     81                   db.DB_CREATE | db.DB_INIT_MPOOL
     82                   | db.DB_INIT_LOCK | db.DB_THREAD)
     83         self.env = env
     84 
     85     def tearDown(self) :
     86         self.closeDB()
     87         if self.env is not None:
     88             self.env.close()
     89             self.env = None
     90         test_support.rmtree(self.homeDir)
     91 
     92     def addDataToDB(self, data) :
     93         i = 0
     94         for item in data:
     95             self.db.put(item, str(i))
     96             i = i + 1
     97 
     98     def createDB(self, key_comparator) :
     99         self.db = db.DB(self.env)
    100         self.setupDB(key_comparator)
    101         self.db.open(self.filename, "test", db.DB_BTREE, db.DB_CREATE)
    102 
    103     def setupDB(self, key_comparator) :
    104         self.db.set_bt_compare(key_comparator)
    105 
    106     def closeDB(self) :
    107         if self.db is not None:
    108             self.db.close()
    109             self.db = None
    110 
    111     def startTest(self) :
    112         pass
    113 
    114     def finishTest(self, expected = None) :
    115         if expected is not None:
    116             self.check_results(expected)
    117         self.closeDB()
    118 
    119     def check_results(self, expected) :
    120         curs = self.db.cursor()
    121         try:
    122             index = 0
    123             rec = curs.first()
    124             while rec:
    125                 key, ignore = rec
    126                 self.assertLess(index, len(expected),
    127                                  "to many values returned from cursor")
    128                 self.assertEqual(expected[index], key,
    129                                  "expected value `%s' at %d but got `%s'"
    130                                  % (expected[index], index, key))
    131                 index = index + 1
    132                 rec = curs.next()
    133             self.assertEqual(index, len(expected),
    134                              "not enough values returned from cursor")
    135         finally:
    136             curs.close()
    137 
    138 class BtreeKeyCompareTestCase(AbstractBtreeKeyCompareTestCase) :
    139     def runCompareTest(self, comparator, data) :
    140         self.startTest()
    141         self.createDB(comparator)
    142         self.addDataToDB(data)
    143         self.finishTest(data)
    144 
    145     def test_lexical_ordering(self) :
    146         self.runCompareTest(lexical_cmp, _expected_lexical_test_data)
    147 
    148     def test_reverse_lexical_ordering(self) :
    149         expected_rev_data = _expected_lexical_test_data[:]
    150         expected_rev_data.reverse()
    151         self.runCompareTest(make_reverse_comparator(lexical_cmp),
    152                              expected_rev_data)
    153 
    154     def test_compare_function_useless(self) :
    155         self.startTest()
    156         def socialist_comparator(l, r) :
    157             return 0
    158         self.createDB(socialist_comparator)
    159         self.addDataToDB(['b', 'a', 'd'])
    160         # all things being equal the first key will be the only key
    161         # in the database...  (with the last key's value fwiw)
    162         self.finishTest(['b'])
    163 
    164 
    165 class BtreeExceptionsTestCase(AbstractBtreeKeyCompareTestCase) :
    166     def test_raises_non_callable(self) :
    167         self.startTest()
    168         self.assertRaises(TypeError, self.createDB, 'abc')
    169         self.assertRaises(TypeError, self.createDB, None)
    170         self.finishTest()
    171 
    172     def test_set_bt_compare_with_function(self) :
    173         self.startTest()
    174         self.createDB(lexical_cmp)
    175         self.finishTest()
    176 
    177     def check_results(self, results) :
    178         pass
    179 
    180     def test_compare_function_incorrect(self) :
    181         self.startTest()
    182         def bad_comparator(l, r) :
    183             return 1
    184         # verify that set_bt_compare checks that comparator('', '') == 0
    185         self.assertRaises(TypeError, self.createDB, bad_comparator)
    186         self.finishTest()
    187 
    188     def verifyStderr(self, method, successRe) :
    189         """
    190         Call method() while capturing sys.stderr output internally and
    191         call self.fail() if successRe.search() does not match the stderr
    192         output.  This is used to test for uncatchable exceptions.
    193         """
    194         stdErr = sys.stderr
    195         sys.stderr = StringIO()
    196         try:
    197             method()
    198         finally:
    199             temp = sys.stderr
    200             sys.stderr = stdErr
    201             errorOut = temp.getvalue()
    202             if not successRe.search(errorOut) :
    203                 self.fail("unexpected stderr output:\n"+errorOut)
    204         if sys.version_info < (3, 0) :  # XXX: How to do this in Py3k ???
    205             sys.exc_traceback = sys.last_traceback = None
    206 
    207     def _test_compare_function_exception(self) :
    208         self.startTest()
    209         def bad_comparator(l, r) :
    210             if l == r:
    211                 # pass the set_bt_compare test
    212                 return 0
    213             raise RuntimeError, "i'm a naughty comparison function"
    214         self.createDB(bad_comparator)
    215         #print "\n*** test should print 2 uncatchable tracebacks ***"
    216         self.addDataToDB(['a', 'b', 'c'])  # this should raise, but...
    217         self.finishTest()
    218 
    219     def test_compare_function_exception(self) :
    220         self.verifyStderr(
    221                 self._test_compare_function_exception,
    222                 re.compile('(^RuntimeError:.* naughty.*){2}', re.M|re.S)
    223         )
    224 
    225     def _test_compare_function_bad_return(self) :
    226         self.startTest()
    227         def bad_comparator(l, r) :
    228             if l == r:
    229                 # pass the set_bt_compare test
    230                 return 0
    231             return l
    232         self.createDB(bad_comparator)
    233         #print "\n*** test should print 2 errors about returning an int ***"
    234         self.addDataToDB(['a', 'b', 'c'])  # this should raise, but...
    235         self.finishTest()
    236 
    237     def test_compare_function_bad_return(self) :
    238         self.verifyStderr(
    239                 self._test_compare_function_bad_return,
    240                 re.compile('(^TypeError:.* return an int.*){2}', re.M|re.S)
    241         )
    242 
    243 
    244     def test_cannot_assign_twice(self) :
    245 
    246         def my_compare(a, b) :
    247             return 0
    248 
    249         self.startTest()
    250         self.createDB(my_compare)
    251         self.assertRaises(RuntimeError, self.db.set_bt_compare, my_compare)
    252 
    253 class AbstractDuplicateCompareTestCase(unittest.TestCase) :
    254     env = None
    255     db = None
    256 
    257     if (sys.version_info < (2, 7)) or ((sys.version_info >= (3,0)) and
    258             (sys.version_info < (3, 2))) :
    259         def assertLess(self, a, b, msg=None) :
    260             return self.assertTrue(a<b, msg=msg)
    261 
    262     def setUp(self) :
    263         self.filename = self.__class__.__name__ + '.db'
    264         self.homeDir = get_new_environment_path()
    265         env = db.DBEnv()
    266         env.open(self.homeDir,
    267                   db.DB_CREATE | db.DB_INIT_MPOOL
    268                   | db.DB_INIT_LOCK | db.DB_THREAD)
    269         self.env = env
    270 
    271     def tearDown(self) :
    272         self.closeDB()
    273         if self.env is not None:
    274             self.env.close()
    275             self.env = None
    276         test_support.rmtree(self.homeDir)
    277 
    278     def addDataToDB(self, data) :
    279         for item in data:
    280             self.db.put("key", item)
    281 
    282     def createDB(self, dup_comparator) :
    283         self.db = db.DB(self.env)
    284         self.setupDB(dup_comparator)
    285         self.db.open(self.filename, "test", db.DB_BTREE, db.DB_CREATE)
    286 
    287     def setupDB(self, dup_comparator) :
    288         self.db.set_flags(db.DB_DUPSORT)
    289         self.db.set_dup_compare(dup_comparator)
    290 
    291     def closeDB(self) :
    292         if self.db is not None:
    293             self.db.close()
    294             self.db = None
    295 
    296     def startTest(self) :
    297         pass
    298 
    299     def finishTest(self, expected = None) :
    300         if expected is not None:
    301             self.check_results(expected)
    302         self.closeDB()
    303 
    304     def check_results(self, expected) :
    305         curs = self.db.cursor()
    306         try:
    307             index = 0
    308             rec = curs.first()
    309             while rec:
    310                 ignore, data = rec
    311                 self.assertLess(index, len(expected),
    312                                  "to many values returned from cursor")
    313                 self.assertEqual(expected[index], data,
    314                                  "expected value `%s' at %d but got `%s'"
    315                                  % (expected[index], index, data))
    316                 index = index + 1
    317                 rec = curs.next()
    318             self.assertEqual(index, len(expected),
    319                              "not enough values returned from cursor")
    320         finally:
    321             curs.close()
    322 
    323 class DuplicateCompareTestCase(AbstractDuplicateCompareTestCase) :
    324     def runCompareTest(self, comparator, data) :
    325         self.startTest()
    326         self.createDB(comparator)
    327         self.addDataToDB(data)
    328         self.finishTest(data)
    329 
    330     def test_lexical_ordering(self) :
    331         self.runCompareTest(lexical_cmp, _expected_lexical_test_data)
    332 
    333     def test_reverse_lexical_ordering(self) :
    334         expected_rev_data = _expected_lexical_test_data[:]
    335         expected_rev_data.reverse()
    336         self.runCompareTest(make_reverse_comparator(lexical_cmp),
    337                              expected_rev_data)
    338 
    339 class DuplicateExceptionsTestCase(AbstractDuplicateCompareTestCase) :
    340     def test_raises_non_callable(self) :
    341         self.startTest()
    342         self.assertRaises(TypeError, self.createDB, 'abc')
    343         self.assertRaises(TypeError, self.createDB, None)
    344         self.finishTest()
    345 
    346     def test_set_dup_compare_with_function(self) :
    347         self.startTest()
    348         self.createDB(lexical_cmp)
    349         self.finishTest()
    350 
    351     def check_results(self, results) :
    352         pass
    353 
    354     def test_compare_function_incorrect(self) :
    355         self.startTest()
    356         def bad_comparator(l, r) :
    357             return 1
    358         # verify that set_dup_compare checks that comparator('', '') == 0
    359         self.assertRaises(TypeError, self.createDB, bad_comparator)
    360         self.finishTest()
    361 
    362     def test_compare_function_useless(self) :
    363         self.startTest()
    364         def socialist_comparator(l, r) :
    365             return 0
    366         self.createDB(socialist_comparator)
    367         # DUPSORT does not allow "duplicate duplicates"
    368         self.assertRaises(db.DBKeyExistError, self.addDataToDB, ['b', 'a', 'd'])
    369         self.finishTest()
    370 
    371     def verifyStderr(self, method, successRe) :
    372         """
    373         Call method() while capturing sys.stderr output internally and
    374         call self.fail() if successRe.search() does not match the stderr
    375         output.  This is used to test for uncatchable exceptions.
    376         """
    377         stdErr = sys.stderr
    378         sys.stderr = StringIO()
    379         try:
    380             method()
    381         finally:
    382             temp = sys.stderr
    383             sys.stderr = stdErr
    384             errorOut = temp.getvalue()
    385             if not successRe.search(errorOut) :
    386                 self.fail("unexpected stderr output:\n"+errorOut)
    387         if sys.version_info < (3, 0) :  # XXX: How to do this in Py3k ???
    388             sys.exc_traceback = sys.last_traceback = None
    389 
    390     def _test_compare_function_exception(self) :
    391         self.startTest()
    392         def bad_comparator(l, r) :
    393             if l == r:
    394                 # pass the set_dup_compare test
    395                 return 0
    396             raise RuntimeError, "i'm a naughty comparison function"
    397         self.createDB(bad_comparator)
    398         #print "\n*** test should print 2 uncatchable tracebacks ***"
    399         self.addDataToDB(['a', 'b', 'c'])  # this should raise, but...
    400         self.finishTest()
    401 
    402     def test_compare_function_exception(self) :
    403         self.verifyStderr(
    404                 self._test_compare_function_exception,
    405                 re.compile('(^RuntimeError:.* naughty.*){2}', re.M|re.S)
    406         )
    407 
    408     def _test_compare_function_bad_return(self) :
    409         self.startTest()
    410         def bad_comparator(l, r) :
    411             if l == r:
    412                 # pass the set_dup_compare test
    413                 return 0
    414             return l
    415         self.createDB(bad_comparator)
    416         #print "\n*** test should print 2 errors about returning an int ***"
    417         self.addDataToDB(['a', 'b', 'c'])  # this should raise, but...
    418         self.finishTest()
    419 
    420     def test_compare_function_bad_return(self) :
    421         self.verifyStderr(
    422                 self._test_compare_function_bad_return,
    423                 re.compile('(^TypeError:.* return an int.*){2}', re.M|re.S)
    424         )
    425 
    426 
    427     def test_cannot_assign_twice(self) :
    428 
    429         def my_compare(a, b) :
    430             return 0
    431 
    432         self.startTest()
    433         self.createDB(my_compare)
    434         self.assertRaises(RuntimeError, self.db.set_dup_compare, my_compare)
    435 
    436 def test_suite() :
    437     res = unittest.TestSuite()
    438 
    439     res.addTest(unittest.makeSuite(ComparatorTests))
    440     res.addTest(unittest.makeSuite(BtreeExceptionsTestCase))
    441     res.addTest(unittest.makeSuite(BtreeKeyCompareTestCase))
    442     res.addTest(unittest.makeSuite(DuplicateExceptionsTestCase))
    443     res.addTest(unittest.makeSuite(DuplicateCompareTestCase))
    444     return res
    445 
    446 if __name__ == '__main__':
    447     unittest.main(defaultTest = 'suite')
    448