Home | History | Annotate | Download | only in database
      1 #!/usr/bin/python
      2 
      3 import unittest, time
      4 import common
      5 from autotest_lib.client.common_lib import global_config
      6 from autotest_lib.client.common_lib.test_utils import mock
      7 from autotest_lib.database import database_connection
      8 
      9 _CONFIG_SECTION = 'AUTOTEST_WEB'
     10 _HOST = 'myhost'
     11 _USER = 'myuser'
     12 _PASS = 'mypass'
     13 _DB_NAME = 'mydb'
     14 _DB_TYPE = 'mydbtype'
     15 
     16 _CONNECT_KWARGS = dict(host=_HOST, username=_USER, password=_PASS,
     17                        db_name=_DB_NAME)
     18 _RECONNECT_DELAY = 10
     19 
     20 class FakeDatabaseError(Exception):
     21     pass
     22 
     23 
     24 class DatabaseConnectionTest(unittest.TestCase):
     25     def setUp(self):
     26         self.god = mock.mock_god()
     27         self.god.stub_function(time, 'sleep')
     28 
     29 
     30     def tearDown(self):
     31         global_config.global_config.reset_config_values()
     32         self.god.unstub_all()
     33 
     34 
     35     def _get_database_connection(self, config_section=_CONFIG_SECTION):
     36         if config_section == _CONFIG_SECTION:
     37             self._override_config()
     38         db = database_connection.DatabaseConnection(config_section)
     39 
     40         self._fake_backend = self.god.create_mock_class(
     41             database_connection._GenericBackend, 'fake_backend')
     42         for exception in database_connection._DB_EXCEPTIONS:
     43             setattr(self._fake_backend, exception, FakeDatabaseError)
     44         self._fake_backend.rowcount = 0
     45 
     46         def get_fake_backend(db_type):
     47             self._db_type = db_type
     48             return self._fake_backend
     49         self.god.stub_with(db, '_get_backend', get_fake_backend)
     50 
     51         db.reconnect_delay_sec = _RECONNECT_DELAY
     52         return db
     53 
     54 
     55     def _override_config(self):
     56         c = global_config.global_config
     57         c.override_config_value(_CONFIG_SECTION, 'host', _HOST)
     58         c.override_config_value(_CONFIG_SECTION, 'user', _USER)
     59         c.override_config_value(_CONFIG_SECTION, 'password', _PASS)
     60         c.override_config_value(_CONFIG_SECTION, 'database', _DB_NAME)
     61         c.override_config_value(_CONFIG_SECTION, 'db_type', _DB_TYPE)
     62 
     63 
     64     def test_connect(self):
     65         db = self._get_database_connection(config_section=None)
     66         self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
     67 
     68         db.connect(db_type=_DB_TYPE, host=_HOST, username=_USER,
     69                    password=_PASS, db_name=_DB_NAME)
     70 
     71         self.assertEquals(self._db_type, _DB_TYPE)
     72         self.god.check_playback()
     73 
     74 
     75     def test_global_config(self):
     76         db = self._get_database_connection()
     77         self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
     78 
     79         db.connect()
     80 
     81         self.assertEquals(self._db_type, _DB_TYPE)
     82         self.god.check_playback()
     83 
     84 
     85     def _expect_reconnect(self, fail=False):
     86         self._fake_backend.disconnect.expect_call()
     87         call = self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
     88         if fail:
     89             call.and_raises(FakeDatabaseError())
     90 
     91 
     92     def _expect_fail_and_reconnect(self, num_reconnects, fail_last=False):
     93         self._fake_backend.connect.expect_call(**_CONNECT_KWARGS).and_raises(
     94             FakeDatabaseError())
     95         for i in xrange(num_reconnects):
     96             time.sleep.expect_call(_RECONNECT_DELAY)
     97             if i < num_reconnects - 1:
     98                 self._expect_reconnect(fail=True)
     99             else:
    100                 self._expect_reconnect(fail=fail_last)
    101 
    102 
    103     def test_connect_retry(self):
    104         db = self._get_database_connection()
    105         self._expect_fail_and_reconnect(1)
    106 
    107         db.connect()
    108         self.god.check_playback()
    109 
    110         self._fake_backend.disconnect.expect_call()
    111         self._expect_fail_and_reconnect(0)
    112         self.assertRaises(FakeDatabaseError, db.connect,
    113                           try_reconnecting=False)
    114         self.god.check_playback()
    115 
    116         db.reconnect_enabled = False
    117         self._fake_backend.disconnect.expect_call()
    118         self._expect_fail_and_reconnect(0)
    119         self.assertRaises(FakeDatabaseError, db.connect)
    120         self.god.check_playback()
    121 
    122 
    123     def test_max_reconnect(self):
    124         db = self._get_database_connection()
    125         db.max_reconnect_attempts = 5
    126         self._expect_fail_and_reconnect(5, fail_last=True)
    127 
    128         self.assertRaises(FakeDatabaseError, db.connect)
    129         self.god.check_playback()
    130 
    131 
    132     def test_reconnect_forever(self):
    133         db = self._get_database_connection()
    134         db.max_reconnect_attempts = database_connection.RECONNECT_FOREVER
    135         self._expect_fail_and_reconnect(30)
    136 
    137         db.connect()
    138         self.god.check_playback()
    139 
    140 
    141     def _simple_connect(self, db):
    142         self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
    143         db.connect()
    144         self.god.check_playback()
    145 
    146 
    147     def test_disconnect(self):
    148         db = self._get_database_connection()
    149         self._simple_connect(db)
    150         self._fake_backend.disconnect.expect_call()
    151 
    152         db.disconnect()
    153         self.god.check_playback()
    154 
    155 
    156     def test_execute(self):
    157         db = self._get_database_connection()
    158         self._simple_connect(db)
    159         params = object()
    160         self._fake_backend.execute.expect_call('query', params)
    161 
    162         db.execute('query', params)
    163         self.god.check_playback()
    164 
    165 
    166     def test_execute_retry(self):
    167         db = self._get_database_connection()
    168         self._simple_connect(db)
    169         self._fake_backend.execute.expect_call('query', None).and_raises(
    170             FakeDatabaseError())
    171         self._expect_reconnect()
    172         self._fake_backend.execute.expect_call('query', None)
    173 
    174         db.execute('query')
    175         self.god.check_playback()
    176 
    177         self._fake_backend.execute.expect_call('query', None).and_raises(
    178             FakeDatabaseError())
    179         self.assertRaises(FakeDatabaseError, db.execute, 'query',
    180                           try_reconnecting=False)
    181 
    182 
    183 if __name__ == '__main__':
    184     unittest.main()
    185