Home | History | Annotate | Download | only in protocol
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "remoting/protocol/ssl_hmac_channel_authenticator.h"
      6 
      7 #include "base/bind.h"
      8 #include "base/bind_helpers.h"
      9 #include "crypto/secure_util.h"
     10 #include "net/base/host_port_pair.h"
     11 #include "net/base/io_buffer.h"
     12 #include "net/base/net_errors.h"
     13 #include "net/cert/cert_verifier.h"
     14 #include "net/cert/x509_certificate.h"
     15 #include "net/http/transport_security_state.h"
     16 #include "net/socket/client_socket_factory.h"
     17 #include "net/socket/ssl_client_socket.h"
     18 #include "net/socket/ssl_server_socket.h"
     19 #include "net/ssl/ssl_config_service.h"
     20 #include "remoting/base/rsa_key_pair.h"
     21 #include "remoting/protocol/auth_util.h"
     22 
     23 namespace remoting {
     24 namespace protocol {
     25 
     26 // static
     27 scoped_ptr<SslHmacChannelAuthenticator>
     28 SslHmacChannelAuthenticator::CreateForClient(
     29       const std::string& remote_cert,
     30       const std::string& auth_key) {
     31   scoped_ptr<SslHmacChannelAuthenticator> result(
     32       new SslHmacChannelAuthenticator(auth_key));
     33   result->remote_cert_ = remote_cert;
     34   return result.Pass();
     35 }
     36 
     37 scoped_ptr<SslHmacChannelAuthenticator>
     38 SslHmacChannelAuthenticator::CreateForHost(
     39     const std::string& local_cert,
     40     scoped_refptr<RsaKeyPair> key_pair,
     41     const std::string& auth_key) {
     42   scoped_ptr<SslHmacChannelAuthenticator> result(
     43       new SslHmacChannelAuthenticator(auth_key));
     44   result->local_cert_ = local_cert;
     45   result->local_key_pair_ = key_pair;
     46   return result.Pass();
     47 }
     48 
     49 SslHmacChannelAuthenticator::SslHmacChannelAuthenticator(
     50     const std::string& auth_key)
     51     : auth_key_(auth_key) {
     52 }
     53 
     54 SslHmacChannelAuthenticator::~SslHmacChannelAuthenticator() {
     55 }
     56 
     57 void SslHmacChannelAuthenticator::SecureAndAuthenticate(
     58     scoped_ptr<net::StreamSocket> socket, const DoneCallback& done_callback) {
     59   DCHECK(CalledOnValidThread());
     60   DCHECK(socket->IsConnected());
     61 
     62   done_callback_ = done_callback;
     63 
     64   int result;
     65   if (is_ssl_server()) {
     66     scoped_refptr<net::X509Certificate> cert =
     67         net::X509Certificate::CreateFromBytes(
     68             local_cert_.data(), local_cert_.length());
     69     if (!cert.get()) {
     70       LOG(ERROR) << "Failed to parse X509Certificate";
     71       NotifyError(net::ERR_FAILED);
     72       return;
     73     }
     74 
     75     net::SSLConfig ssl_config;
     76     net::SSLServerSocket* server_socket =
     77         net::CreateSSLServerSocket(socket.release(),
     78                                    cert.get(),
     79                                    local_key_pair_->private_key(),
     80                                    ssl_config);
     81     socket_.reset(server_socket);
     82 
     83     result = server_socket->Handshake(base::Bind(
     84         &SslHmacChannelAuthenticator::OnConnected, base::Unretained(this)));
     85   } else {
     86     cert_verifier_.reset(net::CertVerifier::CreateDefault());
     87     transport_security_state_.reset(new net::TransportSecurityState);
     88 
     89     net::SSLConfig::CertAndStatus cert_and_status;
     90     cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID;
     91     cert_and_status.der_cert = remote_cert_;
     92 
     93     net::SSLConfig ssl_config;
     94     // Certificate verification and revocation checking are not needed
     95     // because we use self-signed certs. Disable it so that the SSL
     96     // layer doesn't try to initialize OCSP (OCSP works only on the IO
     97     // thread).
     98     ssl_config.cert_io_enabled = false;
     99     ssl_config.rev_checking_enabled = false;
    100     ssl_config.allowed_bad_certs.push_back(cert_and_status);
    101 
    102     net::HostPortPair host_and_port(kSslFakeHostName, 0);
    103     net::SSLClientSocketContext context;
    104     context.cert_verifier = cert_verifier_.get();
    105     context.transport_security_state = transport_security_state_.get();
    106     socket_.reset(
    107         net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket(
    108             socket.release(), host_and_port, ssl_config, context));
    109 
    110     result = socket_->Connect(
    111         base::Bind(&SslHmacChannelAuthenticator::OnConnected,
    112                    base::Unretained(this)));
    113   }
    114 
    115   if (result == net::ERR_IO_PENDING)
    116     return;
    117 
    118   OnConnected(result);
    119 }
    120 
    121 bool SslHmacChannelAuthenticator::is_ssl_server() {
    122   return local_key_pair_.get() != NULL;
    123 }
    124 
    125 void SslHmacChannelAuthenticator::OnConnected(int result) {
    126   if (result != net::OK) {
    127     LOG(WARNING) << "Failed to establish SSL connection";
    128     NotifyError(result);
    129     return;
    130   }
    131 
    132   // Generate authentication digest to write to the socket.
    133   std::string auth_bytes = GetAuthBytes(
    134       socket_.get(), is_ssl_server() ?
    135       kHostAuthSslExporterLabel : kClientAuthSslExporterLabel, auth_key_);
    136   if (auth_bytes.empty()) {
    137     NotifyError(net::ERR_FAILED);
    138     return;
    139   }
    140 
    141   // Allocate a buffer to write the digest.
    142   auth_write_buf_ = new net::DrainableIOBuffer(
    143       new net::StringIOBuffer(auth_bytes), auth_bytes.size());
    144 
    145   // Read an incoming token.
    146   auth_read_buf_ = new net::GrowableIOBuffer();
    147   auth_read_buf_->SetCapacity(kAuthDigestLength);
    148 
    149   // If WriteAuthenticationBytes() results in |done_callback_| being
    150   // called then we must not do anything else because this object may
    151   // be destroyed at that point.
    152   bool callback_called = false;
    153   WriteAuthenticationBytes(&callback_called);
    154   if (!callback_called)
    155     ReadAuthenticationBytes();
    156 }
    157 
    158 void SslHmacChannelAuthenticator::WriteAuthenticationBytes(
    159     bool* callback_called) {
    160   while (true) {
    161     int result = socket_->Write(
    162         auth_write_buf_.get(),
    163         auth_write_buf_->BytesRemaining(),
    164         base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesWritten,
    165                    base::Unretained(this)));
    166     if (result == net::ERR_IO_PENDING)
    167       break;
    168     if (!HandleAuthBytesWritten(result, callback_called))
    169       break;
    170   }
    171 }
    172 
    173 void SslHmacChannelAuthenticator::OnAuthBytesWritten(int result) {
    174   DCHECK(CalledOnValidThread());
    175 
    176   if (HandleAuthBytesWritten(result, NULL))
    177     WriteAuthenticationBytes(NULL);
    178 }
    179 
    180 bool SslHmacChannelAuthenticator::HandleAuthBytesWritten(
    181     int result, bool* callback_called) {
    182   if (result <= 0) {
    183     LOG(ERROR) << "Error writing authentication: " << result;
    184     if (callback_called)
    185       *callback_called = false;
    186     NotifyError(result);
    187     return false;
    188   }
    189 
    190   auth_write_buf_->DidConsume(result);
    191   if (auth_write_buf_->BytesRemaining() > 0)
    192     return true;
    193 
    194   auth_write_buf_ = NULL;
    195   CheckDone(callback_called);
    196   return false;
    197 }
    198 
    199 void SslHmacChannelAuthenticator::ReadAuthenticationBytes() {
    200   while (true) {
    201     int result =
    202         socket_->Read(auth_read_buf_.get(),
    203                       auth_read_buf_->RemainingCapacity(),
    204                       base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesRead,
    205                                  base::Unretained(this)));
    206     if (result == net::ERR_IO_PENDING)
    207       break;
    208     if (!HandleAuthBytesRead(result))
    209       break;
    210   }
    211 }
    212 
    213 void SslHmacChannelAuthenticator::OnAuthBytesRead(int result) {
    214   DCHECK(CalledOnValidThread());
    215 
    216   if (HandleAuthBytesRead(result))
    217     ReadAuthenticationBytes();
    218 }
    219 
    220 bool SslHmacChannelAuthenticator::HandleAuthBytesRead(int read_result) {
    221   if (read_result <= 0) {
    222     NotifyError(read_result);
    223     return false;
    224   }
    225 
    226   auth_read_buf_->set_offset(auth_read_buf_->offset() + read_result);
    227   if (auth_read_buf_->RemainingCapacity() > 0)
    228     return true;
    229 
    230   if (!VerifyAuthBytes(std::string(
    231           auth_read_buf_->StartOfBuffer(),
    232           auth_read_buf_->StartOfBuffer() + kAuthDigestLength))) {
    233     LOG(WARNING) << "Mismatched authentication";
    234     NotifyError(net::ERR_FAILED);
    235     return false;
    236   }
    237 
    238   auth_read_buf_ = NULL;
    239   CheckDone(NULL);
    240   return false;
    241 }
    242 
    243 bool SslHmacChannelAuthenticator::VerifyAuthBytes(
    244     const std::string& received_auth_bytes) {
    245   DCHECK(received_auth_bytes.length() == kAuthDigestLength);
    246 
    247   // Compute expected auth bytes.
    248   std::string auth_bytes = GetAuthBytes(
    249       socket_.get(), is_ssl_server() ?
    250       kClientAuthSslExporterLabel : kHostAuthSslExporterLabel, auth_key_);
    251   if (auth_bytes.empty())
    252     return false;
    253 
    254   return crypto::SecureMemEqual(received_auth_bytes.data(),
    255                                 &(auth_bytes[0]), kAuthDigestLength);
    256 }
    257 
    258 void SslHmacChannelAuthenticator::CheckDone(bool* callback_called) {
    259   if (auth_write_buf_.get() == NULL && auth_read_buf_.get() == NULL) {
    260     DCHECK(socket_.get() != NULL);
    261     if (callback_called)
    262       *callback_called = true;
    263     done_callback_.Run(net::OK, socket_.PassAs<net::StreamSocket>());
    264   }
    265 }
    266 
    267 void SslHmacChannelAuthenticator::NotifyError(int error) {
    268   done_callback_.Run(static_cast<net::Error>(error),
    269                      scoped_ptr<net::StreamSocket>());
    270 }
    271 
    272 }  // namespace protocol
    273 }  // namespace remoting
    274