1 // Copyright (c) 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/ssl/openssl_client_key_store.h" 6 7 #include <openssl/evp.h> 8 #include <openssl/x509.h> 9 10 #include "base/memory/scoped_ptr.h" 11 #include "base/memory/singleton.h" 12 #include "net/cert/x509_certificate.h" 13 14 namespace net { 15 16 namespace { 17 18 typedef OpenSSLClientKeyStore::ScopedEVP_PKEY ScopedEVP_PKEY; 19 20 // Increment the reference count of a given EVP_PKEY. This function 21 // is similar to EVP_PKEY_dup which is not available from the OpenSSL 22 // version used by Chromium at the moment. Its name is distinct to 23 // avoid compiler warnings about ambiguous function calls at caller 24 // sites. 25 EVP_PKEY* CopyEVP_PKEY(EVP_PKEY* key) { 26 if (key) 27 CRYPTO_add(&key->references, 1, CRYPTO_LOCK_EVP_PKEY); 28 return key; 29 } 30 31 // Return the EVP_PKEY holding the public key of a given certificate. 32 // |cert| is a certificate. 33 // Returns a scoped EVP_PKEY for it. 34 ScopedEVP_PKEY GetOpenSSLPublicKey(const X509Certificate* cert) { 35 // X509_PUBKEY_get() increments the reference count of its result. 36 // Unlike X509_get_X509_PUBKEY() which simply returns a direct pointer. 37 EVP_PKEY* pkey = 38 X509_PUBKEY_get(X509_get_X509_PUBKEY(cert->os_cert_handle())); 39 if (!pkey) 40 LOG(ERROR) << "Can't extract private key from certificate!"; 41 return ScopedEVP_PKEY(pkey); 42 } 43 44 } // namespace 45 46 OpenSSLClientKeyStore::OpenSSLClientKeyStore() { 47 } 48 49 OpenSSLClientKeyStore::~OpenSSLClientKeyStore() { 50 } 51 52 OpenSSLClientKeyStore::KeyPair::KeyPair(EVP_PKEY* pub_key, 53 EVP_PKEY* priv_key) { 54 public_key = CopyEVP_PKEY(pub_key); 55 private_key = CopyEVP_PKEY(priv_key); 56 } 57 58 OpenSSLClientKeyStore::KeyPair::~KeyPair() { 59 EVP_PKEY_free(public_key); 60 EVP_PKEY_free(private_key); 61 } 62 63 OpenSSLClientKeyStore::KeyPair::KeyPair(const KeyPair& other) { 64 public_key = CopyEVP_PKEY(other.public_key); 65 private_key = CopyEVP_PKEY(other.private_key); 66 } 67 68 void OpenSSLClientKeyStore::KeyPair::operator=(const KeyPair& other) { 69 EVP_PKEY* old_public_key = public_key; 70 EVP_PKEY* old_private_key = private_key; 71 public_key = CopyEVP_PKEY(other.public_key); 72 private_key = CopyEVP_PKEY(other.private_key); 73 EVP_PKEY_free(old_private_key); 74 EVP_PKEY_free(old_public_key); 75 } 76 77 int OpenSSLClientKeyStore::FindKeyPairIndex(EVP_PKEY* public_key) { 78 if (!public_key) 79 return -1; 80 for (size_t n = 0; n < pairs_.size(); ++n) { 81 if (EVP_PKEY_cmp(pairs_[n].public_key, public_key) == 1) 82 return static_cast<int>(n); 83 } 84 return -1; 85 } 86 87 void OpenSSLClientKeyStore::AddKeyPair(EVP_PKEY* pub_key, 88 EVP_PKEY* private_key) { 89 int index = FindKeyPairIndex(pub_key); 90 if (index < 0) 91 pairs_.push_back(KeyPair(pub_key, private_key)); 92 } 93 94 // Common code for OpenSSLClientKeyStore. Shared by all OpenSSL-based 95 // builds. 96 bool OpenSSLClientKeyStore::RecordClientCertPrivateKey( 97 const X509Certificate* client_cert, 98 EVP_PKEY* private_key) { 99 // Sanity check. 100 if (!client_cert || !private_key) 101 return false; 102 103 // Get public key from certificate. 104 ScopedEVP_PKEY pub_key(GetOpenSSLPublicKey(client_cert)); 105 if (!pub_key.get()) 106 return false; 107 108 AddKeyPair(pub_key.get(), private_key); 109 return true; 110 } 111 112 bool OpenSSLClientKeyStore::FetchClientCertPrivateKey( 113 const X509Certificate* client_cert, 114 ScopedEVP_PKEY* private_key) { 115 if (!client_cert) 116 return false; 117 118 ScopedEVP_PKEY pub_key(GetOpenSSLPublicKey(client_cert)); 119 if (!pub_key.get()) 120 return false; 121 122 int index = FindKeyPairIndex(pub_key.get()); 123 if (index < 0) 124 return false; 125 126 private_key->reset(CopyEVP_PKEY(pairs_[index].private_key)); 127 return true; 128 } 129 130 void OpenSSLClientKeyStore::Flush() { 131 pairs_.clear(); 132 } 133 134 OpenSSLClientKeyStore* OpenSSLClientKeyStore::GetInstance() { 135 return Singleton<OpenSSLClientKeyStore>::get(); 136 } 137 138 } // namespace net 139 140 141