Home | History | Annotate | Download | only in test
      1 #-*- coding: iso-8859-1 -*-
      2 # pysqlite2/test/hooks.py: tests for various SQLite-specific hooks
      3 #
      4 # Copyright (C) 2006-2007 Gerhard Hring <gh (at] ghaering.de>
      5 #
      6 # This file is part of pysqlite.
      7 #
      8 # This software is provided 'as-is', without any express or implied
      9 # warranty.  In no event will the authors be held liable for any damages
     10 # arising from the use of this software.
     11 #
     12 # Permission is granted to anyone to use this software for any purpose,
     13 # including commercial applications, and to alter it and redistribute it
     14 # freely, subject to the following restrictions:
     15 #
     16 # 1. The origin of this software must not be misrepresented; you must not
     17 #    claim that you wrote the original software. If you use this software
     18 #    in a product, an acknowledgment in the product documentation would be
     19 #    appreciated but is not required.
     20 # 2. Altered source versions must be plainly marked as such, and must not be
     21 #    misrepresented as being the original software.
     22 # 3. This notice may not be removed or altered from any source distribution.
     23 
     24 import unittest
     25 import sqlite3 as sqlite
     26 
     27 from test.support import TESTFN, unlink
     28 
     29 class CollationTests(unittest.TestCase):
     30     def CheckCreateCollationNotString(self):
     31         con = sqlite.connect(":memory:")
     32         with self.assertRaises(TypeError):
     33             con.create_collation(None, lambda x, y: (x > y) - (x < y))
     34 
     35     def CheckCreateCollationNotCallable(self):
     36         con = sqlite.connect(":memory:")
     37         with self.assertRaises(TypeError) as cm:
     38             con.create_collation("X", 42)
     39         self.assertEqual(str(cm.exception), 'parameter must be callable')
     40 
     41     def CheckCreateCollationNotAscii(self):
     42         con = sqlite.connect(":memory:")
     43         with self.assertRaises(sqlite.ProgrammingError):
     44             con.create_collation("coll", lambda x, y: (x > y) - (x < y))
     45 
     46     def CheckCreateCollationBadUpper(self):
     47         class BadUpperStr(str):
     48             def upper(self):
     49                 return None
     50         con = sqlite.connect(":memory:")
     51         mycoll = lambda x, y: -((x > y) - (x < y))
     52         con.create_collation(BadUpperStr("mycoll"), mycoll)
     53         result = con.execute("""
     54             select x from (
     55             select 'a' as x
     56             union
     57             select 'b' as x
     58             ) order by x collate mycoll
     59             """).fetchall()
     60         self.assertEqual(result[0][0], 'b')
     61         self.assertEqual(result[1][0], 'a')
     62 
     63     @unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 1),
     64                      'old SQLite versions crash on this test')
     65     def CheckCollationIsUsed(self):
     66         def mycoll(x, y):
     67             # reverse order
     68             return -((x > y) - (x < y))
     69 
     70         con = sqlite.connect(":memory:")
     71         con.create_collation("mycoll", mycoll)
     72         sql = """
     73             select x from (
     74             select 'a' as x
     75             union
     76             select 'b' as x
     77             union
     78             select 'c' as x
     79             ) order by x collate mycoll
     80             """
     81         result = con.execute(sql).fetchall()
     82         self.assertEqual(result, [('c',), ('b',), ('a',)],
     83                          msg='the expected order was not returned')
     84 
     85         con.create_collation("mycoll", None)
     86         with self.assertRaises(sqlite.OperationalError) as cm:
     87             result = con.execute(sql).fetchall()
     88         self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
     89 
     90     def CheckCollationReturnsLargeInteger(self):
     91         def mycoll(x, y):
     92             # reverse order
     93             return -((x > y) - (x < y)) * 2**32
     94         con = sqlite.connect(":memory:")
     95         con.create_collation("mycoll", mycoll)
     96         sql = """
     97             select x from (
     98             select 'a' as x
     99             union
    100             select 'b' as x
    101             union
    102             select 'c' as x
    103             ) order by x collate mycoll
    104             """
    105         result = con.execute(sql).fetchall()
    106         self.assertEqual(result, [('c',), ('b',), ('a',)],
    107                          msg="the expected order was not returned")
    108 
    109     def CheckCollationRegisterTwice(self):
    110         """
    111         Register two different collation functions under the same name.
    112         Verify that the last one is actually used.
    113         """
    114         con = sqlite.connect(":memory:")
    115         con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
    116         con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y)))
    117         result = con.execute("""
    118             select x from (select 'a' as x union select 'b' as x) order by x collate mycoll
    119             """).fetchall()
    120         self.assertEqual(result[0][0], 'b')
    121         self.assertEqual(result[1][0], 'a')
    122 
    123     def CheckDeregisterCollation(self):
    124         """
    125         Register a collation, then deregister it. Make sure an error is raised if we try
    126         to use it.
    127         """
    128         con = sqlite.connect(":memory:")
    129         con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
    130         con.create_collation("mycoll", None)
    131         with self.assertRaises(sqlite.OperationalError) as cm:
    132             con.execute("select 'a' as x union select 'b' as x order by x collate mycoll")
    133         self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
    134 
    135 class ProgressTests(unittest.TestCase):
    136     def CheckProgressHandlerUsed(self):
    137         """
    138         Test that the progress handler is invoked once it is set.
    139         """
    140         con = sqlite.connect(":memory:")
    141         progress_calls = []
    142         def progress():
    143             progress_calls.append(None)
    144             return 0
    145         con.set_progress_handler(progress, 1)
    146         con.execute("""
    147             create table foo(a, b)
    148             """)
    149         self.assertTrue(progress_calls)
    150 
    151 
    152     def CheckOpcodeCount(self):
    153         """
    154         Test that the opcode argument is respected.
    155         """
    156         con = sqlite.connect(":memory:")
    157         progress_calls = []
    158         def progress():
    159             progress_calls.append(None)
    160             return 0
    161         con.set_progress_handler(progress, 1)
    162         curs = con.cursor()
    163         curs.execute("""
    164             create table foo (a, b)
    165             """)
    166         first_count = len(progress_calls)
    167         progress_calls = []
    168         con.set_progress_handler(progress, 2)
    169         curs.execute("""
    170             create table bar (a, b)
    171             """)
    172         second_count = len(progress_calls)
    173         self.assertGreaterEqual(first_count, second_count)
    174 
    175     def CheckCancelOperation(self):
    176         """
    177         Test that returning a non-zero value stops the operation in progress.
    178         """
    179         con = sqlite.connect(":memory:")
    180         def progress():
    181             return 1
    182         con.set_progress_handler(progress, 1)
    183         curs = con.cursor()
    184         self.assertRaises(
    185             sqlite.OperationalError,
    186             curs.execute,
    187             "create table bar (a, b)")
    188 
    189     def CheckClearHandler(self):
    190         """
    191         Test that setting the progress handler to None clears the previously set handler.
    192         """
    193         con = sqlite.connect(":memory:")
    194         action = 0
    195         def progress():
    196             nonlocal action
    197             action = 1
    198             return 0
    199         con.set_progress_handler(progress, 1)
    200         con.set_progress_handler(None, 1)
    201         con.execute("select 1 union select 2 union select 3").fetchall()
    202         self.assertEqual(action, 0, "progress handler was not cleared")
    203 
    204 class TraceCallbackTests(unittest.TestCase):
    205     def CheckTraceCallbackUsed(self):
    206         """
    207         Test that the trace callback is invoked once it is set.
    208         """
    209         con = sqlite.connect(":memory:")
    210         traced_statements = []
    211         def trace(statement):
    212             traced_statements.append(statement)
    213         con.set_trace_callback(trace)
    214         con.execute("create table foo(a, b)")
    215         self.assertTrue(traced_statements)
    216         self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
    217 
    218     def CheckClearTraceCallback(self):
    219         """
    220         Test that setting the trace callback to None clears the previously set callback.
    221         """
    222         con = sqlite.connect(":memory:")
    223         traced_statements = []
    224         def trace(statement):
    225             traced_statements.append(statement)
    226         con.set_trace_callback(trace)
    227         con.set_trace_callback(None)
    228         con.execute("create table foo(a, b)")
    229         self.assertFalse(traced_statements, "trace callback was not cleared")
    230 
    231     def CheckUnicodeContent(self):
    232         """
    233         Test that the statement can contain unicode literals.
    234         """
    235         unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac'
    236         con = sqlite.connect(":memory:")
    237         traced_statements = []
    238         def trace(statement):
    239             traced_statements.append(statement)
    240         con.set_trace_callback(trace)
    241         con.execute("create table foo(x)")
    242         # Can't execute bound parameters as their values don't appear
    243         # in traced statements before SQLite 3.6.21
    244         # (cf. http://www.sqlite.org/draft/releaselog/3_6_21.html)
    245         con.execute('insert into foo(x) values ("%s")' % unicode_value)
    246         con.commit()
    247         self.assertTrue(any(unicode_value in stmt for stmt in traced_statements),
    248                         "Unicode data %s garbled in trace callback: %s"
    249                         % (ascii(unicode_value), ', '.join(map(ascii, traced_statements))))
    250 
    251     @unittest.skipIf(sqlite.sqlite_version_info < (3, 3, 9), "sqlite3_prepare_v2 is not available")
    252     def CheckTraceCallbackContent(self):
    253         # set_trace_callback() shouldn't produce duplicate content (bpo-26187)
    254         traced_statements = []
    255         def trace(statement):
    256             traced_statements.append(statement)
    257 
    258         queries = ["create table foo(x)",
    259                    "insert into foo(x) values(1)"]
    260         self.addCleanup(unlink, TESTFN)
    261         con1 = sqlite.connect(TESTFN, isolation_level=None)
    262         con2 = sqlite.connect(TESTFN)
    263         con1.set_trace_callback(trace)
    264         cur = con1.cursor()
    265         cur.execute(queries[0])
    266         con2.execute("create table bar(x)")
    267         cur.execute(queries[1])
    268         self.assertEqual(traced_statements, queries)
    269 
    270 
    271 def suite():
    272     collation_suite = unittest.makeSuite(CollationTests, "Check")
    273     progress_suite = unittest.makeSuite(ProgressTests, "Check")
    274     trace_suite = unittest.makeSuite(TraceCallbackTests, "Check")
    275     return unittest.TestSuite((collation_suite, progress_suite, trace_suite))
    276 
    277 def test():
    278     runner = unittest.TextTestRunner()
    279     runner.run(suite())
    280 
    281 if __name__ == "__main__":
    282     test()
    283