Home | History | Annotate | Download | only in cup
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "google_apis/cup/client_update_protocol.h"
      6 
      7 #include "base/base64.h"
      8 #include "base/logging.h"
      9 #include "base/memory/scoped_ptr.h"
     10 #include "base/sha1.h"
     11 #include "base/strings/string_util.h"
     12 #include "base/strings/stringprintf.h"
     13 #include "crypto/hmac.h"
     14 #include "crypto/random.h"
     15 
     16 namespace {
     17 
     18 base::StringPiece ByteVectorToSP(const std::vector<uint8>& vec) {
     19   if (vec.empty())
     20     return base::StringPiece();
     21 
     22   return base::StringPiece(reinterpret_cast<const char*>(&vec[0]), vec.size());
     23 }
     24 
     25 // This class needs to implement the same hashing and signing functions as the
     26 // Google Update server; for now, this is SHA-1 and HMAC-SHA1, but this may
     27 // change to SHA-256 in the near future.  For this reason, all primitives are
     28 // wrapped.  The name "SymSign" is used to mirror the CUP specification.
     29 size_t HashDigestSize() {
     30   return base::kSHA1Length;
     31 }
     32 
     33 std::vector<uint8> Hash(const std::vector<uint8>& data) {
     34   std::vector<uint8> result(HashDigestSize());
     35   base::SHA1HashBytes(data.empty() ? NULL : &data[0],
     36                       data.size(),
     37                       &result[0]);
     38   return result;
     39 }
     40 
     41 std::vector<uint8> Hash(const base::StringPiece& sdata) {
     42   std::vector<uint8> result(HashDigestSize());
     43   base::SHA1HashBytes(sdata.empty() ?
     44                           NULL :
     45                           reinterpret_cast<const unsigned char*>(sdata.data()),
     46                       sdata.length(),
     47                       &result[0]);
     48   return result;
     49 }
     50 
     51 std::vector<uint8> SymConcat(uint8 id,
     52                              const std::vector<uint8>* h1,
     53                              const std::vector<uint8>* h2,
     54                              const std::vector<uint8>* h3) {
     55   std::vector<uint8> result;
     56   result.push_back(id);
     57   const std::vector<uint8>* args[] = { h1, h2, h3 };
     58   for (size_t i = 0; i != arraysize(args); ++i) {
     59     if (args[i]) {
     60       DCHECK_EQ(args[i]->size(), HashDigestSize());
     61       result.insert(result.end(), args[i]->begin(), args[i]->end());
     62     }
     63   }
     64 
     65   return result;
     66 }
     67 
     68 std::vector<uint8> SymSign(const std::vector<uint8>& key,
     69                            const std::vector<uint8>& hashes) {
     70   DCHECK(!key.empty());
     71   DCHECK(!hashes.empty());
     72 
     73   crypto::HMAC hmac(crypto::HMAC::SHA1);
     74   if (!hmac.Init(&key[0], key.size()))
     75     return std::vector<uint8>();
     76 
     77   std::vector<uint8> result(hmac.DigestLength());
     78   if (!hmac.Sign(ByteVectorToSP(hashes), &result[0], result.size()))
     79     return std::vector<uint8>();
     80 
     81   return result;
     82 }
     83 
     84 bool SymSignVerify(const std::vector<uint8>& key,
     85                    const std::vector<uint8>& hashes,
     86                    const std::vector<uint8>& server_proof) {
     87   DCHECK(!key.empty());
     88   DCHECK(!hashes.empty());
     89   DCHECK(!server_proof.empty());
     90 
     91   crypto::HMAC hmac(crypto::HMAC::SHA1);
     92   if (!hmac.Init(&key[0], key.size()))
     93     return false;
     94 
     95   return hmac.Verify(ByteVectorToSP(hashes), ByteVectorToSP(server_proof));
     96 }
     97 
     98 // RsaPad() is implemented as described in the CUP spec.  It is NOT a general
     99 // purpose padding algorithm.
    100 std::vector<uint8> RsaPad(size_t rsa_key_size,
    101                           const std::vector<uint8>& entropy) {
    102   DCHECK_GE(rsa_key_size, HashDigestSize());
    103 
    104   // The result gets padded with zeros if the result size is greater than
    105   // the size of the buffer provided by the caller.
    106   std::vector<uint8> result(entropy);
    107   result.resize(rsa_key_size - HashDigestSize());
    108 
    109   // For use with RSA, the input needs to be smaller than the RSA modulus,
    110   // which has always the msb set.
    111   result[0] &= 127;  // Reset msb
    112   result[0] |= 64;   // Set second highest bit.
    113 
    114   std::vector<uint8> digest = Hash(result);
    115   result.insert(result.end(), digest.begin(), digest.end());
    116   DCHECK_EQ(result.size(), rsa_key_size);
    117   return result;
    118 }
    119 
    120 // CUP passes the versioned secret in the query portion of the URL for the
    121 // update check service -- and that means that a URL-safe variant of Base64 is
    122 // needed.  Call the standard Base64 encoder/decoder and then apply fixups.
    123 std::string UrlSafeB64Encode(const std::vector<uint8>& data) {
    124   std::string result;
    125   if (!base::Base64Encode(ByteVectorToSP(data), &result))
    126     return std::string();
    127 
    128   // Do an tr|+/|-_| on the output, and strip any '=' padding.
    129   for (std::string::iterator it = result.begin(); it != result.end(); ++it) {
    130     switch (*it) {
    131       case '+':
    132         *it = '-';
    133         break;
    134       case '/':
    135         *it = '_';
    136         break;
    137       default:
    138         break;
    139     }
    140   }
    141   TrimString(result, "=", &result);
    142 
    143   return result;
    144 }
    145 
    146 std::vector<uint8> UrlSafeB64Decode(const base::StringPiece& input) {
    147   std::string unsafe(input.begin(), input.end());
    148   for (std::string::iterator it = unsafe.begin(); it != unsafe.end(); ++it) {
    149     switch (*it) {
    150       case '-':
    151         *it = '+';
    152         break;
    153       case '_':
    154         *it = '/';
    155         break;
    156       default:
    157         break;
    158     }
    159   }
    160   if (unsafe.length() % 4)
    161     unsafe.append(4 - (unsafe.length() % 4), '=');
    162 
    163   std::string decoded;
    164   if (!base::Base64Decode(unsafe, &decoded))
    165     return std::vector<uint8>();
    166 
    167   return std::vector<uint8>(decoded.begin(), decoded.end());
    168 }
    169 
    170 }  // end namespace
    171 
    172 ClientUpdateProtocol::ClientUpdateProtocol(int key_version)
    173     : pub_key_version_(key_version) {
    174 }
    175 
    176 scoped_ptr<ClientUpdateProtocol> ClientUpdateProtocol::Create(
    177     int key_version,
    178     const base::StringPiece& public_key) {
    179   DCHECK_GT(key_version, 0);
    180   DCHECK(!public_key.empty());
    181 
    182   scoped_ptr<ClientUpdateProtocol> result(
    183       new ClientUpdateProtocol(key_version));
    184   if (!result)
    185     return scoped_ptr<ClientUpdateProtocol>();
    186 
    187   if (!result->LoadPublicKey(public_key))
    188     return scoped_ptr<ClientUpdateProtocol>();
    189 
    190   if (!result->BuildRandomSharedKey())
    191     return scoped_ptr<ClientUpdateProtocol>();
    192 
    193   return result.Pass();
    194 }
    195 
    196 std::string ClientUpdateProtocol::GetVersionedSecret() const {
    197   return base::StringPrintf("%d:%s",
    198                             pub_key_version_,
    199                             UrlSafeB64Encode(encrypted_key_source_).c_str());
    200 }
    201 
    202 bool ClientUpdateProtocol::SignRequest(const base::StringPiece& url,
    203                                        const base::StringPiece& request_body,
    204                                        std::string* client_proof) {
    205   DCHECK(!encrypted_key_source_.empty());
    206   DCHECK(!url.empty());
    207   DCHECK(!request_body.empty());
    208   DCHECK(client_proof);
    209 
    210   // Compute the challenge hash:
    211   //   hw = HASH(HASH(v|w)|HASH(request_url)|HASH(body)).
    212   // Keep the challenge hash for later to validate the server's response.
    213   std::vector<uint8> internal_hashes;
    214 
    215   std::vector<uint8> h;
    216   h = Hash(GetVersionedSecret());
    217   internal_hashes.insert(internal_hashes.end(), h.begin(), h.end());
    218   h = Hash(url);
    219   internal_hashes.insert(internal_hashes.end(), h.begin(), h.end());
    220   h = Hash(request_body);
    221   internal_hashes.insert(internal_hashes.end(), h.begin(), h.end());
    222   DCHECK_EQ(internal_hashes.size(), 3 * HashDigestSize());
    223 
    224   client_challenge_hash_ = Hash(internal_hashes);
    225 
    226   // Sign the challenge hash (hw) using the shared key (sk) to produce the
    227   // client proof (cp).
    228   std::vector<uint8> raw_client_proof =
    229       SymSign(shared_key_, SymConcat(3, &client_challenge_hash_, NULL, NULL));
    230   if (raw_client_proof.empty()) {
    231     client_challenge_hash_.clear();
    232     return false;
    233   }
    234 
    235   *client_proof = UrlSafeB64Encode(raw_client_proof);
    236   return true;
    237 }
    238 
    239 bool ClientUpdateProtocol::ValidateResponse(
    240     const base::StringPiece& response_body,
    241     const base::StringPiece& server_cookie,
    242     const base::StringPiece& server_proof) {
    243   DCHECK(!client_challenge_hash_.empty());
    244 
    245   if (response_body.empty() || server_cookie.empty() || server_proof.empty())
    246     return false;
    247 
    248   // Decode the server proof from URL-safe Base64 to a binary HMAC for the
    249   // response.
    250   std::vector<uint8> sp_decoded = UrlSafeB64Decode(server_proof);
    251   if (sp_decoded.empty())
    252     return false;
    253 
    254   // If the request was received by the server, the server will use its
    255   // private key to decrypt |w_|, yielding the original contents of |r_|.
    256   // The server can then recreate |sk_|, compute |hw_|, and SymSign(3|hw)
    257   // to ensure that the cp matches the contents.  It will then use |sk_|
    258   // to sign its response, producing the server proof |sp|.
    259   std::vector<uint8> hm = Hash(response_body);
    260   std::vector<uint8> hc = Hash(server_cookie);
    261   return SymSignVerify(shared_key_,
    262                        SymConcat(1, &client_challenge_hash_, &hm, &hc),
    263                        sp_decoded);
    264 }
    265 
    266 bool ClientUpdateProtocol::BuildRandomSharedKey() {
    267   DCHECK_GE(PublicKeyLength(), HashDigestSize());
    268 
    269   // Start by generating some random bytes that are suitable to be encrypted;
    270   // this will be the source of the shared HMAC key that client and server use.
    271   // (CUP specification calls this "r".)
    272   std::vector<uint8> key_source;
    273   std::vector<uint8> entropy(PublicKeyLength() - HashDigestSize());
    274   crypto::RandBytes(&entropy[0], entropy.size());
    275   key_source = RsaPad(PublicKeyLength(), entropy);
    276 
    277   return DeriveSharedKey(key_source);
    278 }
    279 
    280 bool ClientUpdateProtocol::SetSharedKeyForTesting(
    281   const base::StringPiece& key_source) {
    282   DCHECK_EQ(key_source.length(), PublicKeyLength());
    283 
    284   return DeriveSharedKey(std::vector<uint8>(key_source.begin(),
    285                                             key_source.end()));
    286 }
    287 
    288 bool ClientUpdateProtocol::DeriveSharedKey(const std::vector<uint8>& source) {
    289   DCHECK(!source.empty());
    290   DCHECK_GE(source.size(), HashDigestSize());
    291   DCHECK_EQ(source.size(), PublicKeyLength());
    292 
    293   // Hash the key source (r) to generate a new shared HMAC key (sk').
    294   shared_key_ = Hash(source);
    295 
    296   // Encrypt the key source (r) using the public key (pk[v]) to generate the
    297   // encrypted key source (w).
    298   if (!EncryptKeySource(source))
    299     return false;
    300   if (encrypted_key_source_.size() != PublicKeyLength())
    301     return false;
    302 
    303   return true;
    304 }
    305 
    306