1 // Copyright 2014 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "chrome/browser/chromeos/platform_keys/platform_keys.h" 6 7 #include <cryptohi.h> 8 9 #include "base/bind.h" 10 #include "base/bind_helpers.h" 11 #include "base/callback.h" 12 #include "base/compiler_specific.h" 13 #include "base/location.h" 14 #include "base/logging.h" 15 #include "base/macros.h" 16 #include "base/single_thread_task_runner.h" 17 #include "base/thread_task_runner_handle.h" 18 #include "base/threading/worker_pool.h" 19 #include "chrome/browser/extensions/api/enterprise_platform_keys/enterprise_platform_keys_api.h" 20 #include "chrome/browser/net/nss_context.h" 21 #include "content/public/browser/browser_context.h" 22 #include "content/public/browser/browser_thread.h" 23 #include "crypto/rsa_private_key.h" 24 #include "net/base/crypto_module.h" 25 #include "net/base/net_errors.h" 26 #include "net/cert/cert_database.h" 27 #include "net/cert/nss_cert_database.h" 28 #include "net/cert/x509_certificate.h" 29 30 using content::BrowserContext; 31 using content::BrowserThread; 32 33 namespace { 34 const char kErrorInternal[] = "Internal Error."; 35 const char kErrorKeyNotFound[] = "Key not found."; 36 const char kErrorCertificateNotFound[] = "Certificate could not be found."; 37 const char kErrorAlgorithmNotSupported[] = "Algorithm not supported."; 38 39 // The current maximal RSA modulus length that ChromeOS's TPM supports for key 40 // generation. 41 const unsigned int kMaxRSAModulusLengthBits = 2048; 42 } 43 44 namespace chromeos { 45 46 namespace platform_keys { 47 48 namespace { 49 50 // Base class to store state that is common to all NSS database operations and 51 // to provide convenience methods to call back. 52 // Keeps track of the originating task runner. 53 class NSSOperationState { 54 public: 55 NSSOperationState(); 56 virtual ~NSSOperationState() {} 57 58 // Called if an error occurred during the execution of the NSS operation 59 // described by this object. 60 virtual void OnError(const tracked_objects::Location& from, 61 const std::string& error_message) = 0; 62 63 crypto::ScopedPK11Slot slot_; 64 65 // The task runner on which the NSS operation was called. Any reply must be 66 // posted to this runner. 67 scoped_refptr<base::SingleThreadTaskRunner> origin_task_runner_; 68 69 private: 70 DISALLOW_COPY_AND_ASSIGN(NSSOperationState); 71 }; 72 73 typedef base::Callback<void(net::NSSCertDatabase* cert_db)> GetCertDBCallback; 74 75 // Called back with the NSSCertDatabase associated to the given |token_id|. 76 // Calls |callback| if the database was successfully retrieved. Used by 77 // GetCertDatabaseOnIOThread. 78 void DidGetCertDBOnIOThread(const GetCertDBCallback& callback, 79 NSSOperationState* state, 80 net::NSSCertDatabase* cert_db) { 81 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); 82 if (!cert_db) { 83 LOG(ERROR) << "Couldn't get NSSCertDatabase."; 84 state->OnError(FROM_HERE, kErrorInternal); 85 return; 86 } 87 88 state->slot_ = cert_db->GetPrivateSlot(); 89 if (!state->slot_) { 90 LOG(ERROR) << "No private slot"; 91 state->OnError(FROM_HERE, kErrorInternal); 92 return; 93 } 94 95 callback.Run(cert_db); 96 } 97 98 // Retrieves the NSSCertDatabase from |context|. Must be called on the IO 99 // thread. 100 void GetCertDatabaseOnIOThread(content::ResourceContext* context, 101 const GetCertDBCallback& callback, 102 NSSOperationState* state) { 103 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); 104 net::NSSCertDatabase* cert_db = GetNSSCertDatabaseForResourceContext( 105 context, base::Bind(&DidGetCertDBOnIOThread, callback, state)); 106 107 if (cert_db) 108 DidGetCertDBOnIOThread(callback, state, cert_db); 109 } 110 111 // Asynchronously fetches the NSSCertDatabase and PK11Slot for |token_id|. 112 // Stores the slot in |state| and passes the database to |callback|. Will run 113 // |callback| on the IO thread. 114 void GetCertDatabase(const std::string& token_id, 115 const GetCertDBCallback& callback, 116 BrowserContext* browser_context, 117 NSSOperationState* state) { 118 // TODO(pneubeck): Decide which DB to retrieve depending on |token_id|. 119 BrowserThread::PostTask(BrowserThread::IO, 120 FROM_HERE, 121 base::Bind(&GetCertDatabaseOnIOThread, 122 browser_context->GetResourceContext(), 123 callback, 124 state)); 125 } 126 127 class GenerateRSAKeyState : public NSSOperationState { 128 public: 129 GenerateRSAKeyState(unsigned int modulus_length_bits, 130 const subtle::GenerateKeyCallback& callback); 131 virtual ~GenerateRSAKeyState() {} 132 133 virtual void OnError(const tracked_objects::Location& from, 134 const std::string& error_message) OVERRIDE { 135 CallBack(from, std::string() /* no public key */, error_message); 136 } 137 138 void CallBack(const tracked_objects::Location& from, 139 const std::string& public_key_spki_der, 140 const std::string& error_message) { 141 origin_task_runner_->PostTask( 142 from, base::Bind(callback_, public_key_spki_der, error_message)); 143 } 144 145 const unsigned int modulus_length_bits_; 146 147 private: 148 // Must be called on origin thread, use CallBack() therefore. 149 subtle::GenerateKeyCallback callback_; 150 }; 151 152 class SignState : public NSSOperationState { 153 public: 154 SignState(const std::string& public_key, 155 HashAlgorithm hash_algorithm, 156 const std::string& data, 157 const subtle::SignCallback& callback); 158 virtual ~SignState() {} 159 160 virtual void OnError(const tracked_objects::Location& from, 161 const std::string& error_message) OVERRIDE { 162 CallBack(from, std::string() /* no signature */, error_message); 163 } 164 165 void CallBack(const tracked_objects::Location& from, 166 const std::string& signature, 167 const std::string& error_message) { 168 origin_task_runner_->PostTask( 169 from, base::Bind(callback_, signature, error_message)); 170 } 171 172 const std::string public_key_; 173 HashAlgorithm hash_algorithm_; 174 const std::string data_; 175 176 private: 177 // Must be called on origin thread, use CallBack() therefore. 178 subtle::SignCallback callback_; 179 }; 180 181 class GetCertificatesState : public NSSOperationState { 182 public: 183 explicit GetCertificatesState(const GetCertificatesCallback& callback); 184 virtual ~GetCertificatesState() {} 185 186 virtual void OnError(const tracked_objects::Location& from, 187 const std::string& error_message) OVERRIDE { 188 CallBack(from, 189 scoped_ptr<net::CertificateList>() /* no certificates */, 190 error_message); 191 } 192 193 void CallBack(const tracked_objects::Location& from, 194 scoped_ptr<net::CertificateList> certs, 195 const std::string& error_message) { 196 origin_task_runner_->PostTask( 197 from, base::Bind(callback_, base::Passed(&certs), error_message)); 198 } 199 200 scoped_ptr<net::CertificateList> certs_; 201 202 private: 203 // Must be called on origin thread, use CallBack() therefore. 204 GetCertificatesCallback callback_; 205 }; 206 207 class ImportCertificateState : public NSSOperationState { 208 public: 209 ImportCertificateState(scoped_refptr<net::X509Certificate> certificate, 210 const ImportCertificateCallback& callback); 211 virtual ~ImportCertificateState() {} 212 213 virtual void OnError(const tracked_objects::Location& from, 214 const std::string& error_message) OVERRIDE { 215 CallBack(from, error_message); 216 } 217 218 void CallBack(const tracked_objects::Location& from, 219 const std::string& error_message) { 220 origin_task_runner_->PostTask(from, base::Bind(callback_, error_message)); 221 } 222 223 scoped_refptr<net::X509Certificate> certificate_; 224 225 private: 226 // Must be called on origin thread, use CallBack() therefore. 227 ImportCertificateCallback callback_; 228 }; 229 230 class RemoveCertificateState : public NSSOperationState { 231 public: 232 RemoveCertificateState(scoped_refptr<net::X509Certificate> certificate, 233 const RemoveCertificateCallback& callback); 234 virtual ~RemoveCertificateState() {} 235 236 virtual void OnError(const tracked_objects::Location& from, 237 const std::string& error_message) OVERRIDE { 238 CallBack(from, error_message); 239 } 240 241 void CallBack(const tracked_objects::Location& from, 242 const std::string& error_message) { 243 origin_task_runner_->PostTask(from, base::Bind(callback_, error_message)); 244 } 245 246 scoped_refptr<net::X509Certificate> certificate_; 247 248 private: 249 // Must be called on origin thread, use CallBack() therefore. 250 RemoveCertificateCallback callback_; 251 }; 252 253 NSSOperationState::NSSOperationState() 254 : origin_task_runner_(base::ThreadTaskRunnerHandle::Get()) { 255 } 256 257 GenerateRSAKeyState::GenerateRSAKeyState( 258 unsigned int modulus_length_bits, 259 const subtle::GenerateKeyCallback& callback) 260 : modulus_length_bits_(modulus_length_bits), callback_(callback) { 261 } 262 263 SignState::SignState(const std::string& public_key, 264 HashAlgorithm hash_algorithm, 265 const std::string& data, 266 const subtle::SignCallback& callback) 267 : public_key_(public_key), 268 hash_algorithm_(hash_algorithm), 269 data_(data), 270 callback_(callback) { 271 } 272 273 GetCertificatesState::GetCertificatesState( 274 const GetCertificatesCallback& callback) 275 : callback_(callback) { 276 } 277 278 ImportCertificateState::ImportCertificateState( 279 scoped_refptr<net::X509Certificate> certificate, 280 const ImportCertificateCallback& callback) 281 : certificate_(certificate), callback_(callback) { 282 } 283 284 RemoveCertificateState::RemoveCertificateState( 285 scoped_refptr<net::X509Certificate> certificate, 286 const RemoveCertificateCallback& callback) 287 : certificate_(certificate), callback_(callback) { 288 } 289 290 // Does the actual key generation on a worker thread. Used by 291 // GenerateRSAKeyWithDB(). 292 void GenerateRSAKeyOnWorkerThread(scoped_ptr<GenerateRSAKeyState> state) { 293 scoped_ptr<crypto::RSAPrivateKey> rsa_key( 294 crypto::RSAPrivateKey::CreateSensitive(state->slot_.get(), 295 state->modulus_length_bits_)); 296 if (!rsa_key) { 297 LOG(ERROR) << "Couldn't create key."; 298 state->OnError(FROM_HERE, kErrorInternal); 299 return; 300 } 301 302 std::vector<uint8> public_key_spki_der; 303 if (!rsa_key->ExportPublicKey(&public_key_spki_der)) { 304 // TODO(pneubeck): Remove rsa_key from storage. 305 LOG(ERROR) << "Couldn't export public key."; 306 state->OnError(FROM_HERE, kErrorInternal); 307 return; 308 } 309 state->CallBack( 310 FROM_HERE, 311 std::string(public_key_spki_der.begin(), public_key_spki_der.end()), 312 std::string() /* no error */); 313 } 314 315 // Continues generating a RSA key with the obtained NSSCertDatabase. Used by 316 // GenerateRSAKey(). 317 void GenerateRSAKeyWithDB(scoped_ptr<GenerateRSAKeyState> state, 318 net::NSSCertDatabase* cert_db) { 319 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); 320 // Only the slot and not the NSSCertDatabase is required. Ignore |cert_db|. 321 base::WorkerPool::PostTask( 322 FROM_HERE, 323 base::Bind(&GenerateRSAKeyOnWorkerThread, base::Passed(&state)), 324 true /*task is slow*/); 325 } 326 327 // Does the actual signing on a worker thread. Used by RSASignWithDB(). 328 void RSASignOnWorkerThread(scoped_ptr<SignState> state) { 329 const uint8* public_key_uint8 = 330 reinterpret_cast<const uint8*>(state->public_key_.data()); 331 std::vector<uint8> public_key_vector( 332 public_key_uint8, public_key_uint8 + state->public_key_.size()); 333 334 // TODO(pneubeck): This searches all slots. Change to look only at |slot_|. 335 scoped_ptr<crypto::RSAPrivateKey> rsa_key( 336 crypto::RSAPrivateKey::FindFromPublicKeyInfo(public_key_vector)); 337 if (!rsa_key || rsa_key->key()->pkcs11Slot != state->slot_) { 338 state->OnError(FROM_HERE, kErrorKeyNotFound); 339 return; 340 } 341 342 SECOidTag sign_alg_tag = SEC_OID_UNKNOWN; 343 switch (state->hash_algorithm_) { 344 case HASH_ALGORITHM_SHA1: 345 sign_alg_tag = SEC_OID_PKCS1_SHA1_WITH_RSA_ENCRYPTION; 346 break; 347 case HASH_ALGORITHM_SHA256: 348 sign_alg_tag = SEC_OID_PKCS1_SHA256_WITH_RSA_ENCRYPTION; 349 break; 350 case HASH_ALGORITHM_SHA384: 351 sign_alg_tag = SEC_OID_PKCS1_SHA384_WITH_RSA_ENCRYPTION; 352 break; 353 case HASH_ALGORITHM_SHA512: 354 sign_alg_tag = SEC_OID_PKCS1_SHA512_WITH_RSA_ENCRYPTION; 355 break; 356 } 357 358 SECItem sign_result = {siBuffer, NULL, 0}; 359 if (SEC_SignData(&sign_result, 360 reinterpret_cast<const unsigned char*>(state->data_.data()), 361 state->data_.size(), 362 rsa_key->key(), 363 sign_alg_tag) != SECSuccess) { 364 LOG(ERROR) << "Couldn't sign."; 365 state->OnError(FROM_HERE, kErrorInternal); 366 return; 367 } 368 369 std::string signature(reinterpret_cast<const char*>(sign_result.data), 370 sign_result.len); 371 state->CallBack(FROM_HERE, signature, std::string() /* no error */); 372 } 373 374 // Continues signing with the obtained NSSCertDatabase. Used by Sign(). 375 void RSASignWithDB(scoped_ptr<SignState> state, net::NSSCertDatabase* cert_db) { 376 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); 377 // Only the slot and not the NSSCertDatabase is required. Ignore |cert_db|. 378 base::WorkerPool::PostTask( 379 FROM_HERE, 380 base::Bind(&RSASignOnWorkerThread, base::Passed(&state)), 381 true /*task is slow*/); 382 } 383 384 // Filters the obtained certificates on a worker thread. Used by 385 // DidGetCertificates(). 386 void FilterCertificatesOnWorkerThread(scoped_ptr<GetCertificatesState> state) { 387 scoped_ptr<net::CertificateList> client_certs(new net::CertificateList); 388 for (net::CertificateList::const_iterator it = state->certs_->begin(); 389 it != state->certs_->end(); 390 ++it) { 391 net::X509Certificate::OSCertHandle cert_handle = (*it)->os_cert_handle(); 392 crypto::ScopedPK11Slot cert_slot(PK11_KeyForCertExists(cert_handle, 393 NULL, // keyPtr 394 NULL)); // wincx 395 396 // Keep only user certificates, i.e. certs for which the private key is 397 // present and stored in the queried slot. 398 if (cert_slot != state->slot_) 399 continue; 400 401 client_certs->push_back(*it); 402 } 403 404 state->CallBack(FROM_HERE, client_certs.Pass(), std::string() /* no error */); 405 } 406 407 // Passes the obtained certificates to the worker thread for filtering. Used by 408 // GetCertificatesWithDB(). 409 void DidGetCertificates(scoped_ptr<GetCertificatesState> state, 410 scoped_ptr<net::CertificateList> all_certs) { 411 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); 412 state->certs_ = all_certs.Pass(); 413 base::WorkerPool::PostTask( 414 FROM_HERE, 415 base::Bind(&FilterCertificatesOnWorkerThread, base::Passed(&state)), 416 true /*task is slow*/); 417 } 418 419 // Continues getting certificates with the obtained NSSCertDatabase. Used by 420 // GetCertificates(). 421 void GetCertificatesWithDB(scoped_ptr<GetCertificatesState> state, 422 net::NSSCertDatabase* cert_db) { 423 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); 424 // Get the pointer to slot before base::Passed releases |state|. 425 PK11SlotInfo* slot = state->slot_.get(); 426 cert_db->ListCertsInSlot( 427 base::Bind(&DidGetCertificates, base::Passed(&state)), slot); 428 } 429 430 // Does the actual certificate importing on the IO thread. Used by 431 // ImportCertificate(). 432 void ImportCertificateWithDB(scoped_ptr<ImportCertificateState> state, 433 net::NSSCertDatabase* cert_db) { 434 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); 435 // TODO(pneubeck): Use |state->slot_| to verify that we're really importing to 436 // the correct token. 437 // |cert_db| is not required, ignore it. 438 net::CertDatabase* db = net::CertDatabase::GetInstance(); 439 440 const net::Error cert_status = 441 static_cast<net::Error>(db->CheckUserCert(state->certificate_)); 442 if (cert_status == net::ERR_NO_PRIVATE_KEY_FOR_CERT) { 443 state->OnError(FROM_HERE, kErrorKeyNotFound); 444 return; 445 } else if (cert_status != net::OK) { 446 state->OnError(FROM_HERE, net::ErrorToString(cert_status)); 447 return; 448 } 449 450 const net::Error import_status = 451 static_cast<net::Error>(db->AddUserCert(state->certificate_.get())); 452 if (import_status != net::OK) { 453 LOG(ERROR) << "Could not import certificate."; 454 state->OnError(FROM_HERE, net::ErrorToString(import_status)); 455 return; 456 } 457 458 state->CallBack(FROM_HERE, std::string() /* no error */); 459 } 460 461 // Called on IO thread after the certificate removal is finished. 462 void DidRemoveCertificate(scoped_ptr<RemoveCertificateState> state, 463 bool certificate_found, 464 bool success) { 465 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); 466 // CertificateNotFound error has precedence over an internal error. 467 if (!certificate_found) { 468 state->OnError(FROM_HERE, kErrorCertificateNotFound); 469 return; 470 } 471 if (!success) { 472 state->OnError(FROM_HERE, kErrorInternal); 473 return; 474 } 475 476 state->CallBack(FROM_HERE, std::string() /* no error */); 477 } 478 479 // Does the actual certificate removal on the IO thread. Used by 480 // RemoveCertificate(). 481 void RemoveCertificateWithDB(scoped_ptr<RemoveCertificateState> state, 482 net::NSSCertDatabase* cert_db) { 483 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO)); 484 // Get the pointer before base::Passed clears |state|. 485 scoped_refptr<net::X509Certificate> certificate = state->certificate_; 486 bool certificate_found = certificate->os_cert_handle()->isperm; 487 cert_db->DeleteCertAndKeyAsync( 488 certificate, 489 base::Bind( 490 &DidRemoveCertificate, base::Passed(&state), certificate_found)); 491 } 492 493 } // namespace 494 495 namespace subtle { 496 497 void GenerateRSAKey(const std::string& token_id, 498 unsigned int modulus_length_bits, 499 const GenerateKeyCallback& callback, 500 BrowserContext* browser_context) { 501 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 502 scoped_ptr<GenerateRSAKeyState> state( 503 new GenerateRSAKeyState(modulus_length_bits, callback)); 504 505 if (modulus_length_bits > kMaxRSAModulusLengthBits) { 506 state->OnError(FROM_HERE, kErrorAlgorithmNotSupported); 507 return; 508 } 509 510 // Get the pointer to |state| before base::Passed releases |state|. 511 NSSOperationState* state_ptr = state.get(); 512 GetCertDatabase(token_id, 513 base::Bind(&GenerateRSAKeyWithDB, base::Passed(&state)), 514 browser_context, 515 state_ptr); 516 } 517 518 void Sign(const std::string& token_id, 519 const std::string& public_key, 520 HashAlgorithm hash_algorithm, 521 const std::string& data, 522 const SignCallback& callback, 523 BrowserContext* browser_context) { 524 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 525 scoped_ptr<SignState> state( 526 new SignState(public_key, hash_algorithm, data, callback)); 527 // Get the pointer to |state| before base::Passed releases |state|. 528 NSSOperationState* state_ptr = state.get(); 529 530 // The NSSCertDatabase object is not required. But in case it's not available 531 // we would get more informative error messages and we can double check that 532 // we use a key of the correct token. 533 GetCertDatabase(token_id, 534 base::Bind(&RSASignWithDB, base::Passed(&state)), 535 browser_context, 536 state_ptr); 537 } 538 539 } // namespace subtle 540 541 void GetCertificates(const std::string& token_id, 542 const GetCertificatesCallback& callback, 543 BrowserContext* browser_context) { 544 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 545 scoped_ptr<GetCertificatesState> state(new GetCertificatesState(callback)); 546 // Get the pointer to |state| before base::Passed releases |state|. 547 NSSOperationState* state_ptr = state.get(); 548 GetCertDatabase(token_id, 549 base::Bind(&GetCertificatesWithDB, base::Passed(&state)), 550 browser_context, 551 state_ptr); 552 } 553 554 void ImportCertificate(const std::string& token_id, 555 scoped_refptr<net::X509Certificate> certificate, 556 const ImportCertificateCallback& callback, 557 BrowserContext* browser_context) { 558 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 559 scoped_ptr<ImportCertificateState> state( 560 new ImportCertificateState(certificate, callback)); 561 // Get the pointer to |state| before base::Passed releases |state|. 562 NSSOperationState* state_ptr = state.get(); 563 564 // The NSSCertDatabase object is not required. But in case it's not available 565 // we would get more informative error messages and we can double check that 566 // we use a key of the correct token. 567 GetCertDatabase(token_id, 568 base::Bind(&ImportCertificateWithDB, base::Passed(&state)), 569 browser_context, 570 state_ptr); 571 } 572 573 void RemoveCertificate(const std::string& token_id, 574 scoped_refptr<net::X509Certificate> certificate, 575 const RemoveCertificateCallback& callback, 576 BrowserContext* browser_context) { 577 DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); 578 scoped_ptr<RemoveCertificateState> state( 579 new RemoveCertificateState(certificate, callback)); 580 // Get the pointer to |state| before base::Passed releases |state|. 581 NSSOperationState* state_ptr = state.get(); 582 583 // The NSSCertDatabase object is not required. But in case it's not available 584 // we would get more informative error messages. 585 GetCertDatabase(token_id, 586 base::Bind(&RemoveCertificateWithDB, base::Passed(&state)), 587 browser_context, 588 state_ptr); 589 } 590 591 } // namespace platform_keys 592 593 } // namespace chromeos 594