1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "remoting/protocol/ssl_hmac_channel_authenticator.h" 6 7 #include "base/bind.h" 8 #include "base/bind_helpers.h" 9 #include "crypto/secure_util.h" 10 #include "net/base/host_port_pair.h" 11 #include "net/base/io_buffer.h" 12 #include "net/base/net_errors.h" 13 #include "net/cert/x509_certificate.h" 14 #include "net/http/transport_security_state.h" 15 #include "net/socket/client_socket_factory.h" 16 #include "net/socket/client_socket_handle.h" 17 #include "net/socket/ssl_client_socket.h" 18 #include "net/socket/ssl_client_socket_openssl.h" 19 #include "net/socket/ssl_server_socket.h" 20 #include "net/ssl/ssl_config_service.h" 21 #include "remoting/base/rsa_key_pair.h" 22 #include "remoting/protocol/auth_util.h" 23 24 namespace remoting { 25 namespace protocol { 26 27 // static 28 scoped_ptr<SslHmacChannelAuthenticator> 29 SslHmacChannelAuthenticator::CreateForClient( 30 const std::string& remote_cert, 31 const std::string& auth_key) { 32 scoped_ptr<SslHmacChannelAuthenticator> result( 33 new SslHmacChannelAuthenticator(auth_key)); 34 result->remote_cert_ = remote_cert; 35 return result.Pass(); 36 } 37 38 scoped_ptr<SslHmacChannelAuthenticator> 39 SslHmacChannelAuthenticator::CreateForHost( 40 const std::string& local_cert, 41 scoped_refptr<RsaKeyPair> key_pair, 42 const std::string& auth_key) { 43 scoped_ptr<SslHmacChannelAuthenticator> result( 44 new SslHmacChannelAuthenticator(auth_key)); 45 result->local_cert_ = local_cert; 46 result->local_key_pair_ = key_pair; 47 return result.Pass(); 48 } 49 50 SslHmacChannelAuthenticator::SslHmacChannelAuthenticator( 51 const std::string& auth_key) 52 : auth_key_(auth_key) { 53 } 54 55 SslHmacChannelAuthenticator::~SslHmacChannelAuthenticator() { 56 } 57 58 void SslHmacChannelAuthenticator::SecureAndAuthenticate( 59 scoped_ptr<net::StreamSocket> socket, const DoneCallback& done_callback) { 60 DCHECK(CalledOnValidThread()); 61 DCHECK(socket->IsConnected()); 62 63 done_callback_ = done_callback; 64 65 int result; 66 if (is_ssl_server()) { 67 #if defined(OS_NACL) 68 // Client plugin doesn't use server SSL sockets, and so SSLServerSocket 69 // implementation is not compiled for NaCl as part of net_nacl. 70 NOTREACHED(); 71 result = net::ERR_FAILED; 72 #else 73 scoped_refptr<net::X509Certificate> cert = 74 net::X509Certificate::CreateFromBytes( 75 local_cert_.data(), local_cert_.length()); 76 if (!cert.get()) { 77 LOG(ERROR) << "Failed to parse X509Certificate"; 78 NotifyError(net::ERR_FAILED); 79 return; 80 } 81 82 net::SSLConfig ssl_config; 83 ssl_config.require_forward_secrecy = true; 84 85 scoped_ptr<net::SSLServerSocket> server_socket = 86 net::CreateSSLServerSocket(socket.Pass(), 87 cert.get(), 88 local_key_pair_->private_key(), 89 ssl_config); 90 net::SSLServerSocket* raw_server_socket = server_socket.get(); 91 socket_ = server_socket.Pass(); 92 result = raw_server_socket->Handshake( 93 base::Bind(&SslHmacChannelAuthenticator::OnConnected, 94 base::Unretained(this))); 95 #endif 96 } else { 97 transport_security_state_.reset(new net::TransportSecurityState); 98 99 net::SSLConfig::CertAndStatus cert_and_status; 100 cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID; 101 cert_and_status.der_cert = remote_cert_; 102 103 net::SSLConfig ssl_config; 104 // Certificate verification and revocation checking are not needed 105 // because we use self-signed certs. Disable it so that the SSL 106 // layer doesn't try to initialize OCSP (OCSP works only on the IO 107 // thread). 108 ssl_config.cert_io_enabled = false; 109 ssl_config.rev_checking_enabled = false; 110 ssl_config.allowed_bad_certs.push_back(cert_and_status); 111 112 net::HostPortPair host_and_port(kSslFakeHostName, 0); 113 net::SSLClientSocketContext context; 114 context.transport_security_state = transport_security_state_.get(); 115 scoped_ptr<net::ClientSocketHandle> socket_handle( 116 new net::ClientSocketHandle); 117 socket_handle->SetSocket(socket.Pass()); 118 119 #if defined(OS_NACL) 120 // net_nacl doesn't include ClientSocketFactory. 121 socket_.reset(new net::SSLClientSocketOpenSSL( 122 socket_handle.Pass(), host_and_port, ssl_config, context)); 123 #else 124 socket_ = 125 net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket( 126 socket_handle.Pass(), host_and_port, ssl_config, context); 127 #endif 128 129 result = socket_->Connect( 130 base::Bind(&SslHmacChannelAuthenticator::OnConnected, 131 base::Unretained(this))); 132 } 133 134 if (result == net::ERR_IO_PENDING) 135 return; 136 137 OnConnected(result); 138 } 139 140 bool SslHmacChannelAuthenticator::is_ssl_server() { 141 return local_key_pair_.get() != NULL; 142 } 143 144 void SslHmacChannelAuthenticator::OnConnected(int result) { 145 if (result != net::OK) { 146 LOG(WARNING) << "Failed to establish SSL connection"; 147 NotifyError(result); 148 return; 149 } 150 151 // Generate authentication digest to write to the socket. 152 std::string auth_bytes = GetAuthBytes( 153 socket_.get(), is_ssl_server() ? 154 kHostAuthSslExporterLabel : kClientAuthSslExporterLabel, auth_key_); 155 if (auth_bytes.empty()) { 156 NotifyError(net::ERR_FAILED); 157 return; 158 } 159 160 // Allocate a buffer to write the digest. 161 auth_write_buf_ = new net::DrainableIOBuffer( 162 new net::StringIOBuffer(auth_bytes), auth_bytes.size()); 163 164 // Read an incoming token. 165 auth_read_buf_ = new net::GrowableIOBuffer(); 166 auth_read_buf_->SetCapacity(kAuthDigestLength); 167 168 // If WriteAuthenticationBytes() results in |done_callback_| being 169 // called then we must not do anything else because this object may 170 // be destroyed at that point. 171 bool callback_called = false; 172 WriteAuthenticationBytes(&callback_called); 173 if (!callback_called) 174 ReadAuthenticationBytes(); 175 } 176 177 void SslHmacChannelAuthenticator::WriteAuthenticationBytes( 178 bool* callback_called) { 179 while (true) { 180 int result = socket_->Write( 181 auth_write_buf_.get(), 182 auth_write_buf_->BytesRemaining(), 183 base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesWritten, 184 base::Unretained(this))); 185 if (result == net::ERR_IO_PENDING) 186 break; 187 if (!HandleAuthBytesWritten(result, callback_called)) 188 break; 189 } 190 } 191 192 void SslHmacChannelAuthenticator::OnAuthBytesWritten(int result) { 193 DCHECK(CalledOnValidThread()); 194 195 if (HandleAuthBytesWritten(result, NULL)) 196 WriteAuthenticationBytes(NULL); 197 } 198 199 bool SslHmacChannelAuthenticator::HandleAuthBytesWritten( 200 int result, bool* callback_called) { 201 if (result <= 0) { 202 LOG(ERROR) << "Error writing authentication: " << result; 203 if (callback_called) 204 *callback_called = false; 205 NotifyError(result); 206 return false; 207 } 208 209 auth_write_buf_->DidConsume(result); 210 if (auth_write_buf_->BytesRemaining() > 0) 211 return true; 212 213 auth_write_buf_ = NULL; 214 CheckDone(callback_called); 215 return false; 216 } 217 218 void SslHmacChannelAuthenticator::ReadAuthenticationBytes() { 219 while (true) { 220 int result = 221 socket_->Read(auth_read_buf_.get(), 222 auth_read_buf_->RemainingCapacity(), 223 base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesRead, 224 base::Unretained(this))); 225 if (result == net::ERR_IO_PENDING) 226 break; 227 if (!HandleAuthBytesRead(result)) 228 break; 229 } 230 } 231 232 void SslHmacChannelAuthenticator::OnAuthBytesRead(int result) { 233 DCHECK(CalledOnValidThread()); 234 235 if (HandleAuthBytesRead(result)) 236 ReadAuthenticationBytes(); 237 } 238 239 bool SslHmacChannelAuthenticator::HandleAuthBytesRead(int read_result) { 240 if (read_result <= 0) { 241 NotifyError(read_result); 242 return false; 243 } 244 245 auth_read_buf_->set_offset(auth_read_buf_->offset() + read_result); 246 if (auth_read_buf_->RemainingCapacity() > 0) 247 return true; 248 249 if (!VerifyAuthBytes(std::string( 250 auth_read_buf_->StartOfBuffer(), 251 auth_read_buf_->StartOfBuffer() + kAuthDigestLength))) { 252 LOG(WARNING) << "Mismatched authentication"; 253 NotifyError(net::ERR_FAILED); 254 return false; 255 } 256 257 auth_read_buf_ = NULL; 258 CheckDone(NULL); 259 return false; 260 } 261 262 bool SslHmacChannelAuthenticator::VerifyAuthBytes( 263 const std::string& received_auth_bytes) { 264 DCHECK(received_auth_bytes.length() == kAuthDigestLength); 265 266 // Compute expected auth bytes. 267 std::string auth_bytes = GetAuthBytes( 268 socket_.get(), is_ssl_server() ? 269 kClientAuthSslExporterLabel : kHostAuthSslExporterLabel, auth_key_); 270 if (auth_bytes.empty()) 271 return false; 272 273 return crypto::SecureMemEqual(received_auth_bytes.data(), 274 &(auth_bytes[0]), kAuthDigestLength); 275 } 276 277 void SslHmacChannelAuthenticator::CheckDone(bool* callback_called) { 278 if (auth_write_buf_.get() == NULL && auth_read_buf_.get() == NULL) { 279 DCHECK(socket_.get() != NULL); 280 if (callback_called) 281 *callback_called = true; 282 done_callback_.Run(net::OK, socket_.PassAs<net::StreamSocket>()); 283 } 284 } 285 286 void SslHmacChannelAuthenticator::NotifyError(int error) { 287 done_callback_.Run(static_cast<net::Error>(error), 288 scoped_ptr<net::StreamSocket>()); 289 } 290 291 } // namespace protocol 292 } // namespace remoting 293