Home | History | Annotate | Download | only in test
      1 #-*- coding: ISO-8859-1 -*-
      2 # pysqlite2/test/userfunctions.py: tests for user-defined functions and
      3 #                                  aggregates.
      4 #
      5 # Copyright (C) 2005-2007 Gerhard Hring <gh (at] ghaering.de>
      6 #
      7 # This file is part of pysqlite.
      8 #
      9 # This software is provided 'as-is', without any express or implied
     10 # warranty.  In no event will the authors be held liable for any damages
     11 # arising from the use of this software.
     12 #
     13 # Permission is granted to anyone to use this software for any purpose,
     14 # including commercial applications, and to alter it and redistribute it
     15 # freely, subject to the following restrictions:
     16 #
     17 # 1. The origin of this software must not be misrepresented; you must not
     18 #    claim that you wrote the original software. If you use this software
     19 #    in a product, an acknowledgment in the product documentation would be
     20 #    appreciated but is not required.
     21 # 2. Altered source versions must be plainly marked as such, and must not be
     22 #    misrepresented as being the original software.
     23 # 3. This notice may not be removed or altered from any source distribution.
     24 
     25 import unittest
     26 import sqlite3 as sqlite
     27 from test import test_support
     28 
     29 def func_returntext():
     30     return "foo"
     31 def func_returnunicode():
     32     return u"bar"
     33 def func_returnint():
     34     return 42
     35 def func_returnfloat():
     36     return 3.14
     37 def func_returnnull():
     38     return None
     39 def func_returnblob():
     40     with test_support.check_py3k_warnings():
     41         return buffer("blob")
     42 def func_returnlonglong():
     43     return 1<<31
     44 def func_raiseexception():
     45     5 // 0
     46 
     47 def func_isstring(v):
     48     return type(v) is unicode
     49 def func_isint(v):
     50     return type(v) is int
     51 def func_isfloat(v):
     52     return type(v) is float
     53 def func_isnone(v):
     54     return type(v) is type(None)
     55 def func_isblob(v):
     56     return type(v) is buffer
     57 def func_islonglong(v):
     58     return isinstance(v, (int, long)) and v >= 1<<31
     59 
     60 class AggrNoStep:
     61     def __init__(self):
     62         pass
     63 
     64     def finalize(self):
     65         return 1
     66 
     67 class AggrNoFinalize:
     68     def __init__(self):
     69         pass
     70 
     71     def step(self, x):
     72         pass
     73 
     74 class AggrExceptionInInit:
     75     def __init__(self):
     76         5 // 0
     77 
     78     def step(self, x):
     79         pass
     80 
     81     def finalize(self):
     82         pass
     83 
     84 class AggrExceptionInStep:
     85     def __init__(self):
     86         pass
     87 
     88     def step(self, x):
     89         5 // 0
     90 
     91     def finalize(self):
     92         return 42
     93 
     94 class AggrExceptionInFinalize:
     95     def __init__(self):
     96         pass
     97 
     98     def step(self, x):
     99         pass
    100 
    101     def finalize(self):
    102         5 // 0
    103 
    104 class AggrCheckType:
    105     def __init__(self):
    106         self.val = None
    107 
    108     def step(self, whichType, val):
    109         theType = {"str": unicode, "int": int, "float": float, "None": type(None), "blob": buffer}
    110         self.val = int(theType[whichType] is type(val))
    111 
    112     def finalize(self):
    113         return self.val
    114 
    115 class AggrSum:
    116     def __init__(self):
    117         self.val = 0.0
    118 
    119     def step(self, val):
    120         self.val += val
    121 
    122     def finalize(self):
    123         return self.val
    124 
    125 class FunctionTests(unittest.TestCase):
    126     def setUp(self):
    127         self.con = sqlite.connect(":memory:")
    128 
    129         self.con.create_function("returntext", 0, func_returntext)
    130         self.con.create_function("returnunicode", 0, func_returnunicode)
    131         self.con.create_function("returnint", 0, func_returnint)
    132         self.con.create_function("returnfloat", 0, func_returnfloat)
    133         self.con.create_function("returnnull", 0, func_returnnull)
    134         self.con.create_function("returnblob", 0, func_returnblob)
    135         self.con.create_function("returnlonglong", 0, func_returnlonglong)
    136         self.con.create_function("raiseexception", 0, func_raiseexception)
    137 
    138         self.con.create_function("isstring", 1, func_isstring)
    139         self.con.create_function("isint", 1, func_isint)
    140         self.con.create_function("isfloat", 1, func_isfloat)
    141         self.con.create_function("isnone", 1, func_isnone)
    142         self.con.create_function("isblob", 1, func_isblob)
    143         self.con.create_function("islonglong", 1, func_islonglong)
    144 
    145     def tearDown(self):
    146         self.con.close()
    147 
    148     def CheckFuncErrorOnCreate(self):
    149         try:
    150             self.con.create_function("bla", -100, lambda x: 2*x)
    151             self.fail("should have raised an OperationalError")
    152         except sqlite.OperationalError:
    153             pass
    154 
    155     def CheckFuncRefCount(self):
    156         def getfunc():
    157             def f():
    158                 return 1
    159             return f
    160         f = getfunc()
    161         globals()["foo"] = f
    162         # self.con.create_function("reftest", 0, getfunc())
    163         self.con.create_function("reftest", 0, f)
    164         cur = self.con.cursor()
    165         cur.execute("select reftest()")
    166 
    167     def CheckFuncReturnText(self):
    168         cur = self.con.cursor()
    169         cur.execute("select returntext()")
    170         val = cur.fetchone()[0]
    171         self.assertEqual(type(val), unicode)
    172         self.assertEqual(val, "foo")
    173 
    174     def CheckFuncReturnUnicode(self):
    175         cur = self.con.cursor()
    176         cur.execute("select returnunicode()")
    177         val = cur.fetchone()[0]
    178         self.assertEqual(type(val), unicode)
    179         self.assertEqual(val, u"bar")
    180 
    181     def CheckFuncReturnInt(self):
    182         cur = self.con.cursor()
    183         cur.execute("select returnint()")
    184         val = cur.fetchone()[0]
    185         self.assertEqual(type(val), int)
    186         self.assertEqual(val, 42)
    187 
    188     def CheckFuncReturnFloat(self):
    189         cur = self.con.cursor()
    190         cur.execute("select returnfloat()")
    191         val = cur.fetchone()[0]
    192         self.assertEqual(type(val), float)
    193         if val < 3.139 or val > 3.141:
    194             self.fail("wrong value")
    195 
    196     def CheckFuncReturnNull(self):
    197         cur = self.con.cursor()
    198         cur.execute("select returnnull()")
    199         val = cur.fetchone()[0]
    200         self.assertEqual(type(val), type(None))
    201         self.assertEqual(val, None)
    202 
    203     def CheckFuncReturnBlob(self):
    204         cur = self.con.cursor()
    205         cur.execute("select returnblob()")
    206         val = cur.fetchone()[0]
    207         with test_support.check_py3k_warnings():
    208             self.assertEqual(type(val), buffer)
    209             self.assertEqual(val, buffer("blob"))
    210 
    211     def CheckFuncReturnLongLong(self):
    212         cur = self.con.cursor()
    213         cur.execute("select returnlonglong()")
    214         val = cur.fetchone()[0]
    215         self.assertEqual(val, 1<<31)
    216 
    217     def CheckFuncException(self):
    218         cur = self.con.cursor()
    219         try:
    220             cur.execute("select raiseexception()")
    221             cur.fetchone()
    222             self.fail("should have raised OperationalError")
    223         except sqlite.OperationalError, e:
    224             self.assertEqual(e.args[0], 'user-defined function raised exception')
    225 
    226     def CheckParamString(self):
    227         cur = self.con.cursor()
    228         cur.execute("select isstring(?)", ("foo",))
    229         val = cur.fetchone()[0]
    230         self.assertEqual(val, 1)
    231 
    232     def CheckParamInt(self):
    233         cur = self.con.cursor()
    234         cur.execute("select isint(?)", (42,))
    235         val = cur.fetchone()[0]
    236         self.assertEqual(val, 1)
    237 
    238     def CheckParamFloat(self):
    239         cur = self.con.cursor()
    240         cur.execute("select isfloat(?)", (3.14,))
    241         val = cur.fetchone()[0]
    242         self.assertEqual(val, 1)
    243 
    244     def CheckParamNone(self):
    245         cur = self.con.cursor()
    246         cur.execute("select isnone(?)", (None,))
    247         val = cur.fetchone()[0]
    248         self.assertEqual(val, 1)
    249 
    250     def CheckParamBlob(self):
    251         cur = self.con.cursor()
    252         with test_support.check_py3k_warnings():
    253             cur.execute("select isblob(?)", (buffer("blob"),))
    254         val = cur.fetchone()[0]
    255         self.assertEqual(val, 1)
    256 
    257     def CheckParamLongLong(self):
    258         cur = self.con.cursor()
    259         cur.execute("select islonglong(?)", (1<<42,))
    260         val = cur.fetchone()[0]
    261         self.assertEqual(val, 1)
    262 
    263 class AggregateTests(unittest.TestCase):
    264     def setUp(self):
    265         self.con = sqlite.connect(":memory:")
    266         cur = self.con.cursor()
    267         cur.execute("""
    268             create table test(
    269                 t text,
    270                 i integer,
    271                 f float,
    272                 n,
    273                 b blob
    274                 )
    275             """)
    276         with test_support.check_py3k_warnings():
    277             cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
    278                 ("foo", 5, 3.14, None, buffer("blob"),))
    279 
    280         self.con.create_aggregate("nostep", 1, AggrNoStep)
    281         self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
    282         self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
    283         self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
    284         self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
    285         self.con.create_aggregate("checkType", 2, AggrCheckType)
    286         self.con.create_aggregate("mysum", 1, AggrSum)
    287 
    288     def tearDown(self):
    289         #self.cur.close()
    290         #self.con.close()
    291         pass
    292 
    293     def CheckAggrErrorOnCreate(self):
    294         try:
    295             self.con.create_function("bla", -100, AggrSum)
    296             self.fail("should have raised an OperationalError")
    297         except sqlite.OperationalError:
    298             pass
    299 
    300     def CheckAggrNoStep(self):
    301         cur = self.con.cursor()
    302         try:
    303             cur.execute("select nostep(t) from test")
    304             self.fail("should have raised an AttributeError")
    305         except AttributeError, e:
    306             self.assertEqual(e.args[0], "AggrNoStep instance has no attribute 'step'")
    307 
    308     def CheckAggrNoFinalize(self):
    309         cur = self.con.cursor()
    310         try:
    311             cur.execute("select nofinalize(t) from test")
    312             val = cur.fetchone()[0]
    313             self.fail("should have raised an OperationalError")
    314         except sqlite.OperationalError, e:
    315             self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
    316 
    317     def CheckAggrExceptionInInit(self):
    318         cur = self.con.cursor()
    319         try:
    320             cur.execute("select excInit(t) from test")
    321             val = cur.fetchone()[0]
    322             self.fail("should have raised an OperationalError")
    323         except sqlite.OperationalError, e:
    324             self.assertEqual(e.args[0], "user-defined aggregate's '__init__' method raised error")
    325 
    326     def CheckAggrExceptionInStep(self):
    327         cur = self.con.cursor()
    328         try:
    329             cur.execute("select excStep(t) from test")
    330             val = cur.fetchone()[0]
    331             self.fail("should have raised an OperationalError")
    332         except sqlite.OperationalError, e:
    333             self.assertEqual(e.args[0], "user-defined aggregate's 'step' method raised error")
    334 
    335     def CheckAggrExceptionInFinalize(self):
    336         cur = self.con.cursor()
    337         try:
    338             cur.execute("select excFinalize(t) from test")
    339             val = cur.fetchone()[0]
    340             self.fail("should have raised an OperationalError")
    341         except sqlite.OperationalError, e:
    342             self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
    343 
    344     def CheckAggrCheckParamStr(self):
    345         cur = self.con.cursor()
    346         cur.execute("select checkType('str', ?)", ("foo",))
    347         val = cur.fetchone()[0]
    348         self.assertEqual(val, 1)
    349 
    350     def CheckAggrCheckParamInt(self):
    351         cur = self.con.cursor()
    352         cur.execute("select checkType('int', ?)", (42,))
    353         val = cur.fetchone()[0]
    354         self.assertEqual(val, 1)
    355 
    356     def CheckAggrCheckParamFloat(self):
    357         cur = self.con.cursor()
    358         cur.execute("select checkType('float', ?)", (3.14,))
    359         val = cur.fetchone()[0]
    360         self.assertEqual(val, 1)
    361 
    362     def CheckAggrCheckParamNone(self):
    363         cur = self.con.cursor()
    364         cur.execute("select checkType('None', ?)", (None,))
    365         val = cur.fetchone()[0]
    366         self.assertEqual(val, 1)
    367 
    368     def CheckAggrCheckParamBlob(self):
    369         cur = self.con.cursor()
    370         with test_support.check_py3k_warnings():
    371             cur.execute("select checkType('blob', ?)", (buffer("blob"),))
    372         val = cur.fetchone()[0]
    373         self.assertEqual(val, 1)
    374 
    375     def CheckAggrCheckAggrSum(self):
    376         cur = self.con.cursor()
    377         cur.execute("delete from test")
    378         cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
    379         cur.execute("select mysum(i) from test")
    380         val = cur.fetchone()[0]
    381         self.assertEqual(val, 60)
    382 
    383 class AuthorizerTests(unittest.TestCase):
    384     @staticmethod
    385     def authorizer_cb(action, arg1, arg2, dbname, source):
    386         if action != sqlite.SQLITE_SELECT:
    387             return sqlite.SQLITE_DENY
    388         if arg2 == 'c2' or arg1 == 't2':
    389             return sqlite.SQLITE_DENY
    390         return sqlite.SQLITE_OK
    391 
    392     def setUp(self):
    393         self.con = sqlite.connect(":memory:")
    394         self.con.executescript("""
    395             create table t1 (c1, c2);
    396             create table t2 (c1, c2);
    397             insert into t1 (c1, c2) values (1, 2);
    398             insert into t2 (c1, c2) values (4, 5);
    399             """)
    400 
    401         # For our security test:
    402         self.con.execute("select c2 from t2")
    403 
    404         self.con.set_authorizer(self.authorizer_cb)
    405 
    406     def tearDown(self):
    407         pass
    408 
    409     def test_table_access(self):
    410         try:
    411             self.con.execute("select * from t2")
    412         except sqlite.DatabaseError, e:
    413             if not e.args[0].endswith("prohibited"):
    414                 self.fail("wrong exception text: %s" % e.args[0])
    415             return
    416         self.fail("should have raised an exception due to missing privileges")
    417 
    418     def test_column_access(self):
    419         try:
    420             self.con.execute("select c2 from t1")
    421         except sqlite.DatabaseError, e:
    422             if not e.args[0].endswith("prohibited"):
    423                 self.fail("wrong exception text: %s" % e.args[0])
    424             return
    425         self.fail("should have raised an exception due to missing privileges")
    426 
    427 class AuthorizerRaiseExceptionTests(AuthorizerTests):
    428     @staticmethod
    429     def authorizer_cb(action, arg1, arg2, dbname, source):
    430         if action != sqlite.SQLITE_SELECT:
    431             raise ValueError
    432         if arg2 == 'c2' or arg1 == 't2':
    433             raise ValueError
    434         return sqlite.SQLITE_OK
    435 
    436 class AuthorizerIllegalTypeTests(AuthorizerTests):
    437     @staticmethod
    438     def authorizer_cb(action, arg1, arg2, dbname, source):
    439         if action != sqlite.SQLITE_SELECT:
    440             return 0.0
    441         if arg2 == 'c2' or arg1 == 't2':
    442             return 0.0
    443         return sqlite.SQLITE_OK
    444 
    445 class AuthorizerLargeIntegerTests(AuthorizerTests):
    446     @staticmethod
    447     def authorizer_cb(action, arg1, arg2, dbname, source):
    448         if action != sqlite.SQLITE_SELECT:
    449             return 2**32
    450         if arg2 == 'c2' or arg1 == 't2':
    451             return 2**32
    452         return sqlite.SQLITE_OK
    453 
    454 
    455 def suite():
    456     function_suite = unittest.makeSuite(FunctionTests, "Check")
    457     aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
    458     authorizer_suite = unittest.makeSuite(AuthorizerTests)
    459     return unittest.TestSuite((
    460             function_suite,
    461             aggregate_suite,
    462             authorizer_suite,
    463             unittest.makeSuite(AuthorizerRaiseExceptionTests),
    464             unittest.makeSuite(AuthorizerIllegalTypeTests),
    465             unittest.makeSuite(AuthorizerLargeIntegerTests),
    466         ))
    467 
    468 def test():
    469     runner = unittest.TextTestRunner()
    470     runner.run(suite())
    471 
    472 if __name__ == "__main__":
    473     test()
    474