Home | History | Annotate | Download | only in test
      1 #-*- coding: iso-8859-1 -*-
      2 # pysqlite2/test/factory.py: tests for the various factories in pysqlite
      3 #
      4 # Copyright (C) 2005-2007 Gerhard Hring <gh (at] ghaering.de>
      5 #
      6 # This file is part of pysqlite.
      7 #
      8 # This software is provided 'as-is', without any express or implied
      9 # warranty.  In no event will the authors be held liable for any damages
     10 # arising from the use of this software.
     11 #
     12 # Permission is granted to anyone to use this software for any purpose,
     13 # including commercial applications, and to alter it and redistribute it
     14 # freely, subject to the following restrictions:
     15 #
     16 # 1. The origin of this software must not be misrepresented; you must not
     17 #    claim that you wrote the original software. If you use this software
     18 #    in a product, an acknowledgment in the product documentation would be
     19 #    appreciated but is not required.
     20 # 2. Altered source versions must be plainly marked as such, and must not be
     21 #    misrepresented as being the original software.
     22 # 3. This notice may not be removed or altered from any source distribution.
     23 
     24 import unittest
     25 import sqlite3 as sqlite
     26 from collections.abc import Sequence
     27 
     28 class MyConnection(sqlite.Connection):
     29     def __init__(self, *args, **kwargs):
     30         sqlite.Connection.__init__(self, *args, **kwargs)
     31 
     32 def dict_factory(cursor, row):
     33     d = {}
     34     for idx, col in enumerate(cursor.description):
     35         d[col[0]] = row[idx]
     36     return d
     37 
     38 class MyCursor(sqlite.Cursor):
     39     def __init__(self, *args, **kwargs):
     40         sqlite.Cursor.__init__(self, *args, **kwargs)
     41         self.row_factory = dict_factory
     42 
     43 class ConnectionFactoryTests(unittest.TestCase):
     44     def setUp(self):
     45         self.con = sqlite.connect(":memory:", factory=MyConnection)
     46 
     47     def tearDown(self):
     48         self.con.close()
     49 
     50     def CheckIsInstance(self):
     51         self.assertIsInstance(self.con, MyConnection)
     52 
     53 class CursorFactoryTests(unittest.TestCase):
     54     def setUp(self):
     55         self.con = sqlite.connect(":memory:")
     56 
     57     def tearDown(self):
     58         self.con.close()
     59 
     60     def CheckIsInstance(self):
     61         cur = self.con.cursor()
     62         self.assertIsInstance(cur, sqlite.Cursor)
     63         cur = self.con.cursor(MyCursor)
     64         self.assertIsInstance(cur, MyCursor)
     65         cur = self.con.cursor(factory=lambda con: MyCursor(con))
     66         self.assertIsInstance(cur, MyCursor)
     67 
     68     def CheckInvalidFactory(self):
     69         # not a callable at all
     70         self.assertRaises(TypeError, self.con.cursor, None)
     71         # invalid callable with not exact one argument
     72         self.assertRaises(TypeError, self.con.cursor, lambda: None)
     73         # invalid callable returning non-cursor
     74         self.assertRaises(TypeError, self.con.cursor, lambda con: None)
     75 
     76 class RowFactoryTestsBackwardsCompat(unittest.TestCase):
     77     def setUp(self):
     78         self.con = sqlite.connect(":memory:")
     79 
     80     def CheckIsProducedByFactory(self):
     81         cur = self.con.cursor(factory=MyCursor)
     82         cur.execute("select 4+5 as foo")
     83         row = cur.fetchone()
     84         self.assertIsInstance(row, dict)
     85         cur.close()
     86 
     87     def tearDown(self):
     88         self.con.close()
     89 
     90 class RowFactoryTests(unittest.TestCase):
     91     def setUp(self):
     92         self.con = sqlite.connect(":memory:")
     93 
     94     def CheckCustomFactory(self):
     95         self.con.row_factory = lambda cur, row: list(row)
     96         row = self.con.execute("select 1, 2").fetchone()
     97         self.assertIsInstance(row, list)
     98 
     99     def CheckSqliteRowIndex(self):
    100         self.con.row_factory = sqlite.Row
    101         row = self.con.execute("select 1 as a, 2 as b").fetchone()
    102         self.assertIsInstance(row, sqlite.Row)
    103 
    104         col1, col2 = row["a"], row["b"]
    105         self.assertEqual(col1, 1, "by name: wrong result for column 'a'")
    106         self.assertEqual(col2, 2, "by name: wrong result for column 'a'")
    107 
    108         col1, col2 = row["A"], row["B"]
    109         self.assertEqual(col1, 1, "by name: wrong result for column 'A'")
    110         self.assertEqual(col2, 2, "by name: wrong result for column 'B'")
    111 
    112         self.assertEqual(row[0], 1, "by index: wrong result for column 0")
    113         self.assertEqual(row[1], 2, "by index: wrong result for column 1")
    114         self.assertEqual(row[-1], 2, "by index: wrong result for column -1")
    115         self.assertEqual(row[-2], 1, "by index: wrong result for column -2")
    116 
    117         with self.assertRaises(IndexError):
    118             row['c']
    119         with self.assertRaises(IndexError):
    120             row[2]
    121         with self.assertRaises(IndexError):
    122             row[-3]
    123         with self.assertRaises(IndexError):
    124             row[2**1000]
    125 
    126     def CheckSqliteRowSlice(self):
    127         # A sqlite.Row can be sliced like a list.
    128         self.con.row_factory = sqlite.Row
    129         row = self.con.execute("select 1, 2, 3, 4").fetchone()
    130         self.assertEqual(row[0:0], ())
    131         self.assertEqual(row[0:1], (1,))
    132         self.assertEqual(row[1:3], (2, 3))
    133         self.assertEqual(row[3:1], ())
    134         # Explicit bounds are optional.
    135         self.assertEqual(row[1:], (2, 3, 4))
    136         self.assertEqual(row[:3], (1, 2, 3))
    137         # Slices can use negative indices.
    138         self.assertEqual(row[-2:-1], (3,))
    139         self.assertEqual(row[-2:], (3, 4))
    140         # Slicing supports steps.
    141         self.assertEqual(row[0:4:2], (1, 3))
    142         self.assertEqual(row[3:0:-2], (4, 2))
    143 
    144     def CheckSqliteRowIter(self):
    145         """Checks if the row object is iterable"""
    146         self.con.row_factory = sqlite.Row
    147         row = self.con.execute("select 1 as a, 2 as b").fetchone()
    148         for col in row:
    149             pass
    150 
    151     def CheckSqliteRowAsTuple(self):
    152         """Checks if the row object can be converted to a tuple"""
    153         self.con.row_factory = sqlite.Row
    154         row = self.con.execute("select 1 as a, 2 as b").fetchone()
    155         t = tuple(row)
    156         self.assertEqual(t, (row['a'], row['b']))
    157 
    158     def CheckSqliteRowAsDict(self):
    159         """Checks if the row object can be correctly converted to a dictionary"""
    160         self.con.row_factory = sqlite.Row
    161         row = self.con.execute("select 1 as a, 2 as b").fetchone()
    162         d = dict(row)
    163         self.assertEqual(d["a"], row["a"])
    164         self.assertEqual(d["b"], row["b"])
    165 
    166     def CheckSqliteRowHashCmp(self):
    167         """Checks if the row object compares and hashes correctly"""
    168         self.con.row_factory = sqlite.Row
    169         row_1 = self.con.execute("select 1 as a, 2 as b").fetchone()
    170         row_2 = self.con.execute("select 1 as a, 2 as b").fetchone()
    171         row_3 = self.con.execute("select 1 as a, 3 as b").fetchone()
    172 
    173         self.assertEqual(row_1, row_1)
    174         self.assertEqual(row_1, row_2)
    175         self.assertTrue(row_2 != row_3)
    176 
    177         self.assertFalse(row_1 != row_1)
    178         self.assertFalse(row_1 != row_2)
    179         self.assertFalse(row_2 == row_3)
    180 
    181         self.assertEqual(row_1, row_2)
    182         self.assertEqual(hash(row_1), hash(row_2))
    183         self.assertNotEqual(row_1, row_3)
    184         self.assertNotEqual(hash(row_1), hash(row_3))
    185 
    186     def CheckSqliteRowAsSequence(self):
    187         """ Checks if the row object can act like a sequence """
    188         self.con.row_factory = sqlite.Row
    189         row = self.con.execute("select 1 as a, 2 as b").fetchone()
    190 
    191         as_tuple = tuple(row)
    192         self.assertEqual(list(reversed(row)), list(reversed(as_tuple)))
    193         self.assertIsInstance(row, Sequence)
    194 
    195     def CheckFakeCursorClass(self):
    196         # Issue #24257: Incorrect use of PyObject_IsInstance() caused
    197         # segmentation fault.
    198         # Issue #27861: Also applies for cursor factory.
    199         class FakeCursor(str):
    200             __class__ = sqlite.Cursor
    201         self.con.row_factory = sqlite.Row
    202         self.assertRaises(TypeError, self.con.cursor, FakeCursor)
    203         self.assertRaises(TypeError, sqlite.Row, FakeCursor(), ())
    204 
    205     def tearDown(self):
    206         self.con.close()
    207 
    208 class TextFactoryTests(unittest.TestCase):
    209     def setUp(self):
    210         self.con = sqlite.connect(":memory:")
    211 
    212     def CheckUnicode(self):
    213         austria = "sterreich"
    214         row = self.con.execute("select ?", (austria,)).fetchone()
    215         self.assertEqual(type(row[0]), str, "type of row[0] must be unicode")
    216 
    217     def CheckString(self):
    218         self.con.text_factory = bytes
    219         austria = "sterreich"
    220         row = self.con.execute("select ?", (austria,)).fetchone()
    221         self.assertEqual(type(row[0]), bytes, "type of row[0] must be bytes")
    222         self.assertEqual(row[0], austria.encode("utf-8"), "column must equal original data in UTF-8")
    223 
    224     def CheckCustom(self):
    225         self.con.text_factory = lambda x: str(x, "utf-8", "ignore")
    226         austria = "sterreich"
    227         row = self.con.execute("select ?", (austria,)).fetchone()
    228         self.assertEqual(type(row[0]), str, "type of row[0] must be unicode")
    229         self.assertTrue(row[0].endswith("reich"), "column must contain original data")
    230 
    231     def CheckOptimizedUnicode(self):
    232         # In py3k, str objects are always returned when text_factory
    233         # is OptimizedUnicode
    234         self.con.text_factory = sqlite.OptimizedUnicode
    235         austria = "sterreich"
    236         germany = "Deutchland"
    237         a_row = self.con.execute("select ?", (austria,)).fetchone()
    238         d_row = self.con.execute("select ?", (germany,)).fetchone()
    239         self.assertEqual(type(a_row[0]), str, "type of non-ASCII row must be str")
    240         self.assertEqual(type(d_row[0]), str, "type of ASCII-only row must be str")
    241 
    242     def tearDown(self):
    243         self.con.close()
    244 
    245 class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase):
    246     def setUp(self):
    247         self.con = sqlite.connect(":memory:")
    248         self.con.execute("create table test (value text)")
    249         self.con.execute("insert into test (value) values (?)", ("a\x00b",))
    250 
    251     def CheckString(self):
    252         # text_factory defaults to str
    253         row = self.con.execute("select value from test").fetchone()
    254         self.assertIs(type(row[0]), str)
    255         self.assertEqual(row[0], "a\x00b")
    256 
    257     def CheckBytes(self):
    258         self.con.text_factory = bytes
    259         row = self.con.execute("select value from test").fetchone()
    260         self.assertIs(type(row[0]), bytes)
    261         self.assertEqual(row[0], b"a\x00b")
    262 
    263     def CheckBytearray(self):
    264         self.con.text_factory = bytearray
    265         row = self.con.execute("select value from test").fetchone()
    266         self.assertIs(type(row[0]), bytearray)
    267         self.assertEqual(row[0], b"a\x00b")
    268 
    269     def CheckCustom(self):
    270         # A custom factory should receive a bytes argument
    271         self.con.text_factory = lambda x: x
    272         row = self.con.execute("select value from test").fetchone()
    273         self.assertIs(type(row[0]), bytes)
    274         self.assertEqual(row[0], b"a\x00b")
    275 
    276     def tearDown(self):
    277         self.con.close()
    278 
    279 def suite():
    280     connection_suite = unittest.makeSuite(ConnectionFactoryTests, "Check")
    281     cursor_suite = unittest.makeSuite(CursorFactoryTests, "Check")
    282     row_suite_compat = unittest.makeSuite(RowFactoryTestsBackwardsCompat, "Check")
    283     row_suite = unittest.makeSuite(RowFactoryTests, "Check")
    284     text_suite = unittest.makeSuite(TextFactoryTests, "Check")
    285     text_zero_bytes_suite = unittest.makeSuite(TextFactoryTestsWithEmbeddedZeroBytes, "Check")
    286     return unittest.TestSuite((connection_suite, cursor_suite, row_suite_compat, row_suite, text_suite, text_zero_bytes_suite))
    287 
    288 def test():
    289     runner = unittest.TextTestRunner()
    290     runner.run(suite())
    291 
    292 if __name__ == "__main__":
    293     test()
    294