Home | History | Annotate | Download | only in protocol
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "remoting/protocol/ssl_hmac_channel_authenticator.h"
      6 
      7 #include "base/bind.h"
      8 #include "base/bind_helpers.h"
      9 #include "crypto/secure_util.h"
     10 #include "net/base/host_port_pair.h"
     11 #include "net/base/io_buffer.h"
     12 #include "net/base/net_errors.h"
     13 #include "net/cert/x509_certificate.h"
     14 #include "net/http/transport_security_state.h"
     15 #include "net/socket/client_socket_factory.h"
     16 #include "net/socket/client_socket_handle.h"
     17 #include "net/socket/ssl_client_socket.h"
     18 #include "net/socket/ssl_client_socket_openssl.h"
     19 #include "net/socket/ssl_server_socket.h"
     20 #include "net/ssl/ssl_config_service.h"
     21 #include "remoting/base/rsa_key_pair.h"
     22 #include "remoting/protocol/auth_util.h"
     23 
     24 namespace remoting {
     25 namespace protocol {
     26 
     27 // static
     28 scoped_ptr<SslHmacChannelAuthenticator>
     29 SslHmacChannelAuthenticator::CreateForClient(
     30       const std::string& remote_cert,
     31       const std::string& auth_key) {
     32   scoped_ptr<SslHmacChannelAuthenticator> result(
     33       new SslHmacChannelAuthenticator(auth_key));
     34   result->remote_cert_ = remote_cert;
     35   return result.Pass();
     36 }
     37 
     38 scoped_ptr<SslHmacChannelAuthenticator>
     39 SslHmacChannelAuthenticator::CreateForHost(
     40     const std::string& local_cert,
     41     scoped_refptr<RsaKeyPair> key_pair,
     42     const std::string& auth_key) {
     43   scoped_ptr<SslHmacChannelAuthenticator> result(
     44       new SslHmacChannelAuthenticator(auth_key));
     45   result->local_cert_ = local_cert;
     46   result->local_key_pair_ = key_pair;
     47   return result.Pass();
     48 }
     49 
     50 SslHmacChannelAuthenticator::SslHmacChannelAuthenticator(
     51     const std::string& auth_key)
     52     : auth_key_(auth_key) {
     53 }
     54 
     55 SslHmacChannelAuthenticator::~SslHmacChannelAuthenticator() {
     56 }
     57 
     58 void SslHmacChannelAuthenticator::SecureAndAuthenticate(
     59     scoped_ptr<net::StreamSocket> socket, const DoneCallback& done_callback) {
     60   DCHECK(CalledOnValidThread());
     61   DCHECK(socket->IsConnected());
     62 
     63   done_callback_ = done_callback;
     64 
     65   int result;
     66   if (is_ssl_server()) {
     67 #if defined(OS_NACL)
     68     // Client plugin doesn't use server SSL sockets, and so SSLServerSocket
     69     // implementation is not compiled for NaCl as part of net_nacl.
     70     NOTREACHED();
     71     result = net::ERR_FAILED;
     72 #else
     73     scoped_refptr<net::X509Certificate> cert =
     74         net::X509Certificate::CreateFromBytes(
     75             local_cert_.data(), local_cert_.length());
     76     if (!cert.get()) {
     77       LOG(ERROR) << "Failed to parse X509Certificate";
     78       NotifyError(net::ERR_FAILED);
     79       return;
     80     }
     81 
     82     net::SSLConfig ssl_config;
     83     ssl_config.require_forward_secrecy = true;
     84 
     85     scoped_ptr<net::SSLServerSocket> server_socket =
     86         net::CreateSSLServerSocket(socket.Pass(),
     87                                    cert.get(),
     88                                    local_key_pair_->private_key(),
     89                                    ssl_config);
     90     net::SSLServerSocket* raw_server_socket = server_socket.get();
     91     socket_ = server_socket.Pass();
     92     result = raw_server_socket->Handshake(
     93         base::Bind(&SslHmacChannelAuthenticator::OnConnected,
     94                    base::Unretained(this)));
     95 #endif
     96   } else {
     97     transport_security_state_.reset(new net::TransportSecurityState);
     98 
     99     net::SSLConfig::CertAndStatus cert_and_status;
    100     cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID;
    101     cert_and_status.der_cert = remote_cert_;
    102 
    103     net::SSLConfig ssl_config;
    104     // Certificate verification and revocation checking are not needed
    105     // because we use self-signed certs. Disable it so that the SSL
    106     // layer doesn't try to initialize OCSP (OCSP works only on the IO
    107     // thread).
    108     ssl_config.cert_io_enabled = false;
    109     ssl_config.rev_checking_enabled = false;
    110     ssl_config.allowed_bad_certs.push_back(cert_and_status);
    111 
    112     net::HostPortPair host_and_port(kSslFakeHostName, 0);
    113     net::SSLClientSocketContext context;
    114     context.transport_security_state = transport_security_state_.get();
    115     scoped_ptr<net::ClientSocketHandle> socket_handle(
    116         new net::ClientSocketHandle);
    117     socket_handle->SetSocket(socket.Pass());
    118 
    119 #if defined(OS_NACL)
    120     // net_nacl doesn't include ClientSocketFactory.
    121     socket_.reset(new net::SSLClientSocketOpenSSL(
    122         socket_handle.Pass(), host_and_port, ssl_config, context));
    123 #else
    124     socket_ =
    125         net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket(
    126             socket_handle.Pass(), host_and_port, ssl_config, context);
    127 #endif
    128 
    129     result = socket_->Connect(
    130         base::Bind(&SslHmacChannelAuthenticator::OnConnected,
    131                    base::Unretained(this)));
    132   }
    133 
    134   if (result == net::ERR_IO_PENDING)
    135     return;
    136 
    137   OnConnected(result);
    138 }
    139 
    140 bool SslHmacChannelAuthenticator::is_ssl_server() {
    141   return local_key_pair_.get() != NULL;
    142 }
    143 
    144 void SslHmacChannelAuthenticator::OnConnected(int result) {
    145   if (result != net::OK) {
    146     LOG(WARNING) << "Failed to establish SSL connection";
    147     NotifyError(result);
    148     return;
    149   }
    150 
    151   // Generate authentication digest to write to the socket.
    152   std::string auth_bytes = GetAuthBytes(
    153       socket_.get(), is_ssl_server() ?
    154       kHostAuthSslExporterLabel : kClientAuthSslExporterLabel, auth_key_);
    155   if (auth_bytes.empty()) {
    156     NotifyError(net::ERR_FAILED);
    157     return;
    158   }
    159 
    160   // Allocate a buffer to write the digest.
    161   auth_write_buf_ = new net::DrainableIOBuffer(
    162       new net::StringIOBuffer(auth_bytes), auth_bytes.size());
    163 
    164   // Read an incoming token.
    165   auth_read_buf_ = new net::GrowableIOBuffer();
    166   auth_read_buf_->SetCapacity(kAuthDigestLength);
    167 
    168   // If WriteAuthenticationBytes() results in |done_callback_| being
    169   // called then we must not do anything else because this object may
    170   // be destroyed at that point.
    171   bool callback_called = false;
    172   WriteAuthenticationBytes(&callback_called);
    173   if (!callback_called)
    174     ReadAuthenticationBytes();
    175 }
    176 
    177 void SslHmacChannelAuthenticator::WriteAuthenticationBytes(
    178     bool* callback_called) {
    179   while (true) {
    180     int result = socket_->Write(
    181         auth_write_buf_.get(),
    182         auth_write_buf_->BytesRemaining(),
    183         base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesWritten,
    184                    base::Unretained(this)));
    185     if (result == net::ERR_IO_PENDING)
    186       break;
    187     if (!HandleAuthBytesWritten(result, callback_called))
    188       break;
    189   }
    190 }
    191 
    192 void SslHmacChannelAuthenticator::OnAuthBytesWritten(int result) {
    193   DCHECK(CalledOnValidThread());
    194 
    195   if (HandleAuthBytesWritten(result, NULL))
    196     WriteAuthenticationBytes(NULL);
    197 }
    198 
    199 bool SslHmacChannelAuthenticator::HandleAuthBytesWritten(
    200     int result, bool* callback_called) {
    201   if (result <= 0) {
    202     LOG(ERROR) << "Error writing authentication: " << result;
    203     if (callback_called)
    204       *callback_called = false;
    205     NotifyError(result);
    206     return false;
    207   }
    208 
    209   auth_write_buf_->DidConsume(result);
    210   if (auth_write_buf_->BytesRemaining() > 0)
    211     return true;
    212 
    213   auth_write_buf_ = NULL;
    214   CheckDone(callback_called);
    215   return false;
    216 }
    217 
    218 void SslHmacChannelAuthenticator::ReadAuthenticationBytes() {
    219   while (true) {
    220     int result =
    221         socket_->Read(auth_read_buf_.get(),
    222                       auth_read_buf_->RemainingCapacity(),
    223                       base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesRead,
    224                                  base::Unretained(this)));
    225     if (result == net::ERR_IO_PENDING)
    226       break;
    227     if (!HandleAuthBytesRead(result))
    228       break;
    229   }
    230 }
    231 
    232 void SslHmacChannelAuthenticator::OnAuthBytesRead(int result) {
    233   DCHECK(CalledOnValidThread());
    234 
    235   if (HandleAuthBytesRead(result))
    236     ReadAuthenticationBytes();
    237 }
    238 
    239 bool SslHmacChannelAuthenticator::HandleAuthBytesRead(int read_result) {
    240   if (read_result <= 0) {
    241     NotifyError(read_result);
    242     return false;
    243   }
    244 
    245   auth_read_buf_->set_offset(auth_read_buf_->offset() + read_result);
    246   if (auth_read_buf_->RemainingCapacity() > 0)
    247     return true;
    248 
    249   if (!VerifyAuthBytes(std::string(
    250           auth_read_buf_->StartOfBuffer(),
    251           auth_read_buf_->StartOfBuffer() + kAuthDigestLength))) {
    252     LOG(WARNING) << "Mismatched authentication";
    253     NotifyError(net::ERR_FAILED);
    254     return false;
    255   }
    256 
    257   auth_read_buf_ = NULL;
    258   CheckDone(NULL);
    259   return false;
    260 }
    261 
    262 bool SslHmacChannelAuthenticator::VerifyAuthBytes(
    263     const std::string& received_auth_bytes) {
    264   DCHECK(received_auth_bytes.length() == kAuthDigestLength);
    265 
    266   // Compute expected auth bytes.
    267   std::string auth_bytes = GetAuthBytes(
    268       socket_.get(), is_ssl_server() ?
    269       kClientAuthSslExporterLabel : kHostAuthSslExporterLabel, auth_key_);
    270   if (auth_bytes.empty())
    271     return false;
    272 
    273   return crypto::SecureMemEqual(received_auth_bytes.data(),
    274                                 &(auth_bytes[0]), kAuthDigestLength);
    275 }
    276 
    277 void SslHmacChannelAuthenticator::CheckDone(bool* callback_called) {
    278   if (auth_write_buf_.get() == NULL && auth_read_buf_.get() == NULL) {
    279     DCHECK(socket_.get() != NULL);
    280     if (callback_called)
    281       *callback_called = true;
    282     done_callback_.Run(net::OK, socket_.PassAs<net::StreamSocket>());
    283   }
    284 }
    285 
    286 void SslHmacChannelAuthenticator::NotifyError(int error) {
    287   done_callback_.Run(static_cast<net::Error>(error),
    288                      scoped_ptr<net::StreamSocket>());
    289 }
    290 
    291 }  // namespace protocol
    292 }  // namespace remoting
    293