1 // Copyright 2015 The Chromium OS Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include <brillo/streams/tls_stream.h> 6 7 #include <algorithm> 8 #include <limits> 9 #include <string> 10 #include <vector> 11 12 #include <openssl/err.h> 13 #include <openssl/ssl.h> 14 15 #include <base/bind.h> 16 #include <base/memory/weak_ptr.h> 17 #include <brillo/message_loops/message_loop.h> 18 #include <brillo/secure_blob.h> 19 #include <brillo/streams/openssl_stream_bio.h> 20 #include <brillo/streams/stream_utils.h> 21 #include <brillo/strings/string_utils.h> 22 23 namespace { 24 25 // SSL info callback which is called by OpenSSL when we enable logging level of 26 // at least 3. This logs the information about the internal TLS handshake. 27 void TlsInfoCallback(const SSL* /* ssl */, int where, int ret) { 28 std::string reason; 29 std::vector<std::string> info; 30 if (where & SSL_CB_LOOP) 31 info.push_back("loop"); 32 if (where & SSL_CB_EXIT) 33 info.push_back("exit"); 34 if (where & SSL_CB_READ) 35 info.push_back("read"); 36 if (where & SSL_CB_WRITE) 37 info.push_back("write"); 38 if (where & SSL_CB_ALERT) { 39 info.push_back("alert"); 40 reason = ", reason: "; 41 reason += SSL_alert_type_string_long(ret); 42 reason += "/"; 43 reason += SSL_alert_desc_string_long(ret); 44 } 45 if (where & SSL_CB_HANDSHAKE_START) 46 info.push_back("handshake_start"); 47 if (where & SSL_CB_HANDSHAKE_DONE) 48 info.push_back("handshake_done"); 49 50 VLOG(3) << "TLS progress info: " << brillo::string_utils::Join(",", info) 51 << ", with status: " << ret << reason; 52 } 53 54 // Static variable to store the index of TlsStream private data in SSL context 55 // used to store custom data for OnCertVerifyResults(). 56 int ssl_ctx_private_data_index = -1; 57 58 // Default trusted certificate store location. 59 const char kCACertificatePath[] = 60 #ifdef __ANDROID__ 61 "/system/etc/security/cacerts_google"; 62 #else 63 "/usr/share/chromeos-ca-certificates"; 64 #endif 65 66 } // anonymous namespace 67 68 namespace brillo { 69 70 // Helper implementation of TLS stream used to hide most of OpenSSL inner 71 // workings from the users of brillo::TlsStream. 72 class TlsStream::TlsStreamImpl { 73 public: 74 TlsStreamImpl(); 75 ~TlsStreamImpl(); 76 77 bool Init(StreamPtr socket, 78 const std::string& host, 79 const base::Closure& success_callback, 80 const Stream::ErrorCallback& error_callback, 81 ErrorPtr* error); 82 83 bool ReadNonBlocking(void* buffer, 84 size_t size_to_read, 85 size_t* size_read, 86 bool* end_of_stream, 87 ErrorPtr* error); 88 89 bool WriteNonBlocking(const void* buffer, 90 size_t size_to_write, 91 size_t* size_written, 92 ErrorPtr* error); 93 94 bool Flush(ErrorPtr* error); 95 bool Close(ErrorPtr* error); 96 bool WaitForData(AccessMode mode, 97 const base::Callback<void(AccessMode)>& callback, 98 ErrorPtr* error); 99 bool WaitForDataBlocking(AccessMode in_mode, 100 base::TimeDelta timeout, 101 AccessMode* out_mode, 102 ErrorPtr* error); 103 void CancelPendingAsyncOperations(); 104 105 private: 106 bool ReportError(ErrorPtr* error, 107 const tracked_objects::Location& location, 108 const std::string& message); 109 void DoHandshake(const base::Closure& success_callback, 110 const Stream::ErrorCallback& error_callback); 111 void RetryHandshake(const base::Closure& success_callback, 112 const Stream::ErrorCallback& error_callback, 113 Stream::AccessMode mode); 114 115 int OnCertVerifyResults(int ok, X509_STORE_CTX* ctx); 116 static int OnCertVerifyResultsStatic(int ok, X509_STORE_CTX* ctx); 117 118 StreamPtr socket_; 119 std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> ctx_{nullptr, SSL_CTX_free}; 120 std::unique_ptr<SSL, decltype(&SSL_free)> ssl_{nullptr, SSL_free}; 121 BIO* stream_bio_{nullptr}; 122 bool need_more_read_{false}; 123 bool need_more_write_{false}; 124 125 base::WeakPtrFactory<TlsStreamImpl> weak_ptr_factory_{this}; 126 DISALLOW_COPY_AND_ASSIGN(TlsStreamImpl); 127 }; 128 129 TlsStream::TlsStreamImpl::TlsStreamImpl() { 130 SSL_load_error_strings(); 131 SSL_library_init(); 132 if (ssl_ctx_private_data_index < 0) { 133 ssl_ctx_private_data_index = 134 SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); 135 } 136 } 137 138 TlsStream::TlsStreamImpl::~TlsStreamImpl() { 139 ssl_.reset(); 140 ctx_.reset(); 141 } 142 143 bool TlsStream::TlsStreamImpl::ReadNonBlocking(void* buffer, 144 size_t size_to_read, 145 size_t* size_read, 146 bool* end_of_stream, 147 ErrorPtr* error) { 148 const size_t max_int = std::numeric_limits<int>::max(); 149 int size_int = static_cast<int>(std::min(size_to_read, max_int)); 150 int ret = SSL_read(ssl_.get(), buffer, size_int); 151 if (ret > 0) { 152 *size_read = static_cast<size_t>(ret); 153 if (end_of_stream) 154 *end_of_stream = false; 155 return true; 156 } 157 158 int err = SSL_get_error(ssl_.get(), ret); 159 if (err == SSL_ERROR_ZERO_RETURN) { 160 *size_read = 0; 161 if (end_of_stream) 162 *end_of_stream = true; 163 return true; 164 } 165 166 if (err == SSL_ERROR_WANT_READ) { 167 need_more_read_ = true; 168 } else if (err == SSL_ERROR_WANT_WRITE) { 169 // Writes might be required for SSL_read() because of possible TLS 170 // re-negotiations which can happen at any time. 171 need_more_write_ = true; 172 } else { 173 return ReportError(error, FROM_HERE, "Error reading from TLS socket"); 174 } 175 *size_read = 0; 176 if (end_of_stream) 177 *end_of_stream = false; 178 return true; 179 } 180 181 bool TlsStream::TlsStreamImpl::WriteNonBlocking(const void* buffer, 182 size_t size_to_write, 183 size_t* size_written, 184 ErrorPtr* error) { 185 const size_t max_int = std::numeric_limits<int>::max(); 186 int size_int = static_cast<int>(std::min(size_to_write, max_int)); 187 int ret = SSL_write(ssl_.get(), buffer, size_int); 188 if (ret > 0) { 189 *size_written = static_cast<size_t>(ret); 190 return true; 191 } 192 193 int err = SSL_get_error(ssl_.get(), ret); 194 if (err == SSL_ERROR_WANT_READ) { 195 // Reads might be required for SSL_write() because of possible TLS 196 // re-negotiations which can happen at any time. 197 need_more_read_ = true; 198 } else if (err == SSL_ERROR_WANT_WRITE) { 199 need_more_write_ = true; 200 } else { 201 return ReportError(error, FROM_HERE, "Error writing to TLS socket"); 202 } 203 *size_written = 0; 204 return true; 205 } 206 207 bool TlsStream::TlsStreamImpl::Flush(ErrorPtr* error) { 208 return socket_->FlushBlocking(error); 209 } 210 211 bool TlsStream::TlsStreamImpl::Close(ErrorPtr* error) { 212 // 2 seconds should be plenty here. 213 const base::TimeDelta kTimeout = base::TimeDelta::FromSeconds(2); 214 // The retry count of 4 below is just arbitrary, to ensure we don't get stuck 215 // here forever. We should rarely need to repeat SSL_shutdown anyway. 216 for (int retry_count = 0; retry_count < 4; retry_count++) { 217 int ret = SSL_shutdown(ssl_.get()); 218 // We really don't care for bi-directional shutdown here. 219 // Just make sure we only send the "close notify" alert to the remote peer. 220 if (ret >= 0) 221 break; 222 223 int err = SSL_get_error(ssl_.get(), ret); 224 if (err == SSL_ERROR_WANT_READ) { 225 if (!socket_->WaitForDataBlocking(AccessMode::READ, kTimeout, nullptr, 226 error)) { 227 break; 228 } 229 } else if (err == SSL_ERROR_WANT_WRITE) { 230 if (!socket_->WaitForDataBlocking(AccessMode::WRITE, kTimeout, nullptr, 231 error)) { 232 break; 233 } 234 } else { 235 LOG(ERROR) << "SSL_shutdown returned error #" << err; 236 ReportError(error, FROM_HERE, "Failed to shut down TLS socket"); 237 break; 238 } 239 } 240 return socket_->CloseBlocking(error); 241 } 242 243 bool TlsStream::TlsStreamImpl::WaitForData( 244 AccessMode mode, 245 const base::Callback<void(AccessMode)>& callback, 246 ErrorPtr* error) { 247 bool is_read = stream_utils::IsReadAccessMode(mode); 248 bool is_write = stream_utils::IsWriteAccessMode(mode); 249 is_read |= need_more_read_; 250 is_write |= need_more_write_; 251 need_more_read_ = false; 252 need_more_write_ = false; 253 if (is_read && SSL_pending(ssl_.get()) > 0) { 254 callback.Run(AccessMode::READ); 255 return true; 256 } 257 mode = stream_utils::MakeAccessMode(is_read, is_write); 258 return socket_->WaitForData(mode, callback, error); 259 } 260 261 bool TlsStream::TlsStreamImpl::WaitForDataBlocking(AccessMode in_mode, 262 base::TimeDelta timeout, 263 AccessMode* out_mode, 264 ErrorPtr* error) { 265 bool is_read = stream_utils::IsReadAccessMode(in_mode); 266 bool is_write = stream_utils::IsWriteAccessMode(in_mode); 267 is_read |= need_more_read_; 268 is_write |= need_more_write_; 269 need_more_read_ = need_more_write_ = false; 270 if (is_read && SSL_pending(ssl_.get()) > 0) { 271 if (out_mode) 272 *out_mode = AccessMode::READ; 273 return true; 274 } 275 in_mode = stream_utils::MakeAccessMode(is_read, is_write); 276 return socket_->WaitForDataBlocking(in_mode, timeout, out_mode, error); 277 } 278 279 void TlsStream::TlsStreamImpl::CancelPendingAsyncOperations() { 280 socket_->CancelPendingAsyncOperations(); 281 weak_ptr_factory_.InvalidateWeakPtrs(); 282 } 283 284 bool TlsStream::TlsStreamImpl::ReportError( 285 ErrorPtr* error, 286 const tracked_objects::Location& location, 287 const std::string& message) { 288 const char* file = nullptr; 289 int line = 0; 290 const char* data = 0; 291 int flags = 0; 292 while (auto errnum = ERR_get_error_line_data(&file, &line, &data, &flags)) { 293 char buf[256]; 294 ERR_error_string_n(errnum, buf, sizeof(buf)); 295 tracked_objects::Location ssl_location{"Unknown", file, line, nullptr}; 296 std::string ssl_message = buf; 297 if (flags & ERR_TXT_STRING) { 298 ssl_message += ": "; 299 ssl_message += data; 300 } 301 Error::AddTo(error, ssl_location, "openssl", std::to_string(errnum), 302 ssl_message); 303 } 304 Error::AddTo(error, location, "tls_stream", "failed", message); 305 return false; 306 } 307 308 int TlsStream::TlsStreamImpl::OnCertVerifyResults(int ok, X509_STORE_CTX* ctx) { 309 // OpenSSL already performs a comprehensive check of the certificate chain 310 // (using X509_verify_cert() function) and calls back with the result of its 311 // verification. 312 // |ok| is set to 1 if the verification passed and 0 if an error was detected. 313 // Here we can perform some additional checks if we need to, or simply log 314 // the issues found. 315 316 // For now, just log an error if it occurred. 317 if (!ok) { 318 LOG(ERROR) << "Server certificate validation failed: " 319 << X509_verify_cert_error_string(X509_STORE_CTX_get_error(ctx)); 320 } 321 return ok; 322 } 323 324 int TlsStream::TlsStreamImpl::OnCertVerifyResultsStatic(int ok, 325 X509_STORE_CTX* ctx) { 326 // Obtain the pointer to the instance of TlsStream::TlsStreamImpl from the 327 // SSL CTX object referenced by |ctx|. 328 SSL* ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data( 329 ctx, SSL_get_ex_data_X509_STORE_CTX_idx())); 330 SSL_CTX* ssl_ctx = ssl ? SSL_get_SSL_CTX(ssl) : nullptr; 331 TlsStream::TlsStreamImpl* self = nullptr; 332 if (ssl_ctx) { 333 self = static_cast<TlsStream::TlsStreamImpl*>(SSL_CTX_get_ex_data( 334 ssl_ctx, ssl_ctx_private_data_index)); 335 } 336 return self ? self->OnCertVerifyResults(ok, ctx) : ok; 337 } 338 339 bool TlsStream::TlsStreamImpl::Init(StreamPtr socket, 340 const std::string& host, 341 const base::Closure& success_callback, 342 const Stream::ErrorCallback& error_callback, 343 ErrorPtr* error) { 344 ctx_.reset(SSL_CTX_new(TLSv1_2_client_method())); 345 if (!ctx_) 346 return ReportError(error, FROM_HERE, "Cannot create SSL_CTX"); 347 348 // Top cipher suites supported by both Google GFEs and OpenSSL (in server 349 // preferred order). 350 int res = SSL_CTX_set_cipher_list(ctx_.get(), 351 "ECDHE-ECDSA-AES128-GCM-SHA256:" 352 "ECDHE-ECDSA-AES256-GCM-SHA384:" 353 "ECDHE-RSA-AES128-GCM-SHA256:" 354 "ECDHE-RSA-AES256-GCM-SHA384"); 355 if (res != 1) 356 return ReportError(error, FROM_HERE, "Cannot set the cipher list"); 357 358 res = SSL_CTX_load_verify_locations(ctx_.get(), nullptr, kCACertificatePath); 359 if (res != 1) { 360 return ReportError(error, FROM_HERE, 361 "Failed to specify trusted certificate location"); 362 } 363 364 // Store a pointer to "this" into SSL_CTX instance. 365 SSL_CTX_set_ex_data(ctx_.get(), ssl_ctx_private_data_index, this); 366 367 // Ask OpenSSL to validate the server host from the certificate to match 368 // the expected host name we are given: 369 X509_VERIFY_PARAM* param = SSL_CTX_get0_param(ctx_.get()); 370 X509_VERIFY_PARAM_set1_host(param, host.c_str(), host.size()); 371 372 SSL_CTX_set_verify(ctx_.get(), SSL_VERIFY_PEER, 373 &TlsStreamImpl::OnCertVerifyResultsStatic); 374 375 socket_ = std::move(socket); 376 ssl_.reset(SSL_new(ctx_.get())); 377 378 // Enable TLS progress callback if VLOG level is >=3. 379 if (VLOG_IS_ON(3)) 380 SSL_set_info_callback(ssl_.get(), TlsInfoCallback); 381 382 stream_bio_ = BIO_new_stream(socket_.get()); 383 SSL_set_bio(ssl_.get(), stream_bio_, stream_bio_); 384 SSL_set_connect_state(ssl_.get()); 385 386 // We might have no message loop (e.g. we are in unit tests). 387 if (MessageLoop::ThreadHasCurrent()) { 388 MessageLoop::current()->PostTask( 389 FROM_HERE, 390 base::Bind(&TlsStreamImpl::DoHandshake, 391 weak_ptr_factory_.GetWeakPtr(), 392 success_callback, 393 error_callback)); 394 } else { 395 DoHandshake(success_callback, error_callback); 396 } 397 return true; 398 } 399 400 void TlsStream::TlsStreamImpl::RetryHandshake( 401 const base::Closure& success_callback, 402 const Stream::ErrorCallback& error_callback, 403 Stream::AccessMode /* mode */) { 404 VLOG(1) << "Retrying TLS handshake"; 405 DoHandshake(success_callback, error_callback); 406 } 407 408 void TlsStream::TlsStreamImpl::DoHandshake( 409 const base::Closure& success_callback, 410 const Stream::ErrorCallback& error_callback) { 411 VLOG(1) << "Begin TLS handshake"; 412 int res = SSL_do_handshake(ssl_.get()); 413 if (res == 1) { 414 VLOG(1) << "Handshake successful"; 415 success_callback.Run(); 416 return; 417 } 418 ErrorPtr error; 419 int err = SSL_get_error(ssl_.get(), res); 420 if (err == SSL_ERROR_WANT_READ) { 421 VLOG(1) << "Waiting for read data..."; 422 bool ok = socket_->WaitForData( 423 Stream::AccessMode::READ, 424 base::Bind(&TlsStreamImpl::RetryHandshake, 425 weak_ptr_factory_.GetWeakPtr(), 426 success_callback, error_callback), 427 &error); 428 if (ok) 429 return; 430 } else if (err == SSL_ERROR_WANT_WRITE) { 431 VLOG(1) << "Waiting for write data..."; 432 bool ok = socket_->WaitForData( 433 Stream::AccessMode::WRITE, 434 base::Bind(&TlsStreamImpl::RetryHandshake, 435 weak_ptr_factory_.GetWeakPtr(), 436 success_callback, error_callback), 437 &error); 438 if (ok) 439 return; 440 } else { 441 ReportError(&error, FROM_HERE, "TLS handshake failed."); 442 } 443 error_callback.Run(error.get()); 444 } 445 446 ///////////////////////////////////////////////////////////////////////////// 447 TlsStream::TlsStream(std::unique_ptr<TlsStreamImpl> impl) 448 : impl_{std::move(impl)} {} 449 450 TlsStream::~TlsStream() { 451 if (impl_) { 452 impl_->Close(nullptr); 453 } 454 } 455 456 void TlsStream::Connect(StreamPtr socket, 457 const std::string& host, 458 const base::Callback<void(StreamPtr)>& success_callback, 459 const Stream::ErrorCallback& error_callback) { 460 std::unique_ptr<TlsStreamImpl> impl{new TlsStreamImpl}; 461 std::unique_ptr<TlsStream> stream{new TlsStream{std::move(impl)}}; 462 463 TlsStreamImpl* pimpl = stream->impl_.get(); 464 ErrorPtr error; 465 bool success = pimpl->Init(std::move(socket), host, 466 base::Bind(success_callback, 467 base::Passed(std::move(stream))), 468 error_callback, &error); 469 470 if (!success) 471 error_callback.Run(error.get()); 472 } 473 474 bool TlsStream::IsOpen() const { 475 return impl_ ? true : false; 476 } 477 478 bool TlsStream::SetSizeBlocking(uint64_t /* size */, ErrorPtr* error) { 479 return stream_utils::ErrorOperationNotSupported(FROM_HERE, error); 480 } 481 482 bool TlsStream::Seek(int64_t /* offset */, 483 Whence /* whence */, 484 uint64_t* /* new_position*/, 485 ErrorPtr* error) { 486 return stream_utils::ErrorOperationNotSupported(FROM_HERE, error); 487 } 488 489 bool TlsStream::ReadNonBlocking(void* buffer, 490 size_t size_to_read, 491 size_t* size_read, 492 bool* end_of_stream, 493 ErrorPtr* error) { 494 if (!impl_) 495 return stream_utils::ErrorStreamClosed(FROM_HERE, error); 496 return impl_->ReadNonBlocking(buffer, size_to_read, size_read, end_of_stream, 497 error); 498 } 499 500 bool TlsStream::WriteNonBlocking(const void* buffer, 501 size_t size_to_write, 502 size_t* size_written, 503 ErrorPtr* error) { 504 if (!impl_) 505 return stream_utils::ErrorStreamClosed(FROM_HERE, error); 506 return impl_->WriteNonBlocking(buffer, size_to_write, size_written, error); 507 } 508 509 bool TlsStream::FlushBlocking(ErrorPtr* error) { 510 if (!impl_) 511 return stream_utils::ErrorStreamClosed(FROM_HERE, error); 512 return impl_->Flush(error); 513 } 514 515 bool TlsStream::CloseBlocking(ErrorPtr* error) { 516 if (impl_ && !impl_->Close(error)) 517 return false; 518 impl_.reset(); 519 return true; 520 } 521 522 bool TlsStream::WaitForData(AccessMode mode, 523 const base::Callback<void(AccessMode)>& callback, 524 ErrorPtr* error) { 525 if (!impl_) 526 return stream_utils::ErrorStreamClosed(FROM_HERE, error); 527 return impl_->WaitForData(mode, callback, error); 528 } 529 530 bool TlsStream::WaitForDataBlocking(AccessMode in_mode, 531 base::TimeDelta timeout, 532 AccessMode* out_mode, 533 ErrorPtr* error) { 534 if (!impl_) 535 return stream_utils::ErrorStreamClosed(FROM_HERE, error); 536 return impl_->WaitForDataBlocking(in_mode, timeout, out_mode, error); 537 } 538 539 void TlsStream::CancelPendingAsyncOperations() { 540 if (impl_) 541 impl_->CancelPendingAsyncOperations(); 542 Stream::CancelPendingAsyncOperations(); 543 } 544 545 } // namespace brillo 546