Home | History | Annotate | Download | only in database
      1 #!/usr/bin/python
      2 
      3 import unittest, tempfile, os
      4 import common
      5 import MySQLdb
      6 from autotest_lib.client.common_lib import global_config
      7 from autotest_lib.database import database_connection, migrate
      8 
      9 # Which section of the global config to pull info from.  We won't actually use
     10 # that DB, we'll use the corresponding test DB (test_<db name>).
     11 CONFIG_DB = 'AUTOTEST_WEB'
     12 
     13 NUM_MIGRATIONS = 3
     14 
     15 class DummyMigration(object):
     16     """\
     17     Dummy migration class that records all migrations done in a class
     18     varaible.
     19     """
     20 
     21     migrations_done = []
     22 
     23     def __init__(self, version):
     24         self.version = version
     25         self.name = '%03d_test' % version
     26 
     27 
     28     @classmethod
     29     def get_migrations_done(cls):
     30         return cls.migrations_done
     31 
     32 
     33     @classmethod
     34     def clear_migrations_done(cls):
     35         cls.migrations_done = []
     36 
     37 
     38     @classmethod
     39     def do_migration(cls, version, direction):
     40         cls.migrations_done.append((version, direction))
     41 
     42 
     43     def migrate_up(self, manager):
     44         self.do_migration(self.version, 'up')
     45         if self.version == 1:
     46             manager.create_migrate_table()
     47 
     48 
     49     def migrate_down(self, manager):
     50         self.do_migration(self.version, 'down')
     51 
     52 
     53 MIGRATIONS = [DummyMigration(n) for n in xrange(1, NUM_MIGRATIONS + 1)]
     54 
     55 
     56 class TestableMigrationManager(migrate.MigrationManager):
     57     def _set_migrations_dir(self, migrations_dir=None):
     58         pass
     59 
     60 
     61     def get_migrations(self, minimum_version=None, maximum_version=None):
     62         minimum_version = minimum_version or 1
     63         maximum_version = maximum_version or len(MIGRATIONS)
     64         return MIGRATIONS[minimum_version-1:maximum_version]
     65 
     66 
     67 class MigrateManagerTest(unittest.TestCase):
     68     def setUp(self):
     69         self._database = (
     70             database_connection.DatabaseConnection.get_test_database())
     71         self._database.connect()
     72         self.manager = TestableMigrationManager(self._database)
     73         DummyMigration.clear_migrations_done()
     74 
     75 
     76     def tearDown(self):
     77         self._database.disconnect()
     78 
     79 
     80     def test_sync(self):
     81         self.manager.do_sync_db()
     82         self.assertEquals(self.manager.get_db_version(), NUM_MIGRATIONS)
     83         self.assertEquals(DummyMigration.get_migrations_done(),
     84                           [(1, 'up'), (2, 'up'), (3, 'up')])
     85 
     86         DummyMigration.clear_migrations_done()
     87         self.manager.do_sync_db(0)
     88         self.assertEquals(self.manager.get_db_version(), 0)
     89         self.assertEquals(DummyMigration.get_migrations_done(),
     90                           [(3, 'down'), (2, 'down'), (1, 'down')])
     91 
     92 
     93     def test_sync_one_by_one(self):
     94         for version in xrange(1, NUM_MIGRATIONS + 1):
     95             self.manager.do_sync_db(version)
     96             self.assertEquals(self.manager.get_db_version(),
     97                               version)
     98             self.assertEquals(
     99                 DummyMigration.get_migrations_done()[-1],
    100                 (version, 'up'))
    101 
    102         for version in xrange(NUM_MIGRATIONS - 1, -1, -1):
    103             self.manager.do_sync_db(version)
    104             self.assertEquals(self.manager.get_db_version(),
    105                               version)
    106             self.assertEquals(
    107                 DummyMigration.get_migrations_done()[-1],
    108                 (version + 1, 'down'))
    109 
    110 
    111     def test_null_sync(self):
    112         self.manager.do_sync_db()
    113         DummyMigration.clear_migrations_done()
    114         self.manager.do_sync_db()
    115         self.assertEquals(DummyMigration.get_migrations_done(), [])
    116 
    117 
    118 class DummyMigrationManager(object):
    119     def __init__(self):
    120         self.calls = []
    121 
    122 
    123     def execute_script(self, script):
    124         self.calls.append(script)
    125 
    126 
    127 class MigrationTest(unittest.TestCase):
    128     def setUp(self):
    129         self.manager = DummyMigrationManager()
    130 
    131 
    132     def _do_migration(self, migration_module):
    133         migration = migrate.Migration('name', 1, migration_module)
    134         migration.migrate_up(self.manager)
    135         migration.migrate_down(self.manager)
    136 
    137         self.assertEquals(self.manager.calls, ['foo', 'bar'])
    138 
    139 
    140     def test_migration_with_methods(self):
    141         class DummyMigration(object):
    142             @staticmethod
    143             def migrate_up(manager):
    144                 manager.execute_script('foo')
    145 
    146 
    147             @staticmethod
    148             def migrate_down(manager):
    149                 manager.execute_script('bar')
    150 
    151         self._do_migration(DummyMigration)
    152 
    153 
    154     def test_migration_with_strings(self):
    155         class DummyMigration(object):
    156             UP_SQL = 'foo'
    157             DOWN_SQL = 'bar'
    158 
    159         self._do_migration(DummyMigration)
    160 
    161 
    162 if __name__ == '__main__':
    163     unittest.main()
    164