Home | History | Annotate | Download | only in protocol
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "remoting/protocol/ssl_hmac_channel_authenticator.h"
      6 
      7 #include "base/bind.h"
      8 #include "base/bind_helpers.h"
      9 #include "crypto/secure_util.h"
     10 #include "net/base/host_port_pair.h"
     11 #include "net/base/io_buffer.h"
     12 #include "net/base/net_errors.h"
     13 #include "net/cert/cert_verifier.h"
     14 #include "net/cert/x509_certificate.h"
     15 #include "net/http/transport_security_state.h"
     16 #include "net/socket/client_socket_factory.h"
     17 #include "net/socket/client_socket_handle.h"
     18 #include "net/socket/ssl_client_socket.h"
     19 #include "net/socket/ssl_server_socket.h"
     20 #include "net/ssl/ssl_config_service.h"
     21 #include "remoting/base/rsa_key_pair.h"
     22 #include "remoting/protocol/auth_util.h"
     23 
     24 namespace remoting {
     25 namespace protocol {
     26 
     27 // static
     28 scoped_ptr<SslHmacChannelAuthenticator>
     29 SslHmacChannelAuthenticator::CreateForClient(
     30       const std::string& remote_cert,
     31       const std::string& auth_key) {
     32   scoped_ptr<SslHmacChannelAuthenticator> result(
     33       new SslHmacChannelAuthenticator(auth_key));
     34   result->remote_cert_ = remote_cert;
     35   return result.Pass();
     36 }
     37 
     38 scoped_ptr<SslHmacChannelAuthenticator>
     39 SslHmacChannelAuthenticator::CreateForHost(
     40     const std::string& local_cert,
     41     scoped_refptr<RsaKeyPair> key_pair,
     42     const std::string& auth_key) {
     43   scoped_ptr<SslHmacChannelAuthenticator> result(
     44       new SslHmacChannelAuthenticator(auth_key));
     45   result->local_cert_ = local_cert;
     46   result->local_key_pair_ = key_pair;
     47   return result.Pass();
     48 }
     49 
     50 SslHmacChannelAuthenticator::SslHmacChannelAuthenticator(
     51     const std::string& auth_key)
     52     : auth_key_(auth_key) {
     53 }
     54 
     55 SslHmacChannelAuthenticator::~SslHmacChannelAuthenticator() {
     56 }
     57 
     58 void SslHmacChannelAuthenticator::SecureAndAuthenticate(
     59     scoped_ptr<net::StreamSocket> socket, const DoneCallback& done_callback) {
     60   DCHECK(CalledOnValidThread());
     61   DCHECK(socket->IsConnected());
     62 
     63   done_callback_ = done_callback;
     64 
     65   int result;
     66   if (is_ssl_server()) {
     67     scoped_refptr<net::X509Certificate> cert =
     68         net::X509Certificate::CreateFromBytes(
     69             local_cert_.data(), local_cert_.length());
     70     if (!cert.get()) {
     71       LOG(ERROR) << "Failed to parse X509Certificate";
     72       NotifyError(net::ERR_FAILED);
     73       return;
     74     }
     75 
     76     net::SSLConfig ssl_config;
     77     ssl_config.require_forward_secrecy = true;
     78 
     79     scoped_ptr<net::SSLServerSocket> server_socket =
     80         net::CreateSSLServerSocket(socket.Pass(),
     81                                    cert.get(),
     82                                    local_key_pair_->private_key(),
     83                                    ssl_config);
     84     net::SSLServerSocket* raw_server_socket = server_socket.get();
     85     socket_ = server_socket.Pass();
     86     result = raw_server_socket->Handshake(
     87         base::Bind(&SslHmacChannelAuthenticator::OnConnected,
     88                    base::Unretained(this)));
     89   } else {
     90     cert_verifier_.reset(net::CertVerifier::CreateDefault());
     91     transport_security_state_.reset(new net::TransportSecurityState);
     92 
     93     net::SSLConfig::CertAndStatus cert_and_status;
     94     cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID;
     95     cert_and_status.der_cert = remote_cert_;
     96 
     97     net::SSLConfig ssl_config;
     98     // Certificate verification and revocation checking are not needed
     99     // because we use self-signed certs. Disable it so that the SSL
    100     // layer doesn't try to initialize OCSP (OCSP works only on the IO
    101     // thread).
    102     ssl_config.cert_io_enabled = false;
    103     ssl_config.rev_checking_enabled = false;
    104     ssl_config.allowed_bad_certs.push_back(cert_and_status);
    105 
    106     net::HostPortPair host_and_port(kSslFakeHostName, 0);
    107     net::SSLClientSocketContext context;
    108     context.cert_verifier = cert_verifier_.get();
    109     context.transport_security_state = transport_security_state_.get();
    110     scoped_ptr<net::ClientSocketHandle> connection(new net::ClientSocketHandle);
    111     connection->SetSocket(socket.Pass());
    112     socket_ =
    113         net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket(
    114             connection.Pass(), host_and_port, ssl_config, context);
    115 
    116     result = socket_->Connect(
    117         base::Bind(&SslHmacChannelAuthenticator::OnConnected,
    118                    base::Unretained(this)));
    119   }
    120 
    121   if (result == net::ERR_IO_PENDING)
    122     return;
    123 
    124   OnConnected(result);
    125 }
    126 
    127 bool SslHmacChannelAuthenticator::is_ssl_server() {
    128   return local_key_pair_.get() != NULL;
    129 }
    130 
    131 void SslHmacChannelAuthenticator::OnConnected(int result) {
    132   if (result != net::OK) {
    133     LOG(WARNING) << "Failed to establish SSL connection";
    134     NotifyError(result);
    135     return;
    136   }
    137 
    138   // Generate authentication digest to write to the socket.
    139   std::string auth_bytes = GetAuthBytes(
    140       socket_.get(), is_ssl_server() ?
    141       kHostAuthSslExporterLabel : kClientAuthSslExporterLabel, auth_key_);
    142   if (auth_bytes.empty()) {
    143     NotifyError(net::ERR_FAILED);
    144     return;
    145   }
    146 
    147   // Allocate a buffer to write the digest.
    148   auth_write_buf_ = new net::DrainableIOBuffer(
    149       new net::StringIOBuffer(auth_bytes), auth_bytes.size());
    150 
    151   // Read an incoming token.
    152   auth_read_buf_ = new net::GrowableIOBuffer();
    153   auth_read_buf_->SetCapacity(kAuthDigestLength);
    154 
    155   // If WriteAuthenticationBytes() results in |done_callback_| being
    156   // called then we must not do anything else because this object may
    157   // be destroyed at that point.
    158   bool callback_called = false;
    159   WriteAuthenticationBytes(&callback_called);
    160   if (!callback_called)
    161     ReadAuthenticationBytes();
    162 }
    163 
    164 void SslHmacChannelAuthenticator::WriteAuthenticationBytes(
    165     bool* callback_called) {
    166   while (true) {
    167     int result = socket_->Write(
    168         auth_write_buf_.get(),
    169         auth_write_buf_->BytesRemaining(),
    170         base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesWritten,
    171                    base::Unretained(this)));
    172     if (result == net::ERR_IO_PENDING)
    173       break;
    174     if (!HandleAuthBytesWritten(result, callback_called))
    175       break;
    176   }
    177 }
    178 
    179 void SslHmacChannelAuthenticator::OnAuthBytesWritten(int result) {
    180   DCHECK(CalledOnValidThread());
    181 
    182   if (HandleAuthBytesWritten(result, NULL))
    183     WriteAuthenticationBytes(NULL);
    184 }
    185 
    186 bool SslHmacChannelAuthenticator::HandleAuthBytesWritten(
    187     int result, bool* callback_called) {
    188   if (result <= 0) {
    189     LOG(ERROR) << "Error writing authentication: " << result;
    190     if (callback_called)
    191       *callback_called = false;
    192     NotifyError(result);
    193     return false;
    194   }
    195 
    196   auth_write_buf_->DidConsume(result);
    197   if (auth_write_buf_->BytesRemaining() > 0)
    198     return true;
    199 
    200   auth_write_buf_ = NULL;
    201   CheckDone(callback_called);
    202   return false;
    203 }
    204 
    205 void SslHmacChannelAuthenticator::ReadAuthenticationBytes() {
    206   while (true) {
    207     int result =
    208         socket_->Read(auth_read_buf_.get(),
    209                       auth_read_buf_->RemainingCapacity(),
    210                       base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesRead,
    211                                  base::Unretained(this)));
    212     if (result == net::ERR_IO_PENDING)
    213       break;
    214     if (!HandleAuthBytesRead(result))
    215       break;
    216   }
    217 }
    218 
    219 void SslHmacChannelAuthenticator::OnAuthBytesRead(int result) {
    220   DCHECK(CalledOnValidThread());
    221 
    222   if (HandleAuthBytesRead(result))
    223     ReadAuthenticationBytes();
    224 }
    225 
    226 bool SslHmacChannelAuthenticator::HandleAuthBytesRead(int read_result) {
    227   if (read_result <= 0) {
    228     NotifyError(read_result);
    229     return false;
    230   }
    231 
    232   auth_read_buf_->set_offset(auth_read_buf_->offset() + read_result);
    233   if (auth_read_buf_->RemainingCapacity() > 0)
    234     return true;
    235 
    236   if (!VerifyAuthBytes(std::string(
    237           auth_read_buf_->StartOfBuffer(),
    238           auth_read_buf_->StartOfBuffer() + kAuthDigestLength))) {
    239     LOG(WARNING) << "Mismatched authentication";
    240     NotifyError(net::ERR_FAILED);
    241     return false;
    242   }
    243 
    244   auth_read_buf_ = NULL;
    245   CheckDone(NULL);
    246   return false;
    247 }
    248 
    249 bool SslHmacChannelAuthenticator::VerifyAuthBytes(
    250     const std::string& received_auth_bytes) {
    251   DCHECK(received_auth_bytes.length() == kAuthDigestLength);
    252 
    253   // Compute expected auth bytes.
    254   std::string auth_bytes = GetAuthBytes(
    255       socket_.get(), is_ssl_server() ?
    256       kClientAuthSslExporterLabel : kHostAuthSslExporterLabel, auth_key_);
    257   if (auth_bytes.empty())
    258     return false;
    259 
    260   return crypto::SecureMemEqual(received_auth_bytes.data(),
    261                                 &(auth_bytes[0]), kAuthDigestLength);
    262 }
    263 
    264 void SslHmacChannelAuthenticator::CheckDone(bool* callback_called) {
    265   if (auth_write_buf_.get() == NULL && auth_read_buf_.get() == NULL) {
    266     DCHECK(socket_.get() != NULL);
    267     if (callback_called)
    268       *callback_called = true;
    269     done_callback_.Run(net::OK, socket_.PassAs<net::StreamSocket>());
    270   }
    271 }
    272 
    273 void SslHmacChannelAuthenticator::NotifyError(int error) {
    274   done_callback_.Run(static_cast<net::Error>(error),
    275                      scoped_ptr<net::StreamSocket>());
    276 }
    277 
    278 }  // namespace protocol
    279 }  // namespace remoting
    280