1 # pylint: disable-msg=C0111 2 3 import re, time, traceback 4 import common 5 from autotest_lib.client.common_lib import global_config 6 7 RECONNECT_FOREVER = object() 8 9 _DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError') 10 _GLOBAL_CONFIG_NAMES = { 11 'username' : 'user', 12 'db_name' : 'database', 13 } 14 15 def _copy_exceptions(source, destination): 16 for exception_name in _DB_EXCEPTIONS: 17 try: 18 setattr(destination, exception_name, 19 getattr(source, exception_name)) 20 except AttributeError: 21 # Under the django backend: 22 # Django 1.3 does not have OperationalError and ProgrammingError. 23 # Let's just mock these classes with the base DatabaseError. 24 setattr(destination, exception_name, 25 getattr(source, 'DatabaseError')) 26 27 28 class _GenericBackend(object): 29 def __init__(self, database_module): 30 self._database_module = database_module 31 self._connection = None 32 self._cursor = None 33 self.rowcount = None 34 _copy_exceptions(database_module, self) 35 36 37 def connect(self, host=None, username=None, password=None, db_name=None): 38 """ 39 This is assumed to enable autocommit. 40 """ 41 raise NotImplementedError 42 43 44 def disconnect(self): 45 if self._connection: 46 self._connection.close() 47 self._connection = None 48 self._cursor = None 49 50 51 def execute(self, query, parameters=None): 52 if parameters is None: 53 parameters = () 54 self._cursor.execute(query, parameters) 55 self.rowcount = self._cursor.rowcount 56 return self._cursor.fetchall() 57 58 59 class _MySqlBackend(_GenericBackend): 60 def __init__(self): 61 import MySQLdb 62 super(_MySqlBackend, self).__init__(MySQLdb) 63 64 65 @staticmethod 66 def convert_boolean(boolean, conversion_dict): 67 'Convert booleans to integer strings' 68 return str(int(boolean)) 69 70 71 def connect(self, host=None, username=None, password=None, db_name=None): 72 import MySQLdb.converters 73 convert_dict = MySQLdb.converters.conversions 74 convert_dict.setdefault(bool, self.convert_boolean) 75 76 self._connection = self._database_module.connect( 77 host=host, user=username, passwd=password, db=db_name, 78 conv=convert_dict) 79 self._connection.autocommit(True) 80 self._cursor = self._connection.cursor() 81 82 83 class _SqliteBackend(_GenericBackend): 84 def __init__(self): 85 try: 86 from pysqlite2 import dbapi2 87 except ImportError: 88 from sqlite3 import dbapi2 89 super(_SqliteBackend, self).__init__(dbapi2) 90 self._last_insert_id_re = re.compile(r'\sLAST_INSERT_ID\(\)', 91 re.IGNORECASE) 92 93 94 def connect(self, host=None, username=None, password=None, db_name=None): 95 self._connection = self._database_module.connect(db_name) 96 self._connection.isolation_level = None # enable autocommit 97 self._cursor = self._connection.cursor() 98 99 100 def execute(self, query, parameters=None): 101 # pysqlite2 uses paramstyle=qmark 102 # TODO: make this more sophisticated if necessary 103 query = query.replace('%s', '?') 104 # pysqlite2 can't handle parameters=None (it throws a nonsense 105 # exception) 106 if parameters is None: 107 parameters = () 108 # sqlite3 doesn't support MySQL's LAST_INSERT_ID(). Instead it has 109 # something similar called LAST_INSERT_ROWID() that will do enough of 110 # what we want (for our non-concurrent unittest use case). 111 query = self._last_insert_id_re.sub(' LAST_INSERT_ROWID()', query) 112 return super(_SqliteBackend, self).execute(query, parameters) 113 114 115 class _DjangoBackend(_GenericBackend): 116 def __init__(self): 117 from django.db import backend, connection, transaction 118 import django.db as django_db 119 super(_DjangoBackend, self).__init__(django_db) 120 self._django_connection = connection 121 self._django_transaction = transaction 122 123 124 def connect(self, host=None, username=None, password=None, db_name=None): 125 self._connection = self._django_connection 126 self._cursor = self._connection.cursor() 127 128 129 def execute(self, query, parameters=None): 130 try: 131 return super(_DjangoBackend, self).execute(query, 132 parameters=parameters) 133 finally: 134 self._django_transaction.commit_unless_managed() 135 136 137 _BACKEND_MAP = { 138 'mysql': _MySqlBackend, 139 'sqlite': _SqliteBackend, 140 'django': _DjangoBackend, 141 } 142 143 144 class DatabaseConnection(object): 145 """ 146 Generic wrapper for a database connection. Supports both mysql and sqlite 147 backends. 148 149 Public attributes: 150 * reconnect_enabled: if True, when an OperationalError occurs the class will 151 try to reconnect to the database automatically. 152 * reconnect_delay_sec: seconds to wait before reconnecting 153 * max_reconnect_attempts: maximum number of time to try reconnecting before 154 giving up. Setting to RECONNECT_FOREVER removes the limit. 155 * rowcount - will hold cursor.rowcount after each call to execute(). 156 * global_config_section - the section in which to find DB information. this 157 should be passed to the constructor, not set later, and may be None, in 158 which case information must be passed to connect(). 159 * debug - if set True, all queries will be printed before being executed 160 """ 161 _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password', 162 'db_name') 163 164 def __init__(self, global_config_section=None, debug=False): 165 self.global_config_section = global_config_section 166 self._backend = None 167 self.rowcount = None 168 self.debug = debug 169 170 # reconnect defaults 171 self.reconnect_enabled = True 172 self.reconnect_delay_sec = 20 173 self.max_reconnect_attempts = 10 174 175 self._read_options() 176 177 178 def _get_option(self, name, provided_value, use_afe_setting=False): 179 """Get value of given option from global config. 180 181 @param name: Name of the config. 182 @param provided_value: Value being provided to override the one from 183 global config. 184 @param use_afe_setting: Force to use the settings in AFE, default is 185 False. 186 """ 187 # TODO(dshi): This function returns the option value depends on multiple 188 # conditions. The value of `provided_value` has highest priority, then 189 # the code checks if use_afe_setting is True, if that's the case, force 190 # to use settings in AUTOTEST_WEB. At last the value is retrieved from 191 # specified global config section. 192 # The logic is too complicated for a generic function named like 193 # _get_option. Ideally we want to make it clear from caller that it 194 # wants to get database credential from one of the 3 ways: 195 # 1. Use the credential from given config section 196 # 2. Use the credential from AUTOTEST_WEB section 197 # 3. Use the credential provided by caller. 198 if provided_value is not None: 199 return provided_value 200 section = ('AUTOTEST_WEB' if use_afe_setting else 201 self.global_config_section) 202 if section: 203 global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name) 204 return global_config.global_config.get_config_value( 205 section, global_config_name) 206 207 return getattr(self, name, None) 208 209 210 def _read_options(self, db_type=None, host=None, username=None, 211 password=None, db_name=None): 212 """Read database information from global config. 213 214 Unless any parameter is specified a value, the connection will use 215 database name from given configure section (self.global_config_section), 216 and database credential from AFE database settings (AUTOTEST_WEB). 217 218 @param db_type: database type, default to None. 219 @param host: database hostname, default to None. 220 @param username: user name for database connection, default to None. 221 @param password: database password, default to None. 222 @param db_name: database name, default to None. 223 """ 224 self.db_name = self._get_option('db_name', db_name) 225 use_afe_setting = not bool(db_type or host or username or password) 226 227 # Database credential can be provided by the caller, as passed in from 228 # function connect. 229 self.db_type = self._get_option('db_type', db_type, use_afe_setting) 230 self.host = self._get_option('host', host, use_afe_setting) 231 self.username = self._get_option('username', username, use_afe_setting) 232 self.password = self._get_option('password', password, use_afe_setting) 233 234 235 def _get_backend(self, db_type): 236 if db_type not in _BACKEND_MAP: 237 raise ValueError('Invalid database type: %s, should be one of %s' % 238 (db_type, ', '.join(_BACKEND_MAP.keys()))) 239 backend_class = _BACKEND_MAP[db_type] 240 return backend_class() 241 242 243 def _reached_max_attempts(self, num_attempts): 244 return (self.max_reconnect_attempts is not RECONNECT_FOREVER and 245 num_attempts > self.max_reconnect_attempts) 246 247 248 def _is_reconnect_enabled(self, supplied_param): 249 if supplied_param is not None: 250 return supplied_param 251 return self.reconnect_enabled 252 253 254 def _connect_backend(self, try_reconnecting=None): 255 num_attempts = 0 256 while True: 257 try: 258 self._backend.connect(host=self.host, username=self.username, 259 password=self.password, 260 db_name=self.db_name) 261 return 262 except self._backend.OperationalError: 263 num_attempts += 1 264 if not self._is_reconnect_enabled(try_reconnecting): 265 raise 266 if self._reached_max_attempts(num_attempts): 267 raise 268 traceback.print_exc() 269 print ("Can't connect to database; reconnecting in %s sec" % 270 self.reconnect_delay_sec) 271 time.sleep(self.reconnect_delay_sec) 272 self.disconnect() 273 274 275 def connect(self, db_type=None, host=None, username=None, password=None, 276 db_name=None, try_reconnecting=None): 277 """ 278 Parameters passed to this function will override defaults from global 279 config. try_reconnecting, if passed, will override 280 self.reconnect_enabled. 281 """ 282 self.disconnect() 283 self._read_options(db_type, host, username, password, db_name) 284 285 self._backend = self._get_backend(self.db_type) 286 _copy_exceptions(self._backend, self) 287 self._connect_backend(try_reconnecting) 288 289 290 def disconnect(self): 291 if self._backend: 292 self._backend.disconnect() 293 294 295 def execute(self, query, parameters=None, try_reconnecting=None): 296 """ 297 Execute a query and return cursor.fetchall(). try_reconnecting, if 298 passed, will override self.reconnect_enabled. 299 """ 300 if self.debug: 301 print 'Executing %s, %s' % (query, parameters) 302 # _connect_backend() contains a retry loop, so don't loop here 303 try: 304 results = self._backend.execute(query, parameters) 305 except self._backend.OperationalError: 306 if not self._is_reconnect_enabled(try_reconnecting): 307 raise 308 traceback.print_exc() 309 print ("MYSQL connection died; reconnecting") 310 self.disconnect() 311 self._connect_backend(try_reconnecting) 312 results = self._backend.execute(query, parameters) 313 314 self.rowcount = self._backend.rowcount 315 return results 316 317 318 def get_database_info(self): 319 return dict((attribute, getattr(self, attribute)) 320 for attribute in self._DATABASE_ATTRIBUTES) 321 322 323 @classmethod 324 def get_test_database(cls, file_path=':memory:', **constructor_kwargs): 325 """ 326 Factory method returning a DatabaseConnection for a temporary in-memory 327 database. 328 """ 329 database = cls(**constructor_kwargs) 330 database.reconnect_enabled = False 331 database.connect(db_type='sqlite', db_name=file_path) 332 return database 333 334 335 class TranslatingDatabase(DatabaseConnection): 336 """ 337 Database wrapper than applies arbitrary substitution regexps to each query 338 string. Useful for SQLite testing. 339 """ 340 def __init__(self, translators): 341 """ 342 @param translation_regexps: list of callables to apply to each query 343 string (in order). Each accepts a query string and returns a 344 (possibly) modified query string. 345 """ 346 super(TranslatingDatabase, self).__init__() 347 self._translators = translators 348 349 350 def execute(self, query, parameters=None, try_reconnecting=None): 351 for translator in self._translators: 352 query = translator(query) 353 return super(TranslatingDatabase, self).execute( 354 query, parameters=parameters, try_reconnecting=try_reconnecting) 355 356 357 @classmethod 358 def make_regexp_translator(cls, search_re, replace_str): 359 """ 360 Returns a translator that calls re.sub() on the query with the given 361 search and replace arguments. 362 """ 363 def translator(query): 364 return re.sub(search_re, replace_str, query) 365 return translator 366