1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "remoting/protocol/ssl_hmac_channel_authenticator.h" 6 7 #include "base/bind.h" 8 #include "base/bind_helpers.h" 9 #include "crypto/secure_util.h" 10 #include "net/base/host_port_pair.h" 11 #include "net/base/io_buffer.h" 12 #include "net/base/net_errors.h" 13 #include "net/cert/cert_verifier.h" 14 #include "net/cert/x509_certificate.h" 15 #include "net/http/transport_security_state.h" 16 #include "net/socket/client_socket_factory.h" 17 #include "net/socket/ssl_client_socket.h" 18 #include "net/socket/ssl_server_socket.h" 19 #include "net/ssl/ssl_config_service.h" 20 #include "remoting/base/rsa_key_pair.h" 21 #include "remoting/protocol/auth_util.h" 22 23 namespace remoting { 24 namespace protocol { 25 26 // static 27 scoped_ptr<SslHmacChannelAuthenticator> 28 SslHmacChannelAuthenticator::CreateForClient( 29 const std::string& remote_cert, 30 const std::string& auth_key) { 31 scoped_ptr<SslHmacChannelAuthenticator> result( 32 new SslHmacChannelAuthenticator(auth_key)); 33 result->remote_cert_ = remote_cert; 34 return result.Pass(); 35 } 36 37 scoped_ptr<SslHmacChannelAuthenticator> 38 SslHmacChannelAuthenticator::CreateForHost( 39 const std::string& local_cert, 40 scoped_refptr<RsaKeyPair> key_pair, 41 const std::string& auth_key) { 42 scoped_ptr<SslHmacChannelAuthenticator> result( 43 new SslHmacChannelAuthenticator(auth_key)); 44 result->local_cert_ = local_cert; 45 result->local_key_pair_ = key_pair; 46 return result.Pass(); 47 } 48 49 SslHmacChannelAuthenticator::SslHmacChannelAuthenticator( 50 const std::string& auth_key) 51 : auth_key_(auth_key) { 52 } 53 54 SslHmacChannelAuthenticator::~SslHmacChannelAuthenticator() { 55 } 56 57 void SslHmacChannelAuthenticator::SecureAndAuthenticate( 58 scoped_ptr<net::StreamSocket> socket, const DoneCallback& done_callback) { 59 DCHECK(CalledOnValidThread()); 60 DCHECK(socket->IsConnected()); 61 62 done_callback_ = done_callback; 63 64 int result; 65 if (is_ssl_server()) { 66 scoped_refptr<net::X509Certificate> cert = 67 net::X509Certificate::CreateFromBytes( 68 local_cert_.data(), local_cert_.length()); 69 if (!cert.get()) { 70 LOG(ERROR) << "Failed to parse X509Certificate"; 71 NotifyError(net::ERR_FAILED); 72 return; 73 } 74 75 net::SSLConfig ssl_config; 76 net::SSLServerSocket* server_socket = 77 net::CreateSSLServerSocket(socket.release(), 78 cert.get(), 79 local_key_pair_->private_key(), 80 ssl_config); 81 socket_.reset(server_socket); 82 83 result = server_socket->Handshake(base::Bind( 84 &SslHmacChannelAuthenticator::OnConnected, base::Unretained(this))); 85 } else { 86 cert_verifier_.reset(net::CertVerifier::CreateDefault()); 87 transport_security_state_.reset(new net::TransportSecurityState); 88 89 net::SSLConfig::CertAndStatus cert_and_status; 90 cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID; 91 cert_and_status.der_cert = remote_cert_; 92 93 net::SSLConfig ssl_config; 94 // Certificate verification and revocation checking are not needed 95 // because we use self-signed certs. Disable it so that the SSL 96 // layer doesn't try to initialize OCSP (OCSP works only on the IO 97 // thread). 98 ssl_config.cert_io_enabled = false; 99 ssl_config.rev_checking_enabled = false; 100 ssl_config.allowed_bad_certs.push_back(cert_and_status); 101 102 net::HostPortPair host_and_port(kSslFakeHostName, 0); 103 net::SSLClientSocketContext context; 104 context.cert_verifier = cert_verifier_.get(); 105 context.transport_security_state = transport_security_state_.get(); 106 socket_.reset( 107 net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket( 108 socket.release(), host_and_port, ssl_config, context)); 109 110 result = socket_->Connect( 111 base::Bind(&SslHmacChannelAuthenticator::OnConnected, 112 base::Unretained(this))); 113 } 114 115 if (result == net::ERR_IO_PENDING) 116 return; 117 118 OnConnected(result); 119 } 120 121 bool SslHmacChannelAuthenticator::is_ssl_server() { 122 return local_key_pair_.get() != NULL; 123 } 124 125 void SslHmacChannelAuthenticator::OnConnected(int result) { 126 if (result != net::OK) { 127 LOG(WARNING) << "Failed to establish SSL connection"; 128 NotifyError(result); 129 return; 130 } 131 132 // Generate authentication digest to write to the socket. 133 std::string auth_bytes = GetAuthBytes( 134 socket_.get(), is_ssl_server() ? 135 kHostAuthSslExporterLabel : kClientAuthSslExporterLabel, auth_key_); 136 if (auth_bytes.empty()) { 137 NotifyError(net::ERR_FAILED); 138 return; 139 } 140 141 // Allocate a buffer to write the digest. 142 auth_write_buf_ = new net::DrainableIOBuffer( 143 new net::StringIOBuffer(auth_bytes), auth_bytes.size()); 144 145 // Read an incoming token. 146 auth_read_buf_ = new net::GrowableIOBuffer(); 147 auth_read_buf_->SetCapacity(kAuthDigestLength); 148 149 // If WriteAuthenticationBytes() results in |done_callback_| being 150 // called then we must not do anything else because this object may 151 // be destroyed at that point. 152 bool callback_called = false; 153 WriteAuthenticationBytes(&callback_called); 154 if (!callback_called) 155 ReadAuthenticationBytes(); 156 } 157 158 void SslHmacChannelAuthenticator::WriteAuthenticationBytes( 159 bool* callback_called) { 160 while (true) { 161 int result = socket_->Write( 162 auth_write_buf_.get(), 163 auth_write_buf_->BytesRemaining(), 164 base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesWritten, 165 base::Unretained(this))); 166 if (result == net::ERR_IO_PENDING) 167 break; 168 if (!HandleAuthBytesWritten(result, callback_called)) 169 break; 170 } 171 } 172 173 void SslHmacChannelAuthenticator::OnAuthBytesWritten(int result) { 174 DCHECK(CalledOnValidThread()); 175 176 if (HandleAuthBytesWritten(result, NULL)) 177 WriteAuthenticationBytes(NULL); 178 } 179 180 bool SslHmacChannelAuthenticator::HandleAuthBytesWritten( 181 int result, bool* callback_called) { 182 if (result <= 0) { 183 LOG(ERROR) << "Error writing authentication: " << result; 184 if (callback_called) 185 *callback_called = false; 186 NotifyError(result); 187 return false; 188 } 189 190 auth_write_buf_->DidConsume(result); 191 if (auth_write_buf_->BytesRemaining() > 0) 192 return true; 193 194 auth_write_buf_ = NULL; 195 CheckDone(callback_called); 196 return false; 197 } 198 199 void SslHmacChannelAuthenticator::ReadAuthenticationBytes() { 200 while (true) { 201 int result = 202 socket_->Read(auth_read_buf_.get(), 203 auth_read_buf_->RemainingCapacity(), 204 base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesRead, 205 base::Unretained(this))); 206 if (result == net::ERR_IO_PENDING) 207 break; 208 if (!HandleAuthBytesRead(result)) 209 break; 210 } 211 } 212 213 void SslHmacChannelAuthenticator::OnAuthBytesRead(int result) { 214 DCHECK(CalledOnValidThread()); 215 216 if (HandleAuthBytesRead(result)) 217 ReadAuthenticationBytes(); 218 } 219 220 bool SslHmacChannelAuthenticator::HandleAuthBytesRead(int read_result) { 221 if (read_result <= 0) { 222 NotifyError(read_result); 223 return false; 224 } 225 226 auth_read_buf_->set_offset(auth_read_buf_->offset() + read_result); 227 if (auth_read_buf_->RemainingCapacity() > 0) 228 return true; 229 230 if (!VerifyAuthBytes(std::string( 231 auth_read_buf_->StartOfBuffer(), 232 auth_read_buf_->StartOfBuffer() + kAuthDigestLength))) { 233 LOG(WARNING) << "Mismatched authentication"; 234 NotifyError(net::ERR_FAILED); 235 return false; 236 } 237 238 auth_read_buf_ = NULL; 239 CheckDone(NULL); 240 return false; 241 } 242 243 bool SslHmacChannelAuthenticator::VerifyAuthBytes( 244 const std::string& received_auth_bytes) { 245 DCHECK(received_auth_bytes.length() == kAuthDigestLength); 246 247 // Compute expected auth bytes. 248 std::string auth_bytes = GetAuthBytes( 249 socket_.get(), is_ssl_server() ? 250 kClientAuthSslExporterLabel : kHostAuthSslExporterLabel, auth_key_); 251 if (auth_bytes.empty()) 252 return false; 253 254 return crypto::SecureMemEqual(received_auth_bytes.data(), 255 &(auth_bytes[0]), kAuthDigestLength); 256 } 257 258 void SslHmacChannelAuthenticator::CheckDone(bool* callback_called) { 259 if (auth_write_buf_.get() == NULL && auth_read_buf_.get() == NULL) { 260 DCHECK(socket_.get() != NULL); 261 if (callback_called) 262 *callback_called = true; 263 done_callback_.Run(net::OK, socket_.PassAs<net::StreamSocket>()); 264 } 265 } 266 267 void SslHmacChannelAuthenticator::NotifyError(int error) { 268 done_callback_.Run(static_cast<net::Error>(error), 269 scoped_ptr<net::StreamSocket>()); 270 } 271 272 } // namespace protocol 273 } // namespace remoting 274