Home | History | Annotate | Download | only in database
      1 # pylint: disable-msg=C0111
      2 
      3 import re, time, traceback
      4 import common
      5 from autotest_lib.client.common_lib import global_config
      6 
      7 RECONNECT_FOREVER = object()
      8 
      9 _DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError')
     10 _GLOBAL_CONFIG_NAMES = {
     11     'username' : 'user',
     12     'db_name' : 'database',
     13 }
     14 
     15 def _copy_exceptions(source, destination):
     16     for exception_name in _DB_EXCEPTIONS:
     17         try:
     18             setattr(destination, exception_name,
     19                     getattr(source, exception_name))
     20         except AttributeError:
     21             # Under the django backend:
     22             # Django 1.3 does not have OperationalError and ProgrammingError.
     23             # Let's just mock these classes with the base DatabaseError.
     24             setattr(destination, exception_name,
     25                     getattr(source, 'DatabaseError'))
     26 
     27 
     28 class _GenericBackend(object):
     29     def __init__(self, database_module):
     30         self._database_module = database_module
     31         self._connection = None
     32         self._cursor = None
     33         self.rowcount = None
     34         _copy_exceptions(database_module, self)
     35 
     36 
     37     def connect(self, host=None, username=None, password=None, db_name=None):
     38         """
     39         This is assumed to enable autocommit.
     40         """
     41         raise NotImplementedError
     42 
     43 
     44     def disconnect(self):
     45         if self._connection:
     46             self._connection.close()
     47         self._connection = None
     48         self._cursor = None
     49 
     50 
     51     def execute(self, query, parameters=None):
     52         if parameters is None:
     53             parameters = ()
     54         self._cursor.execute(query, parameters)
     55         self.rowcount = self._cursor.rowcount
     56         return self._cursor.fetchall()
     57 
     58 
     59 class _MySqlBackend(_GenericBackend):
     60     def __init__(self):
     61         import MySQLdb
     62         super(_MySqlBackend, self).__init__(MySQLdb)
     63 
     64 
     65     @staticmethod
     66     def convert_boolean(boolean, conversion_dict):
     67         'Convert booleans to integer strings'
     68         return str(int(boolean))
     69 
     70 
     71     def connect(self, host=None, username=None, password=None, db_name=None):
     72         import MySQLdb.converters
     73         convert_dict = MySQLdb.converters.conversions
     74         convert_dict.setdefault(bool, self.convert_boolean)
     75 
     76         self._connection = self._database_module.connect(
     77             host=host, user=username, passwd=password, db=db_name,
     78             conv=convert_dict)
     79         self._connection.autocommit(True)
     80         self._cursor = self._connection.cursor()
     81 
     82 
     83 class _SqliteBackend(_GenericBackend):
     84     def __init__(self):
     85         try:
     86             from pysqlite2 import dbapi2
     87         except ImportError:
     88             from sqlite3 import dbapi2
     89         super(_SqliteBackend, self).__init__(dbapi2)
     90         self._last_insert_id_re = re.compile(r'\sLAST_INSERT_ID\(\)',
     91                                              re.IGNORECASE)
     92 
     93 
     94     def connect(self, host=None, username=None, password=None, db_name=None):
     95         self._connection = self._database_module.connect(db_name)
     96         self._connection.isolation_level = None # enable autocommit
     97         self._cursor = self._connection.cursor()
     98 
     99 
    100     def execute(self, query, parameters=None):
    101         # pysqlite2 uses paramstyle=qmark
    102         # TODO: make this more sophisticated if necessary
    103         query = query.replace('%s', '?')
    104         # pysqlite2 can't handle parameters=None (it throws a nonsense
    105         # exception)
    106         if parameters is None:
    107             parameters = ()
    108         # sqlite3 doesn't support MySQL's LAST_INSERT_ID().  Instead it has
    109         # something similar called LAST_INSERT_ROWID() that will do enough of
    110         # what we want (for our non-concurrent unittest use case).
    111         query = self._last_insert_id_re.sub(' LAST_INSERT_ROWID()', query)
    112         return super(_SqliteBackend, self).execute(query, parameters)
    113 
    114 
    115 class _DjangoBackend(_GenericBackend):
    116     def __init__(self):
    117         from django.db import backend, connection, transaction
    118         import django.db as django_db
    119         super(_DjangoBackend, self).__init__(django_db)
    120         self._django_connection = connection
    121         self._django_transaction = transaction
    122 
    123 
    124     def connect(self, host=None, username=None, password=None, db_name=None):
    125         self._connection = self._django_connection
    126         self._cursor = self._connection.cursor()
    127 
    128 
    129     def execute(self, query, parameters=None):
    130         try:
    131             return super(_DjangoBackend, self).execute(query,
    132                                                        parameters=parameters)
    133         finally:
    134             self._django_transaction.commit_unless_managed()
    135 
    136 
    137 _BACKEND_MAP = {
    138     'mysql': _MySqlBackend,
    139     'sqlite': _SqliteBackend,
    140     'django': _DjangoBackend,
    141 }
    142 
    143 
    144 class DatabaseConnection(object):
    145     """
    146     Generic wrapper for a database connection.  Supports both mysql and sqlite
    147     backends.
    148 
    149     Public attributes:
    150     * reconnect_enabled: if True, when an OperationalError occurs the class will
    151       try to reconnect to the database automatically.
    152     * reconnect_delay_sec: seconds to wait before reconnecting
    153     * max_reconnect_attempts: maximum number of time to try reconnecting before
    154       giving up.  Setting to RECONNECT_FOREVER removes the limit.
    155     * rowcount - will hold cursor.rowcount after each call to execute().
    156     * global_config_section - the section in which to find DB information. this
    157       should be passed to the constructor, not set later, and may be None, in
    158       which case information must be passed to connect().
    159     * debug - if set True, all queries will be printed before being executed
    160     """
    161     _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password',
    162                             'db_name')
    163 
    164     def __init__(self, global_config_section=None, debug=False):
    165         self.global_config_section = global_config_section
    166         self._backend = None
    167         self.rowcount = None
    168         self.debug = debug
    169 
    170         # reconnect defaults
    171         self.reconnect_enabled = True
    172         self.reconnect_delay_sec = 20
    173         self.max_reconnect_attempts = 10
    174 
    175         self._read_options()
    176 
    177 
    178     def _get_option(self, name, provided_value, use_afe_setting=False):
    179         """Get value of given option from global config.
    180 
    181         @param name: Name of the config.
    182         @param provided_value: Value being provided to override the one from
    183                                global config.
    184         @param use_afe_setting: Force to use the settings in AFE, default is
    185                                 False.
    186         """
    187         # TODO(dshi): This function returns the option value depends on multiple
    188         # conditions. The value of `provided_value` has highest priority, then
    189         # the code checks if use_afe_setting is True, if that's the case, force
    190         # to use settings in AUTOTEST_WEB. At last the value is retrieved from
    191         # specified global config section.
    192         # The logic is too complicated for a generic function named like
    193         # _get_option. Ideally we want to make it clear from caller that it
    194         # wants to get database credential from one of the 3 ways:
    195         # 1. Use the credential from given config section
    196         # 2. Use the credential from AUTOTEST_WEB section
    197         # 3. Use the credential provided by caller.
    198         if provided_value is not None:
    199             return provided_value
    200         section = ('AUTOTEST_WEB' if use_afe_setting else
    201                    self.global_config_section)
    202         if section:
    203             global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name)
    204             return global_config.global_config.get_config_value(
    205                     section, global_config_name)
    206 
    207         return getattr(self, name, None)
    208 
    209 
    210     def _read_options(self, db_type=None, host=None, username=None,
    211                       password=None, db_name=None):
    212         """Read database information from global config.
    213 
    214         Unless any parameter is specified a value, the connection will use
    215         database name from given configure section (self.global_config_section),
    216         and database credential from AFE database settings (AUTOTEST_WEB).
    217 
    218         @param db_type: database type, default to None.
    219         @param host: database hostname, default to None.
    220         @param username: user name for database connection, default to None.
    221         @param password: database password, default to None.
    222         @param db_name: database name, default to None.
    223         """
    224         self.db_name = self._get_option('db_name', db_name)
    225         use_afe_setting = not bool(db_type or host or username or password)
    226 
    227         # Database credential can be provided by the caller, as passed in from
    228         # function connect.
    229         self.db_type = self._get_option('db_type', db_type, use_afe_setting)
    230         self.host = self._get_option('host', host, use_afe_setting)
    231         self.username = self._get_option('username', username, use_afe_setting)
    232         self.password = self._get_option('password', password, use_afe_setting)
    233 
    234 
    235     def _get_backend(self, db_type):
    236         if db_type not in _BACKEND_MAP:
    237             raise ValueError('Invalid database type: %s, should be one of %s' %
    238                              (db_type, ', '.join(_BACKEND_MAP.keys())))
    239         backend_class = _BACKEND_MAP[db_type]
    240         return backend_class()
    241 
    242 
    243     def _reached_max_attempts(self, num_attempts):
    244         return (self.max_reconnect_attempts is not RECONNECT_FOREVER and
    245                 num_attempts > self.max_reconnect_attempts)
    246 
    247 
    248     def _is_reconnect_enabled(self, supplied_param):
    249         if supplied_param is not None:
    250             return supplied_param
    251         return self.reconnect_enabled
    252 
    253 
    254     def _connect_backend(self, try_reconnecting=None):
    255         num_attempts = 0
    256         while True:
    257             try:
    258                 self._backend.connect(host=self.host, username=self.username,
    259                                       password=self.password,
    260                                       db_name=self.db_name)
    261                 return
    262             except self._backend.OperationalError:
    263                 num_attempts += 1
    264                 if not self._is_reconnect_enabled(try_reconnecting):
    265                     raise
    266                 if self._reached_max_attempts(num_attempts):
    267                     raise
    268                 traceback.print_exc()
    269                 print ("Can't connect to database; reconnecting in %s sec" %
    270                        self.reconnect_delay_sec)
    271                 time.sleep(self.reconnect_delay_sec)
    272                 self.disconnect()
    273 
    274 
    275     def connect(self, db_type=None, host=None, username=None, password=None,
    276                 db_name=None, try_reconnecting=None):
    277         """
    278         Parameters passed to this function will override defaults from global
    279         config.  try_reconnecting, if passed, will override
    280         self.reconnect_enabled.
    281         """
    282         self.disconnect()
    283         self._read_options(db_type, host, username, password, db_name)
    284 
    285         self._backend = self._get_backend(self.db_type)
    286         _copy_exceptions(self._backend, self)
    287         self._connect_backend(try_reconnecting)
    288 
    289 
    290     def disconnect(self):
    291         if self._backend:
    292             self._backend.disconnect()
    293 
    294 
    295     def execute(self, query, parameters=None, try_reconnecting=None):
    296         """
    297         Execute a query and return cursor.fetchall(). try_reconnecting, if
    298         passed, will override self.reconnect_enabled.
    299         """
    300         if self.debug:
    301             print 'Executing %s, %s' % (query, parameters)
    302         # _connect_backend() contains a retry loop, so don't loop here
    303         try:
    304             results = self._backend.execute(query, parameters)
    305         except self._backend.OperationalError:
    306             if not self._is_reconnect_enabled(try_reconnecting):
    307                 raise
    308             traceback.print_exc()
    309             print ("MYSQL connection died; reconnecting")
    310             self.disconnect()
    311             self._connect_backend(try_reconnecting)
    312             results = self._backend.execute(query, parameters)
    313 
    314         self.rowcount = self._backend.rowcount
    315         return results
    316 
    317 
    318     def get_database_info(self):
    319         return dict((attribute, getattr(self, attribute))
    320                     for attribute in self._DATABASE_ATTRIBUTES)
    321 
    322 
    323     @classmethod
    324     def get_test_database(cls, file_path=':memory:', **constructor_kwargs):
    325         """
    326         Factory method returning a DatabaseConnection for a temporary in-memory
    327         database.
    328         """
    329         database = cls(**constructor_kwargs)
    330         database.reconnect_enabled = False
    331         database.connect(db_type='sqlite', db_name=file_path)
    332         return database
    333 
    334 
    335 class TranslatingDatabase(DatabaseConnection):
    336     """
    337     Database wrapper than applies arbitrary substitution regexps to each query
    338     string.  Useful for SQLite testing.
    339     """
    340     def __init__(self, translators):
    341         """
    342         @param translation_regexps: list of callables to apply to each query
    343                 string (in order).  Each accepts a query string and returns a
    344                 (possibly) modified query string.
    345         """
    346         super(TranslatingDatabase, self).__init__()
    347         self._translators = translators
    348 
    349 
    350     def execute(self, query, parameters=None, try_reconnecting=None):
    351         for translator in self._translators:
    352             query = translator(query)
    353         return super(TranslatingDatabase, self).execute(
    354                 query, parameters=parameters, try_reconnecting=try_reconnecting)
    355 
    356 
    357     @classmethod
    358     def make_regexp_translator(cls, search_re, replace_str):
    359         """
    360         Returns a translator that calls re.sub() on the query with the given
    361         search and replace arguments.
    362         """
    363         def translator(query):
    364             return re.sub(search_re, replace_str, query)
    365         return translator
    366