Home | History | Annotate | Download | only in common
      1 // Copyright (c) 2010 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "chrome/common/sqlite_utils.h"
      6 
      7 #include <list>
      8 
      9 #include "base/file_path.h"
     10 #include "base/lazy_instance.h"
     11 #include "base/logging.h"
     12 #include "base/stl_util-inl.h"
     13 #include "base/string16.h"
     14 #include "base/synchronization/lock.h"
     15 
     16 // The vanilla error handler implements the common fucntionality for all the
     17 // error handlers. Specialized error handlers are expected to only override
     18 // the Handler() function.
     19 class VanillaSQLErrorHandler : public SQLErrorHandler {
     20  public:
     21   VanillaSQLErrorHandler() : error_(SQLITE_OK) {
     22   }
     23   virtual int GetLastError() const {
     24     return error_;
     25   }
     26  protected:
     27   int error_;
     28 };
     29 
     30 class DebugSQLErrorHandler: public VanillaSQLErrorHandler {
     31  public:
     32   virtual int HandleError(int error, sqlite3* db) {
     33     error_ = error;
     34     NOTREACHED() << "sqlite error " << error
     35                  << " db " << static_cast<void*>(db);
     36     return error;
     37   }
     38 };
     39 
     40 class ReleaseSQLErrorHandler : public VanillaSQLErrorHandler {
     41  public:
     42   virtual int HandleError(int error, sqlite3* db) {
     43     error_ = error;
     44     // Used to have a CHECK here. Got lots of crashes.
     45     return error;
     46   }
     47 };
     48 
     49 // The default error handler factory is also in charge of managing the
     50 // lifetime of the error objects. This object is multi-thread safe.
     51 class DefaultSQLErrorHandlerFactory : public SQLErrorHandlerFactory {
     52  public:
     53   ~DefaultSQLErrorHandlerFactory() {
     54     STLDeleteContainerPointers(errors_.begin(), errors_.end());
     55   }
     56 
     57   virtual SQLErrorHandler* Make() {
     58     SQLErrorHandler* handler;
     59 #ifndef NDEBUG
     60     handler = new DebugSQLErrorHandler;
     61 #else
     62     handler = new ReleaseSQLErrorHandler;
     63 #endif  // NDEBUG
     64     AddHandler(handler);
     65     return handler;
     66   }
     67 
     68  private:
     69   void AddHandler(SQLErrorHandler* handler) {
     70     base::AutoLock lock(lock_);
     71     errors_.push_back(handler);
     72   }
     73 
     74   typedef std::list<SQLErrorHandler*> ErrorList;
     75   ErrorList errors_;
     76   base::Lock lock_;
     77 };
     78 
     79 static base::LazyInstance<DefaultSQLErrorHandlerFactory>
     80     g_default_sql_error_handler_factory(base::LINKER_INITIALIZED);
     81 
     82 SQLErrorHandlerFactory* GetErrorHandlerFactory() {
     83   // TODO(cpu): Testing needs to override the error handler.
     84   // Destruction of DefaultSQLErrorHandlerFactory handled by at_exit manager.
     85   return g_default_sql_error_handler_factory.Pointer();
     86 }
     87 
     88 namespace sqlite_utils {
     89 
     90 int OpenSqliteDb(const FilePath& filepath, sqlite3** database) {
     91 #if defined(OS_WIN)
     92   // We want the default encoding to always be UTF-8, so we use the
     93   // 8-bit version of open().
     94   return sqlite3_open(WideToUTF8(filepath.value()).c_str(), database);
     95 #elif defined(OS_POSIX)
     96   return sqlite3_open(filepath.value().c_str(), database);
     97 #endif
     98 }
     99 
    100 bool DoesSqliteTableExist(sqlite3* db,
    101                           const char* db_name,
    102                           const char* table_name) {
    103   // sqlite doesn't allow binding parameters as table names, so we have to
    104   // manually construct the sql
    105   std::string sql("SELECT name FROM ");
    106   if (db_name && db_name[0]) {
    107     sql.append(db_name);
    108     sql.push_back('.');
    109   }
    110   sql.append("sqlite_master WHERE type='table' AND name=?");
    111 
    112   SQLStatement statement;
    113   if (statement.prepare(db, sql.c_str()) != SQLITE_OK)
    114     return false;
    115 
    116   if (statement.bind_text(0, table_name) != SQLITE_OK)
    117     return false;
    118 
    119   // we only care about if this matched a row, not the actual data
    120   return sqlite3_step(statement.get()) == SQLITE_ROW;
    121 }
    122 
    123 bool DoesSqliteColumnExist(sqlite3* db,
    124                            const char* database_name,
    125                            const char* table_name,
    126                            const char* column_name,
    127                            const char* column_type) {
    128   SQLStatement s;
    129   std::string sql;
    130   sql.append("PRAGMA ");
    131   if (database_name && database_name[0]) {
    132     // optional database name specified
    133     sql.append(database_name);
    134     sql.push_back('.');
    135   }
    136   sql.append("TABLE_INFO(");
    137   sql.append(table_name);
    138   sql.append(")");
    139 
    140   if (s.prepare(db, sql.c_str()) != SQLITE_OK)
    141     return false;
    142 
    143   while (s.step() == SQLITE_ROW) {
    144     if (!s.column_string(1).compare(column_name)) {
    145       if (column_type && column_type[0])
    146         return !s.column_string(2).compare(column_type);
    147       return true;
    148     }
    149   }
    150   return false;
    151 }
    152 
    153 bool DoesSqliteTableHaveRow(sqlite3* db, const char* table_name) {
    154   SQLStatement s;
    155   std::string b;
    156   b.append("SELECT * FROM ");
    157   b.append(table_name);
    158 
    159   if (s.prepare(db, b.c_str()) != SQLITE_OK)
    160     return false;
    161 
    162   return s.step() == SQLITE_ROW;
    163 }
    164 
    165 }  // namespace sqlite_utils
    166 
    167 SQLTransaction::SQLTransaction(sqlite3* db) : db_(db), began_(false) {
    168 }
    169 
    170 SQLTransaction::~SQLTransaction() {
    171   if (began_) {
    172     Rollback();
    173   }
    174 }
    175 
    176 int SQLTransaction::BeginCommand(const char* command) {
    177   int rv = SQLITE_ERROR;
    178   if (!began_ && db_) {
    179     rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
    180     began_ = (rv == SQLITE_OK);
    181   }
    182   return rv;
    183 }
    184 
    185 int SQLTransaction::EndCommand(const char* command) {
    186   int rv = SQLITE_ERROR;
    187   if (began_ && db_) {
    188     rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
    189     began_ = (rv != SQLITE_OK);
    190   }
    191   return rv;
    192 }
    193 
    194 SQLNestedTransactionSite::~SQLNestedTransactionSite() {
    195   DCHECK(!top_transaction_);
    196 }
    197 
    198 void SQLNestedTransactionSite::SetTopTransaction(SQLNestedTransaction* top) {
    199   DCHECK(!top || !top_transaction_);
    200   top_transaction_ = top;
    201 }
    202 
    203 SQLNestedTransaction::SQLNestedTransaction(SQLNestedTransactionSite* site)
    204   : SQLTransaction(site->GetSqlite3DB()),
    205     needs_rollback_(false),
    206     site_(site) {
    207   DCHECK(site);
    208   if (site->GetTopTransaction() == NULL) {
    209     site->SetTopTransaction(this);
    210   }
    211 }
    212 
    213 SQLNestedTransaction::~SQLNestedTransaction() {
    214   if (began_) {
    215     Rollback();
    216   }
    217   if (site_->GetTopTransaction() == this) {
    218     site_->SetTopTransaction(NULL);
    219   }
    220 }
    221 
    222 int SQLNestedTransaction::BeginCommand(const char* command) {
    223   DCHECK(db_);
    224   DCHECK(site_ && site_->GetTopTransaction());
    225   if (!db_ || began_) {
    226     return SQLITE_ERROR;
    227   }
    228   if (site_->GetTopTransaction() == this) {
    229     int rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
    230     began_ = (rv == SQLITE_OK);
    231     if (began_) {
    232       site_->OnBegin();
    233     }
    234     return rv;
    235   } else {
    236     if (site_->GetTopTransaction()->needs_rollback_) {
    237       return SQLITE_ERROR;
    238     }
    239     began_ = true;
    240     return SQLITE_OK;
    241   }
    242 }
    243 
    244 int SQLNestedTransaction::EndCommand(const char* command) {
    245   DCHECK(db_);
    246   DCHECK(site_ && site_->GetTopTransaction());
    247   if (!db_ || !began_) {
    248     return SQLITE_ERROR;
    249   }
    250   if (site_->GetTopTransaction() == this) {
    251     if (needs_rollback_) {
    252       sqlite3_exec(db_, "ROLLBACK", NULL, NULL, NULL);
    253       began_ = false;  // reset so we don't try to rollback or call
    254                        // OnRollback() again
    255       site_->OnRollback();
    256       return SQLITE_ERROR;
    257     } else {
    258       int rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
    259       began_ = (rv != SQLITE_OK);
    260       if (strcmp(command, "ROLLBACK") == 0) {
    261         began_ = false;  // reset so we don't try to rollbck or call
    262                          // OnRollback() again
    263         site_->OnRollback();
    264       } else {
    265         DCHECK(strcmp(command, "COMMIT") == 0);
    266         if (rv == SQLITE_OK) {
    267           site_->OnCommit();
    268         }
    269       }
    270       return rv;
    271     }
    272   } else {
    273     if (strcmp(command, "ROLLBACK") == 0) {
    274       site_->GetTopTransaction()->needs_rollback_ = true;
    275     }
    276     began_ = false;
    277     return SQLITE_OK;
    278   }
    279 }
    280 
    281 int SQLStatement::prepare(sqlite3* db, const char* sql, int sql_len) {
    282   DCHECK(!stmt_);
    283   int rv = sqlite3_prepare_v2(db, sql, sql_len, &stmt_, NULL);
    284   if (rv != SQLITE_OK) {
    285     SQLErrorHandler* error_handler = GetErrorHandlerFactory()->Make();
    286     return error_handler->HandleError(rv, db);
    287   }
    288   return rv;
    289 }
    290 
    291 int SQLStatement::step() {
    292   DCHECK(stmt_);
    293   int status = sqlite3_step(stmt_);
    294   if ((status == SQLITE_ROW) || (status == SQLITE_DONE))
    295     return status;
    296   // We got a problem.
    297   SQLErrorHandler* error_handler = GetErrorHandlerFactory()->Make();
    298   return error_handler->HandleError(status, db_handle());
    299 }
    300 
    301 int SQLStatement::reset() {
    302   DCHECK(stmt_);
    303   return sqlite3_reset(stmt_);
    304 }
    305 
    306 sqlite_int64 SQLStatement::last_insert_rowid() {
    307   DCHECK(stmt_);
    308   return sqlite3_last_insert_rowid(db_handle());
    309 }
    310 
    311 int SQLStatement::changes() {
    312   DCHECK(stmt_);
    313   return sqlite3_changes(db_handle());
    314 }
    315 
    316 sqlite3* SQLStatement::db_handle() {
    317   DCHECK(stmt_);
    318   return sqlite3_db_handle(stmt_);
    319 }
    320 
    321 int SQLStatement::bind_parameter_count() {
    322   DCHECK(stmt_);
    323   return sqlite3_bind_parameter_count(stmt_);
    324 }
    325 
    326 int SQLStatement::bind_blob(int index, std::vector<unsigned char>* blob) {
    327   if (blob) {
    328     const void* value = blob->empty() ? NULL : &(*blob)[0];
    329     int len = static_cast<int>(blob->size());
    330     return bind_blob(index, value, len);
    331   } else {
    332     return bind_null(index);
    333   }
    334 }
    335 
    336 int SQLStatement::bind_blob(int index, const void* value, int value_len) {
    337    return bind_blob(index, value, value_len, SQLITE_TRANSIENT);
    338 }
    339 
    340 int SQLStatement::bind_blob(int index, const void* value, int value_len,
    341                             Function dtor) {
    342   DCHECK(stmt_);
    343   return sqlite3_bind_blob(stmt_, index + 1, value, value_len, dtor);
    344 }
    345 
    346 int SQLStatement::bind_double(int index, double value) {
    347   DCHECK(stmt_);
    348   return sqlite3_bind_double(stmt_, index + 1, value);
    349 }
    350 
    351 int SQLStatement::bind_bool(int index, bool value) {
    352   DCHECK(stmt_);
    353   return sqlite3_bind_int(stmt_, index + 1, value);
    354 }
    355 
    356 int SQLStatement::bind_int(int index, int value) {
    357   DCHECK(stmt_);
    358   return sqlite3_bind_int(stmt_, index + 1, value);
    359 }
    360 
    361 int SQLStatement::bind_int64(int index, sqlite_int64 value) {
    362   DCHECK(stmt_);
    363   return sqlite3_bind_int64(stmt_, index + 1, value);
    364 }
    365 
    366 int SQLStatement::bind_null(int index) {
    367   DCHECK(stmt_);
    368   return sqlite3_bind_null(stmt_, index + 1);
    369 }
    370 
    371 int SQLStatement::bind_text(int index, const char* value, int value_len,
    372               Function dtor) {
    373   DCHECK(stmt_);
    374   return sqlite3_bind_text(stmt_, index + 1, value, value_len, dtor);
    375 }
    376 
    377 int SQLStatement::bind_text16(int index, const char16* value, int value_len,
    378                 Function dtor) {
    379   DCHECK(stmt_);
    380   value_len *= sizeof(char16);
    381   return sqlite3_bind_text16(stmt_, index + 1, value, value_len, dtor);
    382 }
    383 
    384 int SQLStatement::bind_value(int index, const sqlite3_value* value) {
    385   DCHECK(stmt_);
    386   return sqlite3_bind_value(stmt_, index + 1, value);
    387 }
    388 
    389 int SQLStatement::column_count() {
    390   DCHECK(stmt_);
    391   return sqlite3_column_count(stmt_);
    392 }
    393 
    394 int SQLStatement::column_type(int index) {
    395   DCHECK(stmt_);
    396   return sqlite3_column_type(stmt_, index);
    397 }
    398 
    399 const void* SQLStatement::column_blob(int index) {
    400   DCHECK(stmt_);
    401   return sqlite3_column_blob(stmt_, index);
    402 }
    403 
    404 bool SQLStatement::column_blob_as_vector(int index,
    405                                          std::vector<unsigned char>* blob) {
    406   DCHECK(stmt_);
    407   const void* p = column_blob(index);
    408   size_t len = column_bytes(index);
    409   blob->resize(len);
    410   if (blob->size() != len) {
    411     return false;
    412   }
    413   if (len > 0)
    414     memcpy(&(blob->front()), p, len);
    415   return true;
    416 }
    417 
    418 bool SQLStatement::column_blob_as_string(int index, std::string* blob) {
    419   DCHECK(stmt_);
    420   const void* p = column_blob(index);
    421   size_t len = column_bytes(index);
    422   blob->resize(len);
    423   if (blob->size() != len) {
    424     return false;
    425   }
    426   blob->assign(reinterpret_cast<const char*>(p), len);
    427   return true;
    428 }
    429 
    430 int SQLStatement::column_bytes(int index) {
    431   DCHECK(stmt_);
    432   return sqlite3_column_bytes(stmt_, index);
    433 }
    434 
    435 int SQLStatement::column_bytes16(int index) {
    436   DCHECK(stmt_);
    437   return sqlite3_column_bytes16(stmt_, index);
    438 }
    439 
    440 double SQLStatement::column_double(int index) {
    441   DCHECK(stmt_);
    442   return sqlite3_column_double(stmt_, index);
    443 }
    444 
    445 bool SQLStatement::column_bool(int index) {
    446   DCHECK(stmt_);
    447   return sqlite3_column_int(stmt_, index) ? true : false;
    448 }
    449 
    450 int SQLStatement::column_int(int index) {
    451   DCHECK(stmt_);
    452   return sqlite3_column_int(stmt_, index);
    453 }
    454 
    455 sqlite_int64 SQLStatement::column_int64(int index) {
    456   DCHECK(stmt_);
    457   return sqlite3_column_int64(stmt_, index);
    458 }
    459 
    460 const char* SQLStatement::column_text(int index) {
    461   DCHECK(stmt_);
    462   return reinterpret_cast<const char*>(sqlite3_column_text(stmt_, index));
    463 }
    464 
    465 bool SQLStatement::column_string(int index, std::string* str) {
    466   DCHECK(stmt_);
    467   DCHECK(str);
    468   const char* s = column_text(index);
    469   str->assign(s ? s : std::string());
    470   return s != NULL;
    471 }
    472 
    473 std::string SQLStatement::column_string(int index) {
    474   std::string str;
    475   column_string(index, &str);
    476   return str;
    477 }
    478 
    479 const char16* SQLStatement::column_text16(int index) {
    480   DCHECK(stmt_);
    481   return static_cast<const char16*>(sqlite3_column_text16(stmt_, index));
    482 }
    483 
    484 bool SQLStatement::column_string16(int index, string16* str) {
    485   DCHECK(stmt_);
    486   DCHECK(str);
    487   const char* s = column_text(index);
    488   str->assign(s ? UTF8ToUTF16(s) : string16());
    489   return (s != NULL);
    490 }
    491 
    492 string16 SQLStatement::column_string16(int index) {
    493   string16 str;
    494   column_string16(index, &str);
    495   return str;
    496 }
    497 
    498 bool SQLStatement::column_wstring(int index, std::wstring* str) {
    499   DCHECK(stmt_);
    500   DCHECK(str);
    501   const char* s = column_text(index);
    502   str->assign(s ? UTF8ToWide(s) : std::wstring());
    503   return (s != NULL);
    504 }
    505 
    506 std::wstring SQLStatement::column_wstring(int index) {
    507   std::wstring wstr;
    508   column_wstring(index, &wstr);
    509   return wstr;
    510 }
    511