Home | History | Annotate | Download | only in test
      1 #-*- coding: ISO-8859-1 -*-
      2 # pysqlite2/test/userfunctions.py: tests for user-defined functions and
      3 #                                  aggregates.
      4 #
      5 # Copyright (C) 2005-2007 Gerhard Hring <gh (at] ghaering.de>
      6 #
      7 # This file is part of pysqlite.
      8 #
      9 # This software is provided 'as-is', without any express or implied
     10 # warranty.  In no event will the authors be held liable for any damages
     11 # arising from the use of this software.
     12 #
     13 # Permission is granted to anyone to use this software for any purpose,
     14 # including commercial applications, and to alter it and redistribute it
     15 # freely, subject to the following restrictions:
     16 #
     17 # 1. The origin of this software must not be misrepresented; you must not
     18 #    claim that you wrote the original software. If you use this software
     19 #    in a product, an acknowledgment in the product documentation would be
     20 #    appreciated but is not required.
     21 # 2. Altered source versions must be plainly marked as such, and must not be
     22 #    misrepresented as being the original software.
     23 # 3. This notice may not be removed or altered from any source distribution.
     24 
     25 import unittest
     26 import sqlite3 as sqlite
     27 
     28 def func_returntext():
     29     return "foo"
     30 def func_returnunicode():
     31     return u"bar"
     32 def func_returnint():
     33     return 42
     34 def func_returnfloat():
     35     return 3.14
     36 def func_returnnull():
     37     return None
     38 def func_returnblob():
     39     return buffer("blob")
     40 def func_returnlonglong():
     41     return 1<<31
     42 def func_raiseexception():
     43     5 // 0
     44 
     45 def func_isstring(v):
     46     return type(v) is unicode
     47 def func_isint(v):
     48     return type(v) is int
     49 def func_isfloat(v):
     50     return type(v) is float
     51 def func_isnone(v):
     52     return type(v) is type(None)
     53 def func_isblob(v):
     54     return type(v) is buffer
     55 def func_islonglong(v):
     56     return isinstance(v, (int, long)) and v >= 1<<31
     57 
     58 class AggrNoStep:
     59     def __init__(self):
     60         pass
     61 
     62     def finalize(self):
     63         return 1
     64 
     65 class AggrNoFinalize:
     66     def __init__(self):
     67         pass
     68 
     69     def step(self, x):
     70         pass
     71 
     72 class AggrExceptionInInit:
     73     def __init__(self):
     74         5 // 0
     75 
     76     def step(self, x):
     77         pass
     78 
     79     def finalize(self):
     80         pass
     81 
     82 class AggrExceptionInStep:
     83     def __init__(self):
     84         pass
     85 
     86     def step(self, x):
     87         5 // 0
     88 
     89     def finalize(self):
     90         return 42
     91 
     92 class AggrExceptionInFinalize:
     93     def __init__(self):
     94         pass
     95 
     96     def step(self, x):
     97         pass
     98 
     99     def finalize(self):
    100         5 // 0
    101 
    102 class AggrCheckType:
    103     def __init__(self):
    104         self.val = None
    105 
    106     def step(self, whichType, val):
    107         theType = {"str": unicode, "int": int, "float": float, "None": type(None), "blob": buffer}
    108         self.val = int(theType[whichType] is type(val))
    109 
    110     def finalize(self):
    111         return self.val
    112 
    113 class AggrSum:
    114     def __init__(self):
    115         self.val = 0.0
    116 
    117     def step(self, val):
    118         self.val += val
    119 
    120     def finalize(self):
    121         return self.val
    122 
    123 class FunctionTests(unittest.TestCase):
    124     def setUp(self):
    125         self.con = sqlite.connect(":memory:")
    126 
    127         self.con.create_function("returntext", 0, func_returntext)
    128         self.con.create_function("returnunicode", 0, func_returnunicode)
    129         self.con.create_function("returnint", 0, func_returnint)
    130         self.con.create_function("returnfloat", 0, func_returnfloat)
    131         self.con.create_function("returnnull", 0, func_returnnull)
    132         self.con.create_function("returnblob", 0, func_returnblob)
    133         self.con.create_function("returnlonglong", 0, func_returnlonglong)
    134         self.con.create_function("raiseexception", 0, func_raiseexception)
    135 
    136         self.con.create_function("isstring", 1, func_isstring)
    137         self.con.create_function("isint", 1, func_isint)
    138         self.con.create_function("isfloat", 1, func_isfloat)
    139         self.con.create_function("isnone", 1, func_isnone)
    140         self.con.create_function("isblob", 1, func_isblob)
    141         self.con.create_function("islonglong", 1, func_islonglong)
    142 
    143     def tearDown(self):
    144         self.con.close()
    145 
    146     def CheckFuncErrorOnCreate(self):
    147         try:
    148             self.con.create_function("bla", -100, lambda x: 2*x)
    149             self.fail("should have raised an OperationalError")
    150         except sqlite.OperationalError:
    151             pass
    152 
    153     def CheckFuncRefCount(self):
    154         def getfunc():
    155             def f():
    156                 return 1
    157             return f
    158         f = getfunc()
    159         globals()["foo"] = f
    160         # self.con.create_function("reftest", 0, getfunc())
    161         self.con.create_function("reftest", 0, f)
    162         cur = self.con.cursor()
    163         cur.execute("select reftest()")
    164 
    165     def CheckFuncReturnText(self):
    166         cur = self.con.cursor()
    167         cur.execute("select returntext()")
    168         val = cur.fetchone()[0]
    169         self.assertEqual(type(val), unicode)
    170         self.assertEqual(val, "foo")
    171 
    172     def CheckFuncReturnUnicode(self):
    173         cur = self.con.cursor()
    174         cur.execute("select returnunicode()")
    175         val = cur.fetchone()[0]
    176         self.assertEqual(type(val), unicode)
    177         self.assertEqual(val, u"bar")
    178 
    179     def CheckFuncReturnInt(self):
    180         cur = self.con.cursor()
    181         cur.execute("select returnint()")
    182         val = cur.fetchone()[0]
    183         self.assertEqual(type(val), int)
    184         self.assertEqual(val, 42)
    185 
    186     def CheckFuncReturnFloat(self):
    187         cur = self.con.cursor()
    188         cur.execute("select returnfloat()")
    189         val = cur.fetchone()[0]
    190         self.assertEqual(type(val), float)
    191         if val < 3.139 or val > 3.141:
    192             self.fail("wrong value")
    193 
    194     def CheckFuncReturnNull(self):
    195         cur = self.con.cursor()
    196         cur.execute("select returnnull()")
    197         val = cur.fetchone()[0]
    198         self.assertEqual(type(val), type(None))
    199         self.assertEqual(val, None)
    200 
    201     def CheckFuncReturnBlob(self):
    202         cur = self.con.cursor()
    203         cur.execute("select returnblob()")
    204         val = cur.fetchone()[0]
    205         self.assertEqual(type(val), buffer)
    206         self.assertEqual(val, buffer("blob"))
    207 
    208     def CheckFuncReturnLongLong(self):
    209         cur = self.con.cursor()
    210         cur.execute("select returnlonglong()")
    211         val = cur.fetchone()[0]
    212         self.assertEqual(val, 1<<31)
    213 
    214     def CheckFuncException(self):
    215         cur = self.con.cursor()
    216         try:
    217             cur.execute("select raiseexception()")
    218             cur.fetchone()
    219             self.fail("should have raised OperationalError")
    220         except sqlite.OperationalError, e:
    221             self.assertEqual(e.args[0], 'user-defined function raised exception')
    222 
    223     def CheckParamString(self):
    224         cur = self.con.cursor()
    225         cur.execute("select isstring(?)", ("foo",))
    226         val = cur.fetchone()[0]
    227         self.assertEqual(val, 1)
    228 
    229     def CheckParamInt(self):
    230         cur = self.con.cursor()
    231         cur.execute("select isint(?)", (42,))
    232         val = cur.fetchone()[0]
    233         self.assertEqual(val, 1)
    234 
    235     def CheckParamFloat(self):
    236         cur = self.con.cursor()
    237         cur.execute("select isfloat(?)", (3.14,))
    238         val = cur.fetchone()[0]
    239         self.assertEqual(val, 1)
    240 
    241     def CheckParamNone(self):
    242         cur = self.con.cursor()
    243         cur.execute("select isnone(?)", (None,))
    244         val = cur.fetchone()[0]
    245         self.assertEqual(val, 1)
    246 
    247     def CheckParamBlob(self):
    248         cur = self.con.cursor()
    249         cur.execute("select isblob(?)", (buffer("blob"),))
    250         val = cur.fetchone()[0]
    251         self.assertEqual(val, 1)
    252 
    253     def CheckParamLongLong(self):
    254         cur = self.con.cursor()
    255         cur.execute("select islonglong(?)", (1<<42,))
    256         val = cur.fetchone()[0]
    257         self.assertEqual(val, 1)
    258 
    259 class AggregateTests(unittest.TestCase):
    260     def setUp(self):
    261         self.con = sqlite.connect(":memory:")
    262         cur = self.con.cursor()
    263         cur.execute("""
    264             create table test(
    265                 t text,
    266                 i integer,
    267                 f float,
    268                 n,
    269                 b blob
    270                 )
    271             """)
    272         cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
    273             ("foo", 5, 3.14, None, buffer("blob"),))
    274 
    275         self.con.create_aggregate("nostep", 1, AggrNoStep)
    276         self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
    277         self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
    278         self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
    279         self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
    280         self.con.create_aggregate("checkType", 2, AggrCheckType)
    281         self.con.create_aggregate("mysum", 1, AggrSum)
    282 
    283     def tearDown(self):
    284         #self.cur.close()
    285         #self.con.close()
    286         pass
    287 
    288     def CheckAggrErrorOnCreate(self):
    289         try:
    290             self.con.create_function("bla", -100, AggrSum)
    291             self.fail("should have raised an OperationalError")
    292         except sqlite.OperationalError:
    293             pass
    294 
    295     def CheckAggrNoStep(self):
    296         cur = self.con.cursor()
    297         try:
    298             cur.execute("select nostep(t) from test")
    299             self.fail("should have raised an AttributeError")
    300         except AttributeError, e:
    301             self.assertEqual(e.args[0], "AggrNoStep instance has no attribute 'step'")
    302 
    303     def CheckAggrNoFinalize(self):
    304         cur = self.con.cursor()
    305         try:
    306             cur.execute("select nofinalize(t) from test")
    307             val = cur.fetchone()[0]
    308             self.fail("should have raised an OperationalError")
    309         except sqlite.OperationalError, e:
    310             self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
    311 
    312     def CheckAggrExceptionInInit(self):
    313         cur = self.con.cursor()
    314         try:
    315             cur.execute("select excInit(t) from test")
    316             val = cur.fetchone()[0]
    317             self.fail("should have raised an OperationalError")
    318         except sqlite.OperationalError, e:
    319             self.assertEqual(e.args[0], "user-defined aggregate's '__init__' method raised error")
    320 
    321     def CheckAggrExceptionInStep(self):
    322         cur = self.con.cursor()
    323         try:
    324             cur.execute("select excStep(t) from test")
    325             val = cur.fetchone()[0]
    326             self.fail("should have raised an OperationalError")
    327         except sqlite.OperationalError, e:
    328             self.assertEqual(e.args[0], "user-defined aggregate's 'step' method raised error")
    329 
    330     def CheckAggrExceptionInFinalize(self):
    331         cur = self.con.cursor()
    332         try:
    333             cur.execute("select excFinalize(t) from test")
    334             val = cur.fetchone()[0]
    335             self.fail("should have raised an OperationalError")
    336         except sqlite.OperationalError, e:
    337             self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
    338 
    339     def CheckAggrCheckParamStr(self):
    340         cur = self.con.cursor()
    341         cur.execute("select checkType('str', ?)", ("foo",))
    342         val = cur.fetchone()[0]
    343         self.assertEqual(val, 1)
    344 
    345     def CheckAggrCheckParamInt(self):
    346         cur = self.con.cursor()
    347         cur.execute("select checkType('int', ?)", (42,))
    348         val = cur.fetchone()[0]
    349         self.assertEqual(val, 1)
    350 
    351     def CheckAggrCheckParamFloat(self):
    352         cur = self.con.cursor()
    353         cur.execute("select checkType('float', ?)", (3.14,))
    354         val = cur.fetchone()[0]
    355         self.assertEqual(val, 1)
    356 
    357     def CheckAggrCheckParamNone(self):
    358         cur = self.con.cursor()
    359         cur.execute("select checkType('None', ?)", (None,))
    360         val = cur.fetchone()[0]
    361         self.assertEqual(val, 1)
    362 
    363     def CheckAggrCheckParamBlob(self):
    364         cur = self.con.cursor()
    365         cur.execute("select checkType('blob', ?)", (buffer("blob"),))
    366         val = cur.fetchone()[0]
    367         self.assertEqual(val, 1)
    368 
    369     def CheckAggrCheckAggrSum(self):
    370         cur = self.con.cursor()
    371         cur.execute("delete from test")
    372         cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
    373         cur.execute("select mysum(i) from test")
    374         val = cur.fetchone()[0]
    375         self.assertEqual(val, 60)
    376 
    377 class AuthorizerTests(unittest.TestCase):
    378     @staticmethod
    379     def authorizer_cb(action, arg1, arg2, dbname, source):
    380         if action != sqlite.SQLITE_SELECT:
    381             return sqlite.SQLITE_DENY
    382         if arg2 == 'c2' or arg1 == 't2':
    383             return sqlite.SQLITE_DENY
    384         return sqlite.SQLITE_OK
    385 
    386     def setUp(self):
    387         self.con = sqlite.connect(":memory:")
    388         self.con.executescript("""
    389             create table t1 (c1, c2);
    390             create table t2 (c1, c2);
    391             insert into t1 (c1, c2) values (1, 2);
    392             insert into t2 (c1, c2) values (4, 5);
    393             """)
    394 
    395         # For our security test:
    396         self.con.execute("select c2 from t2")
    397 
    398         self.con.set_authorizer(self.authorizer_cb)
    399 
    400     def tearDown(self):
    401         pass
    402 
    403     def test_table_access(self):
    404         try:
    405             self.con.execute("select * from t2")
    406         except sqlite.DatabaseError, e:
    407             if not e.args[0].endswith("prohibited"):
    408                 self.fail("wrong exception text: %s" % e.args[0])
    409             return
    410         self.fail("should have raised an exception due to missing privileges")
    411 
    412     def test_column_access(self):
    413         try:
    414             self.con.execute("select c2 from t1")
    415         except sqlite.DatabaseError, e:
    416             if not e.args[0].endswith("prohibited"):
    417                 self.fail("wrong exception text: %s" % e.args[0])
    418             return
    419         self.fail("should have raised an exception due to missing privileges")
    420 
    421 class AuthorizerRaiseExceptionTests(AuthorizerTests):
    422     @staticmethod
    423     def authorizer_cb(action, arg1, arg2, dbname, source):
    424         if action != sqlite.SQLITE_SELECT:
    425             raise ValueError
    426         if arg2 == 'c2' or arg1 == 't2':
    427             raise ValueError
    428         return sqlite.SQLITE_OK
    429 
    430 class AuthorizerIllegalTypeTests(AuthorizerTests):
    431     @staticmethod
    432     def authorizer_cb(action, arg1, arg2, dbname, source):
    433         if action != sqlite.SQLITE_SELECT:
    434             return 0.0
    435         if arg2 == 'c2' or arg1 == 't2':
    436             return 0.0
    437         return sqlite.SQLITE_OK
    438 
    439 class AuthorizerLargeIntegerTests(AuthorizerTests):
    440     @staticmethod
    441     def authorizer_cb(action, arg1, arg2, dbname, source):
    442         if action != sqlite.SQLITE_SELECT:
    443             return 2**32
    444         if arg2 == 'c2' or arg1 == 't2':
    445             return 2**32
    446         return sqlite.SQLITE_OK
    447 
    448 
    449 def suite():
    450     function_suite = unittest.makeSuite(FunctionTests, "Check")
    451     aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
    452     authorizer_suite = unittest.makeSuite(AuthorizerTests)
    453     return unittest.TestSuite((
    454             function_suite,
    455             aggregate_suite,
    456             authorizer_suite,
    457             unittest.makeSuite(AuthorizerRaiseExceptionTests),
    458             unittest.makeSuite(AuthorizerIllegalTypeTests),
    459             unittest.makeSuite(AuthorizerLargeIntegerTests),
    460         ))
    461 
    462 def test():
    463     runner = unittest.TextTestRunner()
    464     runner.run(suite())
    465 
    466 if __name__ == "__main__":
    467     test()
    468