Home | History | Annotate | Download | only in database
      1 #!/usr/bin/python -u
      2 
      3 import os, sys, re, tempfile
      4 from optparse import OptionParser
      5 import common
      6 from autotest_lib.client.common_lib import utils
      7 from autotest_lib.database import database_connection
      8 
      9 MIGRATE_TABLE = 'migrate_info'
     10 
     11 _AUTODIR = os.path.join(os.path.dirname(__file__), '..')
     12 _MIGRATIONS_DIRS = {
     13     'AUTOTEST_WEB': os.path.join(_AUTODIR, 'frontend', 'migrations'),
     14     'TKO': os.path.join(_AUTODIR, 'tko', 'migrations'),
     15     'AUTOTEST_SERVER_DB': os.path.join(_AUTODIR, 'database',
     16                                       'server_db_migrations'),
     17 }
     18 _DEFAULT_MIGRATIONS_DIR = 'migrations' # use CWD
     19 
     20 class Migration(object):
     21     """Represents a database migration."""
     22     _UP_ATTRIBUTES = ('migrate_up', 'UP_SQL')
     23     _DOWN_ATTRIBUTES = ('migrate_down', 'DOWN_SQL')
     24 
     25     def __init__(self, name, version, module):
     26         self.name = name
     27         self.version = version
     28         self.module = module
     29         self._check_attributes(self._UP_ATTRIBUTES)
     30         self._check_attributes(self._DOWN_ATTRIBUTES)
     31 
     32 
     33     @classmethod
     34     def from_file(cls, filename):
     35         """Instantiates a Migration from a file.
     36 
     37         @param filename: Name of a migration file.
     38 
     39         @return An instantiated Migration object.
     40 
     41         """
     42         version = int(filename[:3])
     43         name = filename[:-3]
     44         module = __import__(name, globals(), locals(), [])
     45         return cls(name, version, module)
     46 
     47 
     48     def _check_attributes(self, attributes):
     49         method_name, sql_name = attributes
     50         assert (hasattr(self.module, method_name) or
     51                 hasattr(self.module, sql_name))
     52 
     53 
     54     def _execute_migration(self, attributes, manager):
     55         method_name, sql_name = attributes
     56         method = getattr(self.module, method_name, None)
     57         if method:
     58             assert callable(method)
     59             method(manager)
     60         else:
     61             sql = getattr(self.module, sql_name)
     62             assert isinstance(sql, basestring)
     63             manager.execute_script(sql)
     64 
     65 
     66     def migrate_up(self, manager):
     67         """Performs an up migration (to a newer version).
     68 
     69         @param manager: A MigrationManager object.
     70 
     71         """
     72         self._execute_migration(self._UP_ATTRIBUTES, manager)
     73 
     74 
     75     def migrate_down(self, manager):
     76         """Performs a down migration (to an older version).
     77 
     78         @param manager: A MigrationManager object.
     79 
     80         """
     81         self._execute_migration(self._DOWN_ATTRIBUTES, manager)
     82 
     83 
     84 class MigrationManager(object):
     85     """Managest database migrations."""
     86     connection = None
     87     cursor = None
     88     migrations_dir = None
     89 
     90     def __init__(self, database_connection, migrations_dir=None, force=False):
     91         self._database = database_connection
     92         self.force = force
     93         # A boolean, this will only be set to True if this migration should be
     94         # simulated rather than actually taken. For use with migrations that
     95         # may make destructive queries
     96         self.simulate = False
     97         self._set_migrations_dir(migrations_dir)
     98 
     99 
    100     def _set_migrations_dir(self, migrations_dir=None):
    101         config_section = self._config_section()
    102         if migrations_dir is None:
    103             migrations_dir = os.path.abspath(
    104                 _MIGRATIONS_DIRS.get(config_section, _DEFAULT_MIGRATIONS_DIR))
    105         self.migrations_dir = migrations_dir
    106         sys.path.append(migrations_dir)
    107         assert os.path.exists(migrations_dir), migrations_dir + " doesn't exist"
    108 
    109 
    110     def _config_section(self):
    111         return self._database.global_config_section
    112 
    113 
    114     def get_db_name(self):
    115         """Gets the database name."""
    116         return self._database.get_database_info()['db_name']
    117 
    118 
    119     def execute(self, query, *parameters):
    120         """Executes a database query.
    121 
    122         @param query: The query to execute.
    123         @param parameters: Associated parameters for the query.
    124 
    125         @return The result of the query.
    126 
    127         """
    128         return self._database.execute(query, parameters)
    129 
    130 
    131     def execute_script(self, script):
    132         """Executes a set of database queries.
    133 
    134         @param script: A string of semicolon-separated queries.
    135 
    136         """
    137         sql_statements = [statement.strip()
    138                           for statement in script.split(';')
    139                           if statement.strip()]
    140         for statement in sql_statements:
    141             self.execute(statement)
    142 
    143 
    144     def check_migrate_table_exists(self):
    145         """Checks whether the migration table exists."""
    146         try:
    147             self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
    148             return True
    149         except self._database.DatabaseError, exc:
    150             # we can't check for more specifics due to differences between DB
    151             # backends (we can't even check for a subclass of DatabaseError)
    152             return False
    153 
    154 
    155     def create_migrate_table(self):
    156         """Creates the migration table."""
    157         if not self.check_migrate_table_exists():
    158             self.execute("CREATE TABLE %s (`version` integer)" %
    159                          MIGRATE_TABLE)
    160         else:
    161             self.execute("DELETE FROM %s" % MIGRATE_TABLE)
    162         self.execute("INSERT INTO %s VALUES (0)" % MIGRATE_TABLE)
    163         assert self._database.rowcount == 1
    164 
    165 
    166     def set_db_version(self, version):
    167         """Sets the database version.
    168 
    169         @param version: The version to which to set the database.
    170 
    171         """
    172         assert isinstance(version, int)
    173         self.execute("UPDATE %s SET version=%%s" % MIGRATE_TABLE,
    174                      version)
    175         assert self._database.rowcount == 1
    176 
    177 
    178     def get_db_version(self):
    179         """Gets the database version.
    180 
    181         @return The database version.
    182 
    183         """
    184         if not self.check_migrate_table_exists():
    185             return 0
    186         rows = self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
    187         if len(rows) == 0:
    188             return 0
    189         assert len(rows) == 1 and len(rows[0]) == 1
    190         return rows[0][0]
    191 
    192 
    193     def get_migrations(self, minimum_version=None, maximum_version=None):
    194         """Gets the list of migrations to perform.
    195 
    196         @param minimum_version: The minimum database version.
    197         @param maximum_version: The maximum database version.
    198 
    199         @return A list of Migration objects.
    200 
    201         """
    202         migrate_files = [filename for filename
    203                          in os.listdir(self.migrations_dir)
    204                          if re.match(r'^\d\d\d_.*\.py$', filename)]
    205         migrate_files.sort()
    206         migrations = [Migration.from_file(filename)
    207                       for filename in migrate_files]
    208         if minimum_version is not None:
    209             migrations = [migration for migration in migrations
    210                           if migration.version >= minimum_version]
    211         if maximum_version is not None:
    212             migrations = [migration for migration in migrations
    213                           if migration.version <= maximum_version]
    214         return migrations
    215 
    216 
    217     def do_migration(self, migration, migrate_up=True):
    218         """Performs a migration.
    219 
    220         @param migration: The Migration to perform.
    221         @param migrate_up: Whether to migrate up (if not, then migrates down).
    222 
    223         """
    224         print 'Applying migration %s' % migration.name, # no newline
    225         if migrate_up:
    226             print 'up'
    227             assert self.get_db_version() == migration.version - 1
    228             migration.migrate_up(self)
    229             new_version = migration.version
    230         else:
    231             print 'down'
    232             assert self.get_db_version() == migration.version
    233             migration.migrate_down(self)
    234             new_version = migration.version - 1
    235         self.set_db_version(new_version)
    236 
    237 
    238     def migrate_to_version(self, version):
    239         """Performs a migration to a specified version.
    240 
    241         @param version: The version to which to migrate the database.
    242 
    243         """
    244         current_version = self.get_db_version()
    245         if current_version == 0 and self._config_section() == 'AUTOTEST_WEB':
    246             self._migrate_from_base()
    247             current_version = self.get_db_version()
    248 
    249         if current_version < version:
    250             lower, upper = current_version, version
    251             migrate_up = True
    252         else:
    253             lower, upper = version, current_version
    254             migrate_up = False
    255 
    256         migrations = self.get_migrations(lower + 1, upper)
    257         if not migrate_up:
    258             migrations.reverse()
    259         for migration in migrations:
    260             self.do_migration(migration, migrate_up)
    261 
    262         assert self.get_db_version() == version
    263         print 'At version', version
    264 
    265 
    266     def _migrate_from_base(self):
    267         """Initialize the AFE database.
    268         """
    269         self.confirm_initialization()
    270 
    271         migration_script = utils.read_file(
    272                 os.path.join(os.path.dirname(__file__), 'schema_051.sql'))
    273         migration_script = migration_script % (
    274                 dict(username=self._database.get_database_info()['username']))
    275         self.execute_script(migration_script)
    276 
    277         self.create_migrate_table()
    278         self.set_db_version(51)
    279 
    280 
    281     def confirm_initialization(self):
    282         """Confirms with the user that we should initialize the database.
    283 
    284         @raises Exception, if the user chooses to abort the migration.
    285 
    286         """
    287         if not self.force:
    288             response = raw_input(
    289                 'Your %s database does not appear to be initialized.  Do you '
    290                 'want to recreate it (this will result in loss of any existing '
    291                 'data) (yes/No)? ' % self.get_db_name())
    292             if response != 'yes':
    293                 raise Exception('User has chosen to abort migration')
    294 
    295 
    296     def get_latest_version(self):
    297         """Gets the latest database version."""
    298         migrations = self.get_migrations()
    299         return migrations[-1].version
    300 
    301 
    302     def migrate_to_latest(self):
    303         """Migrates the database to the latest version."""
    304         latest_version = self.get_latest_version()
    305         self.migrate_to_version(latest_version)
    306 
    307 
    308     def initialize_test_db(self):
    309         """Initializes a test database."""
    310         db_name = self.get_db_name()
    311         test_db_name = 'test_' + db_name
    312         # first, connect to no DB so we can create a test DB
    313         self._database.connect(db_name='')
    314         print 'Creating test DB', test_db_name
    315         self.execute('CREATE DATABASE ' + test_db_name)
    316         self._database.disconnect()
    317         # now connect to the test DB
    318         self._database.connect(db_name=test_db_name)
    319 
    320 
    321     def remove_test_db(self):
    322         """Removes a test database."""
    323         print 'Removing test DB'
    324         self.execute('DROP DATABASE ' + self.get_db_name())
    325         # reset connection back to real DB
    326         self._database.disconnect()
    327         self._database.connect()
    328 
    329 
    330     def get_mysql_args(self):
    331         """Returns the mysql arguments as a string."""
    332         return ('-u %(username)s -p%(password)s -h %(host)s %(db_name)s' %
    333                 self._database.get_database_info())
    334 
    335 
    336     def migrate_to_version_or_latest(self, version):
    337         """Migrates to either a specified version, or the latest version.
    338 
    339         @param version: The version to which to migrate the database,
    340             or None in order to migrate to the latest version.
    341 
    342         """
    343         if version is None:
    344             self.migrate_to_latest()
    345         else:
    346             self.migrate_to_version(version)
    347 
    348 
    349     def do_sync_db(self, version=None):
    350         """Migrates the database.
    351 
    352         @param version: The version to which to migrate the database.
    353 
    354         """
    355         print 'Migration starting for database', self.get_db_name()
    356         self.migrate_to_version_or_latest(version)
    357         print 'Migration complete'
    358 
    359 
    360     def test_sync_db(self, version=None):
    361         """Create a fresh database and run all migrations on it.
    362 
    363         @param version: The version to which to migrate the database.
    364 
    365         """
    366         self.initialize_test_db()
    367         try:
    368             print 'Starting migration test on DB', self.get_db_name()
    369             self.migrate_to_version_or_latest(version)
    370             # show schema to the user
    371             os.system('mysqldump %s --no-data=true '
    372                       '--add-drop-table=false' %
    373                       self.get_mysql_args())
    374         finally:
    375             self.remove_test_db()
    376         print 'Test finished successfully'
    377 
    378 
    379     def simulate_sync_db(self, version=None):
    380         """Creates a fresh DB, copies existing DB to it, then synchronizes it.
    381 
    382         @param version: The version to which to migrate the database.
    383 
    384         """
    385         db_version = self.get_db_version()
    386         # don't do anything if we're already at the latest version
    387         if db_version == self.get_latest_version():
    388             print 'Skipping simulation, already at latest version'
    389             return
    390         # get existing data
    391         self.initialize_and_fill_test_db()
    392         try:
    393             print 'Starting migration test on DB', self.get_db_name()
    394             self.migrate_to_version_or_latest(version)
    395         finally:
    396             self.remove_test_db()
    397         print 'Test finished successfully'
    398 
    399 
    400     def initialize_and_fill_test_db(self):
    401         """Initializes and fills up a test database."""
    402         print 'Dumping existing data'
    403         dump_fd, dump_file = tempfile.mkstemp('.migrate_dump')
    404         os.system('mysqldump %s >%s' %
    405                   (self.get_mysql_args(), dump_file))
    406         # fill in test DB
    407         self.initialize_test_db()
    408         print 'Filling in test DB'
    409         os.system('mysql %s <%s' % (self.get_mysql_args(), dump_file))
    410         os.close(dump_fd)
    411         os.remove(dump_file)
    412 
    413 
    414 USAGE = """\
    415 %s [options] sync|test|simulate|safesync [version]
    416 Options:
    417     -d --database   Which database to act on
    418     -f --force      Don't ask for confirmation
    419     --debug         Print all DB queries"""\
    420     % sys.argv[0]
    421 
    422 
    423 def main():
    424     """Main function for the migration script."""
    425     parser = OptionParser()
    426     parser.add_option("-d", "--database",
    427                       help="which database to act on",
    428                       dest="database",
    429                       default="AUTOTEST_WEB")
    430     parser.add_option("-f", "--force", help="don't ask for confirmation",
    431                       action="store_true")
    432     parser.add_option('--debug', help='print all DB queries',
    433                       action='store_true')
    434     (options, args) = parser.parse_args()
    435     manager = get_migration_manager(db_name=options.database,
    436                                     debug=options.debug, force=options.force)
    437 
    438     if len(args) > 0:
    439         if len(args) > 1:
    440             version = int(args[1])
    441         else:
    442             version = None
    443         if args[0] == 'sync':
    444             manager.do_sync_db(version)
    445         elif args[0] == 'test':
    446             manager.simulate=True
    447             manager.test_sync_db(version)
    448         elif args[0] == 'simulate':
    449             manager.simulate=True
    450             manager.simulate_sync_db(version)
    451         elif args[0] == 'safesync':
    452             print 'Simluating migration'
    453             manager.simulate=True
    454             manager.simulate_sync_db(version)
    455             print 'Performing real migration'
    456             manager.simulate=False
    457             manager.do_sync_db(version)
    458         else:
    459             print USAGE
    460         return
    461 
    462     print USAGE
    463 
    464 
    465 def get_migration_manager(db_name, debug, force):
    466     """Creates a MigrationManager object.
    467 
    468     @param db_name: The database name.
    469     @param debug: Whether to print debug messages.
    470     @param force: Whether to force migration without asking for confirmation.
    471 
    472     @return A created MigrationManager object.
    473 
    474     """
    475     database = database_connection.DatabaseConnection(db_name)
    476     database.debug = debug
    477     database.reconnect_enabled = False
    478     database.connect()
    479     return MigrationManager(database, force=force)
    480 
    481 
    482 if __name__ == '__main__':
    483     main()
    484