1 #-*- coding: ISO-8859-1 -*- 2 # pysqlite2/test/userfunctions.py: tests for user-defined functions and 3 # aggregates. 4 # 5 # Copyright (C) 2005-2007 Gerhard Hring <gh (at] ghaering.de> 6 # 7 # This file is part of pysqlite. 8 # 9 # This software is provided 'as-is', without any express or implied 10 # warranty. In no event will the authors be held liable for any damages 11 # arising from the use of this software. 12 # 13 # Permission is granted to anyone to use this software for any purpose, 14 # including commercial applications, and to alter it and redistribute it 15 # freely, subject to the following restrictions: 16 # 17 # 1. The origin of this software must not be misrepresented; you must not 18 # claim that you wrote the original software. If you use this software 19 # in a product, an acknowledgment in the product documentation would be 20 # appreciated but is not required. 21 # 2. Altered source versions must be plainly marked as such, and must not be 22 # misrepresented as being the original software. 23 # 3. This notice may not be removed or altered from any source distribution. 24 25 import unittest 26 import sqlite3 as sqlite 27 28 def func_returntext(): 29 return "foo" 30 def func_returnunicode(): 31 return u"bar" 32 def func_returnint(): 33 return 42 34 def func_returnfloat(): 35 return 3.14 36 def func_returnnull(): 37 return None 38 def func_returnblob(): 39 return buffer("blob") 40 def func_returnlonglong(): 41 return 1<<31 42 def func_raiseexception(): 43 5 // 0 44 45 def func_isstring(v): 46 return type(v) is unicode 47 def func_isint(v): 48 return type(v) is int 49 def func_isfloat(v): 50 return type(v) is float 51 def func_isnone(v): 52 return type(v) is type(None) 53 def func_isblob(v): 54 return type(v) is buffer 55 def func_islonglong(v): 56 return isinstance(v, (int, long)) and v >= 1<<31 57 58 class AggrNoStep: 59 def __init__(self): 60 pass 61 62 def finalize(self): 63 return 1 64 65 class AggrNoFinalize: 66 def __init__(self): 67 pass 68 69 def step(self, x): 70 pass 71 72 class AggrExceptionInInit: 73 def __init__(self): 74 5 // 0 75 76 def step(self, x): 77 pass 78 79 def finalize(self): 80 pass 81 82 class AggrExceptionInStep: 83 def __init__(self): 84 pass 85 86 def step(self, x): 87 5 // 0 88 89 def finalize(self): 90 return 42 91 92 class AggrExceptionInFinalize: 93 def __init__(self): 94 pass 95 96 def step(self, x): 97 pass 98 99 def finalize(self): 100 5 // 0 101 102 class AggrCheckType: 103 def __init__(self): 104 self.val = None 105 106 def step(self, whichType, val): 107 theType = {"str": unicode, "int": int, "float": float, "None": type(None), "blob": buffer} 108 self.val = int(theType[whichType] is type(val)) 109 110 def finalize(self): 111 return self.val 112 113 class AggrSum: 114 def __init__(self): 115 self.val = 0.0 116 117 def step(self, val): 118 self.val += val 119 120 def finalize(self): 121 return self.val 122 123 class FunctionTests(unittest.TestCase): 124 def setUp(self): 125 self.con = sqlite.connect(":memory:") 126 127 self.con.create_function("returntext", 0, func_returntext) 128 self.con.create_function("returnunicode", 0, func_returnunicode) 129 self.con.create_function("returnint", 0, func_returnint) 130 self.con.create_function("returnfloat", 0, func_returnfloat) 131 self.con.create_function("returnnull", 0, func_returnnull) 132 self.con.create_function("returnblob", 0, func_returnblob) 133 self.con.create_function("returnlonglong", 0, func_returnlonglong) 134 self.con.create_function("raiseexception", 0, func_raiseexception) 135 136 self.con.create_function("isstring", 1, func_isstring) 137 self.con.create_function("isint", 1, func_isint) 138 self.con.create_function("isfloat", 1, func_isfloat) 139 self.con.create_function("isnone", 1, func_isnone) 140 self.con.create_function("isblob", 1, func_isblob) 141 self.con.create_function("islonglong", 1, func_islonglong) 142 143 def tearDown(self): 144 self.con.close() 145 146 def CheckFuncErrorOnCreate(self): 147 try: 148 self.con.create_function("bla", -100, lambda x: 2*x) 149 self.fail("should have raised an OperationalError") 150 except sqlite.OperationalError: 151 pass 152 153 def CheckFuncRefCount(self): 154 def getfunc(): 155 def f(): 156 return 1 157 return f 158 f = getfunc() 159 globals()["foo"] = f 160 # self.con.create_function("reftest", 0, getfunc()) 161 self.con.create_function("reftest", 0, f) 162 cur = self.con.cursor() 163 cur.execute("select reftest()") 164 165 def CheckFuncReturnText(self): 166 cur = self.con.cursor() 167 cur.execute("select returntext()") 168 val = cur.fetchone()[0] 169 self.assertEqual(type(val), unicode) 170 self.assertEqual(val, "foo") 171 172 def CheckFuncReturnUnicode(self): 173 cur = self.con.cursor() 174 cur.execute("select returnunicode()") 175 val = cur.fetchone()[0] 176 self.assertEqual(type(val), unicode) 177 self.assertEqual(val, u"bar") 178 179 def CheckFuncReturnInt(self): 180 cur = self.con.cursor() 181 cur.execute("select returnint()") 182 val = cur.fetchone()[0] 183 self.assertEqual(type(val), int) 184 self.assertEqual(val, 42) 185 186 def CheckFuncReturnFloat(self): 187 cur = self.con.cursor() 188 cur.execute("select returnfloat()") 189 val = cur.fetchone()[0] 190 self.assertEqual(type(val), float) 191 if val < 3.139 or val > 3.141: 192 self.fail("wrong value") 193 194 def CheckFuncReturnNull(self): 195 cur = self.con.cursor() 196 cur.execute("select returnnull()") 197 val = cur.fetchone()[0] 198 self.assertEqual(type(val), type(None)) 199 self.assertEqual(val, None) 200 201 def CheckFuncReturnBlob(self): 202 cur = self.con.cursor() 203 cur.execute("select returnblob()") 204 val = cur.fetchone()[0] 205 self.assertEqual(type(val), buffer) 206 self.assertEqual(val, buffer("blob")) 207 208 def CheckFuncReturnLongLong(self): 209 cur = self.con.cursor() 210 cur.execute("select returnlonglong()") 211 val = cur.fetchone()[0] 212 self.assertEqual(val, 1<<31) 213 214 def CheckFuncException(self): 215 cur = self.con.cursor() 216 try: 217 cur.execute("select raiseexception()") 218 cur.fetchone() 219 self.fail("should have raised OperationalError") 220 except sqlite.OperationalError, e: 221 self.assertEqual(e.args[0], 'user-defined function raised exception') 222 223 def CheckParamString(self): 224 cur = self.con.cursor() 225 cur.execute("select isstring(?)", ("foo",)) 226 val = cur.fetchone()[0] 227 self.assertEqual(val, 1) 228 229 def CheckParamInt(self): 230 cur = self.con.cursor() 231 cur.execute("select isint(?)", (42,)) 232 val = cur.fetchone()[0] 233 self.assertEqual(val, 1) 234 235 def CheckParamFloat(self): 236 cur = self.con.cursor() 237 cur.execute("select isfloat(?)", (3.14,)) 238 val = cur.fetchone()[0] 239 self.assertEqual(val, 1) 240 241 def CheckParamNone(self): 242 cur = self.con.cursor() 243 cur.execute("select isnone(?)", (None,)) 244 val = cur.fetchone()[0] 245 self.assertEqual(val, 1) 246 247 def CheckParamBlob(self): 248 cur = self.con.cursor() 249 cur.execute("select isblob(?)", (buffer("blob"),)) 250 val = cur.fetchone()[0] 251 self.assertEqual(val, 1) 252 253 def CheckParamLongLong(self): 254 cur = self.con.cursor() 255 cur.execute("select islonglong(?)", (1<<42,)) 256 val = cur.fetchone()[0] 257 self.assertEqual(val, 1) 258 259 class AggregateTests(unittest.TestCase): 260 def setUp(self): 261 self.con = sqlite.connect(":memory:") 262 cur = self.con.cursor() 263 cur.execute(""" 264 create table test( 265 t text, 266 i integer, 267 f float, 268 n, 269 b blob 270 ) 271 """) 272 cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)", 273 ("foo", 5, 3.14, None, buffer("blob"),)) 274 275 self.con.create_aggregate("nostep", 1, AggrNoStep) 276 self.con.create_aggregate("nofinalize", 1, AggrNoFinalize) 277 self.con.create_aggregate("excInit", 1, AggrExceptionInInit) 278 self.con.create_aggregate("excStep", 1, AggrExceptionInStep) 279 self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize) 280 self.con.create_aggregate("checkType", 2, AggrCheckType) 281 self.con.create_aggregate("mysum", 1, AggrSum) 282 283 def tearDown(self): 284 #self.cur.close() 285 #self.con.close() 286 pass 287 288 def CheckAggrErrorOnCreate(self): 289 try: 290 self.con.create_function("bla", -100, AggrSum) 291 self.fail("should have raised an OperationalError") 292 except sqlite.OperationalError: 293 pass 294 295 def CheckAggrNoStep(self): 296 cur = self.con.cursor() 297 try: 298 cur.execute("select nostep(t) from test") 299 self.fail("should have raised an AttributeError") 300 except AttributeError, e: 301 self.assertEqual(e.args[0], "AggrNoStep instance has no attribute 'step'") 302 303 def CheckAggrNoFinalize(self): 304 cur = self.con.cursor() 305 try: 306 cur.execute("select nofinalize(t) from test") 307 val = cur.fetchone()[0] 308 self.fail("should have raised an OperationalError") 309 except sqlite.OperationalError, e: 310 self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error") 311 312 def CheckAggrExceptionInInit(self): 313 cur = self.con.cursor() 314 try: 315 cur.execute("select excInit(t) from test") 316 val = cur.fetchone()[0] 317 self.fail("should have raised an OperationalError") 318 except sqlite.OperationalError, e: 319 self.assertEqual(e.args[0], "user-defined aggregate's '__init__' method raised error") 320 321 def CheckAggrExceptionInStep(self): 322 cur = self.con.cursor() 323 try: 324 cur.execute("select excStep(t) from test") 325 val = cur.fetchone()[0] 326 self.fail("should have raised an OperationalError") 327 except sqlite.OperationalError, e: 328 self.assertEqual(e.args[0], "user-defined aggregate's 'step' method raised error") 329 330 def CheckAggrExceptionInFinalize(self): 331 cur = self.con.cursor() 332 try: 333 cur.execute("select excFinalize(t) from test") 334 val = cur.fetchone()[0] 335 self.fail("should have raised an OperationalError") 336 except sqlite.OperationalError, e: 337 self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error") 338 339 def CheckAggrCheckParamStr(self): 340 cur = self.con.cursor() 341 cur.execute("select checkType('str', ?)", ("foo",)) 342 val = cur.fetchone()[0] 343 self.assertEqual(val, 1) 344 345 def CheckAggrCheckParamInt(self): 346 cur = self.con.cursor() 347 cur.execute("select checkType('int', ?)", (42,)) 348 val = cur.fetchone()[0] 349 self.assertEqual(val, 1) 350 351 def CheckAggrCheckParamFloat(self): 352 cur = self.con.cursor() 353 cur.execute("select checkType('float', ?)", (3.14,)) 354 val = cur.fetchone()[0] 355 self.assertEqual(val, 1) 356 357 def CheckAggrCheckParamNone(self): 358 cur = self.con.cursor() 359 cur.execute("select checkType('None', ?)", (None,)) 360 val = cur.fetchone()[0] 361 self.assertEqual(val, 1) 362 363 def CheckAggrCheckParamBlob(self): 364 cur = self.con.cursor() 365 cur.execute("select checkType('blob', ?)", (buffer("blob"),)) 366 val = cur.fetchone()[0] 367 self.assertEqual(val, 1) 368 369 def CheckAggrCheckAggrSum(self): 370 cur = self.con.cursor() 371 cur.execute("delete from test") 372 cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)]) 373 cur.execute("select mysum(i) from test") 374 val = cur.fetchone()[0] 375 self.assertEqual(val, 60) 376 377 class AuthorizerTests(unittest.TestCase): 378 @staticmethod 379 def authorizer_cb(action, arg1, arg2, dbname, source): 380 if action != sqlite.SQLITE_SELECT: 381 return sqlite.SQLITE_DENY 382 if arg2 == 'c2' or arg1 == 't2': 383 return sqlite.SQLITE_DENY 384 return sqlite.SQLITE_OK 385 386 def setUp(self): 387 self.con = sqlite.connect(":memory:") 388 self.con.executescript(""" 389 create table t1 (c1, c2); 390 create table t2 (c1, c2); 391 insert into t1 (c1, c2) values (1, 2); 392 insert into t2 (c1, c2) values (4, 5); 393 """) 394 395 # For our security test: 396 self.con.execute("select c2 from t2") 397 398 self.con.set_authorizer(self.authorizer_cb) 399 400 def tearDown(self): 401 pass 402 403 def test_table_access(self): 404 try: 405 self.con.execute("select * from t2") 406 except sqlite.DatabaseError, e: 407 if not e.args[0].endswith("prohibited"): 408 self.fail("wrong exception text: %s" % e.args[0]) 409 return 410 self.fail("should have raised an exception due to missing privileges") 411 412 def test_column_access(self): 413 try: 414 self.con.execute("select c2 from t1") 415 except sqlite.DatabaseError, e: 416 if not e.args[0].endswith("prohibited"): 417 self.fail("wrong exception text: %s" % e.args[0]) 418 return 419 self.fail("should have raised an exception due to missing privileges") 420 421 class AuthorizerRaiseExceptionTests(AuthorizerTests): 422 @staticmethod 423 def authorizer_cb(action, arg1, arg2, dbname, source): 424 if action != sqlite.SQLITE_SELECT: 425 raise ValueError 426 if arg2 == 'c2' or arg1 == 't2': 427 raise ValueError 428 return sqlite.SQLITE_OK 429 430 class AuthorizerIllegalTypeTests(AuthorizerTests): 431 @staticmethod 432 def authorizer_cb(action, arg1, arg2, dbname, source): 433 if action != sqlite.SQLITE_SELECT: 434 return 0.0 435 if arg2 == 'c2' or arg1 == 't2': 436 return 0.0 437 return sqlite.SQLITE_OK 438 439 class AuthorizerLargeIntegerTests(AuthorizerTests): 440 @staticmethod 441 def authorizer_cb(action, arg1, arg2, dbname, source): 442 if action != sqlite.SQLITE_SELECT: 443 return 2**32 444 if arg2 == 'c2' or arg1 == 't2': 445 return 2**32 446 return sqlite.SQLITE_OK 447 448 449 def suite(): 450 function_suite = unittest.makeSuite(FunctionTests, "Check") 451 aggregate_suite = unittest.makeSuite(AggregateTests, "Check") 452 authorizer_suite = unittest.makeSuite(AuthorizerTests) 453 return unittest.TestSuite(( 454 function_suite, 455 aggregate_suite, 456 authorizer_suite, 457 unittest.makeSuite(AuthorizerRaiseExceptionTests), 458 unittest.makeSuite(AuthorizerIllegalTypeTests), 459 unittest.makeSuite(AuthorizerLargeIntegerTests), 460 )) 461 462 def test(): 463 runner = unittest.TextTestRunner() 464 runner.run(suite()) 465 466 if __name__ == "__main__": 467 test() 468