Home | History | Annotate | Download | only in test
      1 #-*- coding: iso-8859-1 -*-
      2 # pysqlite2/test/userfunctions.py: tests for user-defined functions and
      3 #                                  aggregates.
      4 #
      5 # Copyright (C) 2005-2007 Gerhard Hring <gh (at] ghaering.de>
      6 #
      7 # This file is part of pysqlite.
      8 #
      9 # This software is provided 'as-is', without any express or implied
     10 # warranty.  In no event will the authors be held liable for any damages
     11 # arising from the use of this software.
     12 #
     13 # Permission is granted to anyone to use this software for any purpose,
     14 # including commercial applications, and to alter it and redistribute it
     15 # freely, subject to the following restrictions:
     16 #
     17 # 1. The origin of this software must not be misrepresented; you must not
     18 #    claim that you wrote the original software. If you use this software
     19 #    in a product, an acknowledgment in the product documentation would be
     20 #    appreciated but is not required.
     21 # 2. Altered source versions must be plainly marked as such, and must not be
     22 #    misrepresented as being the original software.
     23 # 3. This notice may not be removed or altered from any source distribution.
     24 
     25 import unittest
     26 import sqlite3 as sqlite
     27 
     28 def func_returntext():
     29     return "foo"
     30 def func_returnunicode():
     31     return "bar"
     32 def func_returnint():
     33     return 42
     34 def func_returnfloat():
     35     return 3.14
     36 def func_returnnull():
     37     return None
     38 def func_returnblob():
     39     return b"blob"
     40 def func_returnlonglong():
     41     return 1<<31
     42 def func_raiseexception():
     43     5/0
     44 
     45 def func_isstring(v):
     46     return type(v) is str
     47 def func_isint(v):
     48     return type(v) is int
     49 def func_isfloat(v):
     50     return type(v) is float
     51 def func_isnone(v):
     52     return type(v) is type(None)
     53 def func_isblob(v):
     54     return isinstance(v, (bytes, memoryview))
     55 def func_islonglong(v):
     56     return isinstance(v, int) and v >= 1<<31
     57 
     58 def func(*args):
     59     return len(args)
     60 
     61 class AggrNoStep:
     62     def __init__(self):
     63         pass
     64 
     65     def finalize(self):
     66         return 1
     67 
     68 class AggrNoFinalize:
     69     def __init__(self):
     70         pass
     71 
     72     def step(self, x):
     73         pass
     74 
     75 class AggrExceptionInInit:
     76     def __init__(self):
     77         5/0
     78 
     79     def step(self, x):
     80         pass
     81 
     82     def finalize(self):
     83         pass
     84 
     85 class AggrExceptionInStep:
     86     def __init__(self):
     87         pass
     88 
     89     def step(self, x):
     90         5/0
     91 
     92     def finalize(self):
     93         return 42
     94 
     95 class AggrExceptionInFinalize:
     96     def __init__(self):
     97         pass
     98 
     99     def step(self, x):
    100         pass
    101 
    102     def finalize(self):
    103         5/0
    104 
    105 class AggrCheckType:
    106     def __init__(self):
    107         self.val = None
    108 
    109     def step(self, whichType, val):
    110         theType = {"str": str, "int": int, "float": float, "None": type(None),
    111                    "blob": bytes}
    112         self.val = int(theType[whichType] is type(val))
    113 
    114     def finalize(self):
    115         return self.val
    116 
    117 class AggrCheckTypes:
    118     def __init__(self):
    119         self.val = 0
    120 
    121     def step(self, whichType, *vals):
    122         theType = {"str": str, "int": int, "float": float, "None": type(None),
    123                    "blob": bytes}
    124         for val in vals:
    125             self.val += int(theType[whichType] is type(val))
    126 
    127     def finalize(self):
    128         return self.val
    129 
    130 class AggrSum:
    131     def __init__(self):
    132         self.val = 0.0
    133 
    134     def step(self, val):
    135         self.val += val
    136 
    137     def finalize(self):
    138         return self.val
    139 
    140 class FunctionTests(unittest.TestCase):
    141     def setUp(self):
    142         self.con = sqlite.connect(":memory:")
    143 
    144         self.con.create_function("returntext", 0, func_returntext)
    145         self.con.create_function("returnunicode", 0, func_returnunicode)
    146         self.con.create_function("returnint", 0, func_returnint)
    147         self.con.create_function("returnfloat", 0, func_returnfloat)
    148         self.con.create_function("returnnull", 0, func_returnnull)
    149         self.con.create_function("returnblob", 0, func_returnblob)
    150         self.con.create_function("returnlonglong", 0, func_returnlonglong)
    151         self.con.create_function("raiseexception", 0, func_raiseexception)
    152 
    153         self.con.create_function("isstring", 1, func_isstring)
    154         self.con.create_function("isint", 1, func_isint)
    155         self.con.create_function("isfloat", 1, func_isfloat)
    156         self.con.create_function("isnone", 1, func_isnone)
    157         self.con.create_function("isblob", 1, func_isblob)
    158         self.con.create_function("islonglong", 1, func_islonglong)
    159         self.con.create_function("spam", -1, func)
    160 
    161     def tearDown(self):
    162         self.con.close()
    163 
    164     def CheckFuncErrorOnCreate(self):
    165         with self.assertRaises(sqlite.OperationalError):
    166             self.con.create_function("bla", -100, lambda x: 2*x)
    167 
    168     def CheckFuncRefCount(self):
    169         def getfunc():
    170             def f():
    171                 return 1
    172             return f
    173         f = getfunc()
    174         globals()["foo"] = f
    175         # self.con.create_function("reftest", 0, getfunc())
    176         self.con.create_function("reftest", 0, f)
    177         cur = self.con.cursor()
    178         cur.execute("select reftest()")
    179 
    180     def CheckFuncReturnText(self):
    181         cur = self.con.cursor()
    182         cur.execute("select returntext()")
    183         val = cur.fetchone()[0]
    184         self.assertEqual(type(val), str)
    185         self.assertEqual(val, "foo")
    186 
    187     def CheckFuncReturnUnicode(self):
    188         cur = self.con.cursor()
    189         cur.execute("select returnunicode()")
    190         val = cur.fetchone()[0]
    191         self.assertEqual(type(val), str)
    192         self.assertEqual(val, "bar")
    193 
    194     def CheckFuncReturnInt(self):
    195         cur = self.con.cursor()
    196         cur.execute("select returnint()")
    197         val = cur.fetchone()[0]
    198         self.assertEqual(type(val), int)
    199         self.assertEqual(val, 42)
    200 
    201     def CheckFuncReturnFloat(self):
    202         cur = self.con.cursor()
    203         cur.execute("select returnfloat()")
    204         val = cur.fetchone()[0]
    205         self.assertEqual(type(val), float)
    206         if val < 3.139 or val > 3.141:
    207             self.fail("wrong value")
    208 
    209     def CheckFuncReturnNull(self):
    210         cur = self.con.cursor()
    211         cur.execute("select returnnull()")
    212         val = cur.fetchone()[0]
    213         self.assertEqual(type(val), type(None))
    214         self.assertEqual(val, None)
    215 
    216     def CheckFuncReturnBlob(self):
    217         cur = self.con.cursor()
    218         cur.execute("select returnblob()")
    219         val = cur.fetchone()[0]
    220         self.assertEqual(type(val), bytes)
    221         self.assertEqual(val, b"blob")
    222 
    223     def CheckFuncReturnLongLong(self):
    224         cur = self.con.cursor()
    225         cur.execute("select returnlonglong()")
    226         val = cur.fetchone()[0]
    227         self.assertEqual(val, 1<<31)
    228 
    229     def CheckFuncException(self):
    230         cur = self.con.cursor()
    231         with self.assertRaises(sqlite.OperationalError) as cm:
    232             cur.execute("select raiseexception()")
    233             cur.fetchone()
    234         self.assertEqual(str(cm.exception), 'user-defined function raised exception')
    235 
    236     def CheckParamString(self):
    237         cur = self.con.cursor()
    238         cur.execute("select isstring(?)", ("foo",))
    239         val = cur.fetchone()[0]
    240         self.assertEqual(val, 1)
    241 
    242     def CheckParamInt(self):
    243         cur = self.con.cursor()
    244         cur.execute("select isint(?)", (42,))
    245         val = cur.fetchone()[0]
    246         self.assertEqual(val, 1)
    247 
    248     def CheckParamFloat(self):
    249         cur = self.con.cursor()
    250         cur.execute("select isfloat(?)", (3.14,))
    251         val = cur.fetchone()[0]
    252         self.assertEqual(val, 1)
    253 
    254     def CheckParamNone(self):
    255         cur = self.con.cursor()
    256         cur.execute("select isnone(?)", (None,))
    257         val = cur.fetchone()[0]
    258         self.assertEqual(val, 1)
    259 
    260     def CheckParamBlob(self):
    261         cur = self.con.cursor()
    262         cur.execute("select isblob(?)", (memoryview(b"blob"),))
    263         val = cur.fetchone()[0]
    264         self.assertEqual(val, 1)
    265 
    266     def CheckParamLongLong(self):
    267         cur = self.con.cursor()
    268         cur.execute("select islonglong(?)", (1<<42,))
    269         val = cur.fetchone()[0]
    270         self.assertEqual(val, 1)
    271 
    272     def CheckAnyArguments(self):
    273         cur = self.con.cursor()
    274         cur.execute("select spam(?, ?)", (1, 2))
    275         val = cur.fetchone()[0]
    276         self.assertEqual(val, 2)
    277 
    278 
    279 class AggregateTests(unittest.TestCase):
    280     def setUp(self):
    281         self.con = sqlite.connect(":memory:")
    282         cur = self.con.cursor()
    283         cur.execute("""
    284             create table test(
    285                 t text,
    286                 i integer,
    287                 f float,
    288                 n,
    289                 b blob
    290                 )
    291             """)
    292         cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
    293             ("foo", 5, 3.14, None, memoryview(b"blob"),))
    294 
    295         self.con.create_aggregate("nostep", 1, AggrNoStep)
    296         self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
    297         self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
    298         self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
    299         self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
    300         self.con.create_aggregate("checkType", 2, AggrCheckType)
    301         self.con.create_aggregate("checkTypes", -1, AggrCheckTypes)
    302         self.con.create_aggregate("mysum", 1, AggrSum)
    303 
    304     def tearDown(self):
    305         #self.cur.close()
    306         #self.con.close()
    307         pass
    308 
    309     def CheckAggrErrorOnCreate(self):
    310         with self.assertRaises(sqlite.OperationalError):
    311             self.con.create_function("bla", -100, AggrSum)
    312 
    313     def CheckAggrNoStep(self):
    314         cur = self.con.cursor()
    315         with self.assertRaises(AttributeError) as cm:
    316             cur.execute("select nostep(t) from test")
    317         self.assertEqual(str(cm.exception), "'AggrNoStep' object has no attribute 'step'")
    318 
    319     def CheckAggrNoFinalize(self):
    320         cur = self.con.cursor()
    321         with self.assertRaises(sqlite.OperationalError) as cm:
    322             cur.execute("select nofinalize(t) from test")
    323             val = cur.fetchone()[0]
    324         self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
    325 
    326     def CheckAggrExceptionInInit(self):
    327         cur = self.con.cursor()
    328         with self.assertRaises(sqlite.OperationalError) as cm:
    329             cur.execute("select excInit(t) from test")
    330             val = cur.fetchone()[0]
    331         self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")
    332 
    333     def CheckAggrExceptionInStep(self):
    334         cur = self.con.cursor()
    335         with self.assertRaises(sqlite.OperationalError) as cm:
    336             cur.execute("select excStep(t) from test")
    337             val = cur.fetchone()[0]
    338         self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")
    339 
    340     def CheckAggrExceptionInFinalize(self):
    341         cur = self.con.cursor()
    342         with self.assertRaises(sqlite.OperationalError) as cm:
    343             cur.execute("select excFinalize(t) from test")
    344             val = cur.fetchone()[0]
    345         self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
    346 
    347     def CheckAggrCheckParamStr(self):
    348         cur = self.con.cursor()
    349         cur.execute("select checkType('str', ?)", ("foo",))
    350         val = cur.fetchone()[0]
    351         self.assertEqual(val, 1)
    352 
    353     def CheckAggrCheckParamInt(self):
    354         cur = self.con.cursor()
    355         cur.execute("select checkType('int', ?)", (42,))
    356         val = cur.fetchone()[0]
    357         self.assertEqual(val, 1)
    358 
    359     def CheckAggrCheckParamsInt(self):
    360         cur = self.con.cursor()
    361         cur.execute("select checkTypes('int', ?, ?)", (42, 24))
    362         val = cur.fetchone()[0]
    363         self.assertEqual(val, 2)
    364 
    365     def CheckAggrCheckParamFloat(self):
    366         cur = self.con.cursor()
    367         cur.execute("select checkType('float', ?)", (3.14,))
    368         val = cur.fetchone()[0]
    369         self.assertEqual(val, 1)
    370 
    371     def CheckAggrCheckParamNone(self):
    372         cur = self.con.cursor()
    373         cur.execute("select checkType('None', ?)", (None,))
    374         val = cur.fetchone()[0]
    375         self.assertEqual(val, 1)
    376 
    377     def CheckAggrCheckParamBlob(self):
    378         cur = self.con.cursor()
    379         cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),))
    380         val = cur.fetchone()[0]
    381         self.assertEqual(val, 1)
    382 
    383     def CheckAggrCheckAggrSum(self):
    384         cur = self.con.cursor()
    385         cur.execute("delete from test")
    386         cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
    387         cur.execute("select mysum(i) from test")
    388         val = cur.fetchone()[0]
    389         self.assertEqual(val, 60)
    390 
    391 class AuthorizerTests(unittest.TestCase):
    392     @staticmethod
    393     def authorizer_cb(action, arg1, arg2, dbname, source):
    394         if action != sqlite.SQLITE_SELECT:
    395             return sqlite.SQLITE_DENY
    396         if arg2 == 'c2' or arg1 == 't2':
    397             return sqlite.SQLITE_DENY
    398         return sqlite.SQLITE_OK
    399 
    400     def setUp(self):
    401         self.con = sqlite.connect(":memory:")
    402         self.con.executescript("""
    403             create table t1 (c1, c2);
    404             create table t2 (c1, c2);
    405             insert into t1 (c1, c2) values (1, 2);
    406             insert into t2 (c1, c2) values (4, 5);
    407             """)
    408 
    409         # For our security test:
    410         self.con.execute("select c2 from t2")
    411 
    412         self.con.set_authorizer(self.authorizer_cb)
    413 
    414     def tearDown(self):
    415         pass
    416 
    417     def test_table_access(self):
    418         with self.assertRaises(sqlite.DatabaseError) as cm:
    419             self.con.execute("select * from t2")
    420         self.assertIn('prohibited', str(cm.exception))
    421 
    422     def test_column_access(self):
    423         with self.assertRaises(sqlite.DatabaseError) as cm:
    424             self.con.execute("select c2 from t1")
    425         self.assertIn('prohibited', str(cm.exception))
    426 
    427 class AuthorizerRaiseExceptionTests(AuthorizerTests):
    428     @staticmethod
    429     def authorizer_cb(action, arg1, arg2, dbname, source):
    430         if action != sqlite.SQLITE_SELECT:
    431             raise ValueError
    432         if arg2 == 'c2' or arg1 == 't2':
    433             raise ValueError
    434         return sqlite.SQLITE_OK
    435 
    436 class AuthorizerIllegalTypeTests(AuthorizerTests):
    437     @staticmethod
    438     def authorizer_cb(action, arg1, arg2, dbname, source):
    439         if action != sqlite.SQLITE_SELECT:
    440             return 0.0
    441         if arg2 == 'c2' or arg1 == 't2':
    442             return 0.0
    443         return sqlite.SQLITE_OK
    444 
    445 class AuthorizerLargeIntegerTests(AuthorizerTests):
    446     @staticmethod
    447     def authorizer_cb(action, arg1, arg2, dbname, source):
    448         if action != sqlite.SQLITE_SELECT:
    449             return 2**32
    450         if arg2 == 'c2' or arg1 == 't2':
    451             return 2**32
    452         return sqlite.SQLITE_OK
    453 
    454 
    455 def suite():
    456     function_suite = unittest.makeSuite(FunctionTests, "Check")
    457     aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
    458     authorizer_suite = unittest.makeSuite(AuthorizerTests)
    459     return unittest.TestSuite((
    460             function_suite,
    461             aggregate_suite,
    462             authorizer_suite,
    463             unittest.makeSuite(AuthorizerRaiseExceptionTests),
    464             unittest.makeSuite(AuthorizerIllegalTypeTests),
    465             unittest.makeSuite(AuthorizerLargeIntegerTests),
    466         ))
    467 
    468 def test():
    469     runner = unittest.TextTestRunner()
    470     runner.run(suite())
    471 
    472 if __name__ == "__main__":
    473     test()
    474