1 #-*- coding: iso-8859-1 -*- 2 # pysqlite2/test/hooks.py: tests for various SQLite-specific hooks 3 # 4 # Copyright (C) 2006-2007 Gerhard Hring <gh (at] ghaering.de> 5 # 6 # This file is part of pysqlite. 7 # 8 # This software is provided 'as-is', without any express or implied 9 # warranty. In no event will the authors be held liable for any damages 10 # arising from the use of this software. 11 # 12 # Permission is granted to anyone to use this software for any purpose, 13 # including commercial applications, and to alter it and redistribute it 14 # freely, subject to the following restrictions: 15 # 16 # 1. The origin of this software must not be misrepresented; you must not 17 # claim that you wrote the original software. If you use this software 18 # in a product, an acknowledgment in the product documentation would be 19 # appreciated but is not required. 20 # 2. Altered source versions must be plainly marked as such, and must not be 21 # misrepresented as being the original software. 22 # 3. This notice may not be removed or altered from any source distribution. 23 24 import unittest 25 import sqlite3 as sqlite 26 27 from test.support import TESTFN, unlink 28 29 class CollationTests(unittest.TestCase): 30 def CheckCreateCollationNotString(self): 31 con = sqlite.connect(":memory:") 32 with self.assertRaises(TypeError): 33 con.create_collation(None, lambda x, y: (x > y) - (x < y)) 34 35 def CheckCreateCollationNotCallable(self): 36 con = sqlite.connect(":memory:") 37 with self.assertRaises(TypeError) as cm: 38 con.create_collation("X", 42) 39 self.assertEqual(str(cm.exception), 'parameter must be callable') 40 41 def CheckCreateCollationNotAscii(self): 42 con = sqlite.connect(":memory:") 43 with self.assertRaises(sqlite.ProgrammingError): 44 con.create_collation("coll", lambda x, y: (x > y) - (x < y)) 45 46 def CheckCreateCollationBadUpper(self): 47 class BadUpperStr(str): 48 def upper(self): 49 return None 50 con = sqlite.connect(":memory:") 51 mycoll = lambda x, y: -((x > y) - (x < y)) 52 con.create_collation(BadUpperStr("mycoll"), mycoll) 53 result = con.execute(""" 54 select x from ( 55 select 'a' as x 56 union 57 select 'b' as x 58 ) order by x collate mycoll 59 """).fetchall() 60 self.assertEqual(result[0][0], 'b') 61 self.assertEqual(result[1][0], 'a') 62 63 @unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 1), 64 'old SQLite versions crash on this test') 65 def CheckCollationIsUsed(self): 66 def mycoll(x, y): 67 # reverse order 68 return -((x > y) - (x < y)) 69 70 con = sqlite.connect(":memory:") 71 con.create_collation("mycoll", mycoll) 72 sql = """ 73 select x from ( 74 select 'a' as x 75 union 76 select 'b' as x 77 union 78 select 'c' as x 79 ) order by x collate mycoll 80 """ 81 result = con.execute(sql).fetchall() 82 self.assertEqual(result, [('c',), ('b',), ('a',)], 83 msg='the expected order was not returned') 84 85 con.create_collation("mycoll", None) 86 with self.assertRaises(sqlite.OperationalError) as cm: 87 result = con.execute(sql).fetchall() 88 self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') 89 90 def CheckCollationReturnsLargeInteger(self): 91 def mycoll(x, y): 92 # reverse order 93 return -((x > y) - (x < y)) * 2**32 94 con = sqlite.connect(":memory:") 95 con.create_collation("mycoll", mycoll) 96 sql = """ 97 select x from ( 98 select 'a' as x 99 union 100 select 'b' as x 101 union 102 select 'c' as x 103 ) order by x collate mycoll 104 """ 105 result = con.execute(sql).fetchall() 106 self.assertEqual(result, [('c',), ('b',), ('a',)], 107 msg="the expected order was not returned") 108 109 def CheckCollationRegisterTwice(self): 110 """ 111 Register two different collation functions under the same name. 112 Verify that the last one is actually used. 113 """ 114 con = sqlite.connect(":memory:") 115 con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) 116 con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y))) 117 result = con.execute(""" 118 select x from (select 'a' as x union select 'b' as x) order by x collate mycoll 119 """).fetchall() 120 self.assertEqual(result[0][0], 'b') 121 self.assertEqual(result[1][0], 'a') 122 123 def CheckDeregisterCollation(self): 124 """ 125 Register a collation, then deregister it. Make sure an error is raised if we try 126 to use it. 127 """ 128 con = sqlite.connect(":memory:") 129 con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) 130 con.create_collation("mycoll", None) 131 with self.assertRaises(sqlite.OperationalError) as cm: 132 con.execute("select 'a' as x union select 'b' as x order by x collate mycoll") 133 self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') 134 135 class ProgressTests(unittest.TestCase): 136 def CheckProgressHandlerUsed(self): 137 """ 138 Test that the progress handler is invoked once it is set. 139 """ 140 con = sqlite.connect(":memory:") 141 progress_calls = [] 142 def progress(): 143 progress_calls.append(None) 144 return 0 145 con.set_progress_handler(progress, 1) 146 con.execute(""" 147 create table foo(a, b) 148 """) 149 self.assertTrue(progress_calls) 150 151 152 def CheckOpcodeCount(self): 153 """ 154 Test that the opcode argument is respected. 155 """ 156 con = sqlite.connect(":memory:") 157 progress_calls = [] 158 def progress(): 159 progress_calls.append(None) 160 return 0 161 con.set_progress_handler(progress, 1) 162 curs = con.cursor() 163 curs.execute(""" 164 create table foo (a, b) 165 """) 166 first_count = len(progress_calls) 167 progress_calls = [] 168 con.set_progress_handler(progress, 2) 169 curs.execute(""" 170 create table bar (a, b) 171 """) 172 second_count = len(progress_calls) 173 self.assertGreaterEqual(first_count, second_count) 174 175 def CheckCancelOperation(self): 176 """ 177 Test that returning a non-zero value stops the operation in progress. 178 """ 179 con = sqlite.connect(":memory:") 180 def progress(): 181 return 1 182 con.set_progress_handler(progress, 1) 183 curs = con.cursor() 184 self.assertRaises( 185 sqlite.OperationalError, 186 curs.execute, 187 "create table bar (a, b)") 188 189 def CheckClearHandler(self): 190 """ 191 Test that setting the progress handler to None clears the previously set handler. 192 """ 193 con = sqlite.connect(":memory:") 194 action = 0 195 def progress(): 196 nonlocal action 197 action = 1 198 return 0 199 con.set_progress_handler(progress, 1) 200 con.set_progress_handler(None, 1) 201 con.execute("select 1 union select 2 union select 3").fetchall() 202 self.assertEqual(action, 0, "progress handler was not cleared") 203 204 class TraceCallbackTests(unittest.TestCase): 205 def CheckTraceCallbackUsed(self): 206 """ 207 Test that the trace callback is invoked once it is set. 208 """ 209 con = sqlite.connect(":memory:") 210 traced_statements = [] 211 def trace(statement): 212 traced_statements.append(statement) 213 con.set_trace_callback(trace) 214 con.execute("create table foo(a, b)") 215 self.assertTrue(traced_statements) 216 self.assertTrue(any("create table foo" in stmt for stmt in traced_statements)) 217 218 def CheckClearTraceCallback(self): 219 """ 220 Test that setting the trace callback to None clears the previously set callback. 221 """ 222 con = sqlite.connect(":memory:") 223 traced_statements = [] 224 def trace(statement): 225 traced_statements.append(statement) 226 con.set_trace_callback(trace) 227 con.set_trace_callback(None) 228 con.execute("create table foo(a, b)") 229 self.assertFalse(traced_statements, "trace callback was not cleared") 230 231 def CheckUnicodeContent(self): 232 """ 233 Test that the statement can contain unicode literals. 234 """ 235 unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac' 236 con = sqlite.connect(":memory:") 237 traced_statements = [] 238 def trace(statement): 239 traced_statements.append(statement) 240 con.set_trace_callback(trace) 241 con.execute("create table foo(x)") 242 # Can't execute bound parameters as their values don't appear 243 # in traced statements before SQLite 3.6.21 244 # (cf. http://www.sqlite.org/draft/releaselog/3_6_21.html) 245 con.execute('insert into foo(x) values ("%s")' % unicode_value) 246 con.commit() 247 self.assertTrue(any(unicode_value in stmt for stmt in traced_statements), 248 "Unicode data %s garbled in trace callback: %s" 249 % (ascii(unicode_value), ', '.join(map(ascii, traced_statements)))) 250 251 @unittest.skipIf(sqlite.sqlite_version_info < (3, 3, 9), "sqlite3_prepare_v2 is not available") 252 def CheckTraceCallbackContent(self): 253 # set_trace_callback() shouldn't produce duplicate content (bpo-26187) 254 traced_statements = [] 255 def trace(statement): 256 traced_statements.append(statement) 257 258 queries = ["create table foo(x)", 259 "insert into foo(x) values(1)"] 260 self.addCleanup(unlink, TESTFN) 261 con1 = sqlite.connect(TESTFN, isolation_level=None) 262 con2 = sqlite.connect(TESTFN) 263 con1.set_trace_callback(trace) 264 cur = con1.cursor() 265 cur.execute(queries[0]) 266 con2.execute("create table bar(x)") 267 cur.execute(queries[1]) 268 self.assertEqual(traced_statements, queries) 269 270 271 def suite(): 272 collation_suite = unittest.makeSuite(CollationTests, "Check") 273 progress_suite = unittest.makeSuite(ProgressTests, "Check") 274 trace_suite = unittest.makeSuite(TraceCallbackTests, "Check") 275 return unittest.TestSuite((collation_suite, progress_suite, trace_suite)) 276 277 def test(): 278 runner = unittest.TextTestRunner() 279 runner.run(suite()) 280 281 if __name__ == "__main__": 282 test() 283