1 #!/usr/bin/python 2 3 import unittest, tempfile, os 4 import common 5 import MySQLdb 6 from autotest_lib.client.common_lib import global_config 7 from autotest_lib.database import database_connection, migrate 8 9 # Which section of the global config to pull info from. We won't actually use 10 # that DB, we'll use the corresponding test DB (test_<db name>). 11 CONFIG_DB = 'AUTOTEST_WEB' 12 13 NUM_MIGRATIONS = 3 14 15 class DummyMigration(object): 16 """\ 17 Dummy migration class that records all migrations done in a class 18 varaible. 19 """ 20 21 migrations_done = [] 22 23 def __init__(self, version): 24 self.version = version 25 self.name = '%03d_test' % version 26 27 28 @classmethod 29 def get_migrations_done(cls): 30 return cls.migrations_done 31 32 33 @classmethod 34 def clear_migrations_done(cls): 35 cls.migrations_done = [] 36 37 38 @classmethod 39 def do_migration(cls, version, direction): 40 cls.migrations_done.append((version, direction)) 41 42 43 def migrate_up(self, manager): 44 self.do_migration(self.version, 'up') 45 if self.version == 1: 46 manager.create_migrate_table() 47 48 49 def migrate_down(self, manager): 50 self.do_migration(self.version, 'down') 51 52 53 MIGRATIONS = [DummyMigration(n) for n in xrange(1, NUM_MIGRATIONS + 1)] 54 55 56 class TestableMigrationManager(migrate.MigrationManager): 57 def _set_migrations_dir(self, migrations_dir=None): 58 pass 59 60 61 def get_migrations(self, minimum_version=None, maximum_version=None): 62 minimum_version = minimum_version or 1 63 maximum_version = maximum_version or len(MIGRATIONS) 64 return MIGRATIONS[minimum_version-1:maximum_version] 65 66 67 class MigrateManagerTest(unittest.TestCase): 68 def setUp(self): 69 self._database = ( 70 database_connection.DatabaseConnection.get_test_database()) 71 self._database.connect() 72 self.manager = TestableMigrationManager(self._database) 73 DummyMigration.clear_migrations_done() 74 75 76 def tearDown(self): 77 self._database.disconnect() 78 79 80 def test_sync(self): 81 self.manager.do_sync_db() 82 self.assertEquals(self.manager.get_db_version(), NUM_MIGRATIONS) 83 self.assertEquals(DummyMigration.get_migrations_done(), 84 [(1, 'up'), (2, 'up'), (3, 'up')]) 85 86 DummyMigration.clear_migrations_done() 87 self.manager.do_sync_db(0) 88 self.assertEquals(self.manager.get_db_version(), 0) 89 self.assertEquals(DummyMigration.get_migrations_done(), 90 [(3, 'down'), (2, 'down'), (1, 'down')]) 91 92 93 def test_sync_one_by_one(self): 94 for version in xrange(1, NUM_MIGRATIONS + 1): 95 self.manager.do_sync_db(version) 96 self.assertEquals(self.manager.get_db_version(), 97 version) 98 self.assertEquals( 99 DummyMigration.get_migrations_done()[-1], 100 (version, 'up')) 101 102 for version in xrange(NUM_MIGRATIONS - 1, -1, -1): 103 self.manager.do_sync_db(version) 104 self.assertEquals(self.manager.get_db_version(), 105 version) 106 self.assertEquals( 107 DummyMigration.get_migrations_done()[-1], 108 (version + 1, 'down')) 109 110 111 def test_null_sync(self): 112 self.manager.do_sync_db() 113 DummyMigration.clear_migrations_done() 114 self.manager.do_sync_db() 115 self.assertEquals(DummyMigration.get_migrations_done(), []) 116 117 118 class DummyMigrationManager(object): 119 def __init__(self): 120 self.calls = [] 121 122 123 def execute_script(self, script): 124 self.calls.append(script) 125 126 127 class MigrationTest(unittest.TestCase): 128 def setUp(self): 129 self.manager = DummyMigrationManager() 130 131 132 def _do_migration(self, migration_module): 133 migration = migrate.Migration('name', 1, migration_module) 134 migration.migrate_up(self.manager) 135 migration.migrate_down(self.manager) 136 137 self.assertEquals(self.manager.calls, ['foo', 'bar']) 138 139 140 def test_migration_with_methods(self): 141 class DummyMigration(object): 142 @staticmethod 143 def migrate_up(manager): 144 manager.execute_script('foo') 145 146 147 @staticmethod 148 def migrate_down(manager): 149 manager.execute_script('bar') 150 151 self._do_migration(DummyMigration) 152 153 154 def test_migration_with_strings(self): 155 class DummyMigration(object): 156 UP_SQL = 'foo' 157 DOWN_SQL = 'bar' 158 159 self._do_migration(DummyMigration) 160 161 162 if __name__ == '__main__': 163 unittest.main() 164