Home | History | Annotate | Download | only in streams
      1 // Copyright 2015 The Chromium OS Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <brillo/streams/tls_stream.h>
      6 
      7 #include <algorithm>
      8 #include <limits>
      9 #include <string>
     10 #include <vector>
     11 
     12 #include <openssl/err.h>
     13 #include <openssl/ssl.h>
     14 
     15 #include <base/bind.h>
     16 #include <base/memory/weak_ptr.h>
     17 #include <brillo/message_loops/message_loop.h>
     18 #include <brillo/secure_blob.h>
     19 #include <brillo/streams/openssl_stream_bio.h>
     20 #include <brillo/streams/stream_utils.h>
     21 #include <brillo/strings/string_utils.h>
     22 
     23 namespace {
     24 
     25 // SSL info callback which is called by OpenSSL when we enable logging level of
     26 // at least 3. This logs the information about the internal TLS handshake.
     27 void TlsInfoCallback(const SSL* /* ssl */, int where, int ret) {
     28   std::string reason;
     29   std::vector<std::string> info;
     30   if (where & SSL_CB_LOOP)
     31     info.push_back("loop");
     32   if (where & SSL_CB_EXIT)
     33     info.push_back("exit");
     34   if (where & SSL_CB_READ)
     35     info.push_back("read");
     36   if (where & SSL_CB_WRITE)
     37     info.push_back("write");
     38   if (where & SSL_CB_ALERT) {
     39     info.push_back("alert");
     40     reason = ", reason: ";
     41     reason += SSL_alert_type_string_long(ret);
     42     reason += "/";
     43     reason += SSL_alert_desc_string_long(ret);
     44   }
     45   if (where & SSL_CB_HANDSHAKE_START)
     46     info.push_back("handshake_start");
     47   if (where & SSL_CB_HANDSHAKE_DONE)
     48     info.push_back("handshake_done");
     49 
     50   VLOG(3) << "TLS progress info: " << brillo::string_utils::Join(",", info)
     51           << ", with status: " << ret << reason;
     52 }
     53 
     54 // Static variable to store the index of TlsStream private data in SSL context
     55 // used to store custom data for OnCertVerifyResults().
     56 int ssl_ctx_private_data_index = -1;
     57 
     58 // Default trusted certificate store location.
     59 const char kCACertificatePath[] =
     60 #ifdef __ANDROID__
     61     "/system/etc/security/cacerts_google";
     62 #else
     63     "/usr/share/chromeos-ca-certificates";
     64 #endif
     65 
     66 }  // anonymous namespace
     67 
     68 namespace brillo {
     69 
     70 // Helper implementation of TLS stream used to hide most of OpenSSL inner
     71 // workings from the users of brillo::TlsStream.
     72 class TlsStream::TlsStreamImpl {
     73  public:
     74   TlsStreamImpl();
     75   ~TlsStreamImpl();
     76 
     77   bool Init(StreamPtr socket,
     78             const std::string& host,
     79             const base::Closure& success_callback,
     80             const Stream::ErrorCallback& error_callback,
     81             ErrorPtr* error);
     82 
     83   bool ReadNonBlocking(void* buffer,
     84                        size_t size_to_read,
     85                        size_t* size_read,
     86                        bool* end_of_stream,
     87                        ErrorPtr* error);
     88 
     89   bool WriteNonBlocking(const void* buffer,
     90                         size_t size_to_write,
     91                         size_t* size_written,
     92                         ErrorPtr* error);
     93 
     94   bool Flush(ErrorPtr* error);
     95   bool Close(ErrorPtr* error);
     96   bool WaitForData(AccessMode mode,
     97                    const base::Callback<void(AccessMode)>& callback,
     98                    ErrorPtr* error);
     99   bool WaitForDataBlocking(AccessMode in_mode,
    100                            base::TimeDelta timeout,
    101                            AccessMode* out_mode,
    102                            ErrorPtr* error);
    103   void CancelPendingAsyncOperations();
    104 
    105  private:
    106   bool ReportError(ErrorPtr* error,
    107                    const tracked_objects::Location& location,
    108                    const std::string& message);
    109   void DoHandshake(const base::Closure& success_callback,
    110                    const Stream::ErrorCallback& error_callback);
    111   void RetryHandshake(const base::Closure& success_callback,
    112                       const Stream::ErrorCallback& error_callback,
    113                       Stream::AccessMode mode);
    114 
    115   int OnCertVerifyResults(int ok, X509_STORE_CTX* ctx);
    116   static int OnCertVerifyResultsStatic(int ok, X509_STORE_CTX* ctx);
    117 
    118   StreamPtr socket_;
    119   std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> ctx_{nullptr, SSL_CTX_free};
    120   std::unique_ptr<SSL, decltype(&SSL_free)> ssl_{nullptr, SSL_free};
    121   BIO* stream_bio_{nullptr};
    122   bool need_more_read_{false};
    123   bool need_more_write_{false};
    124 
    125   base::WeakPtrFactory<TlsStreamImpl> weak_ptr_factory_{this};
    126   DISALLOW_COPY_AND_ASSIGN(TlsStreamImpl);
    127 };
    128 
    129 TlsStream::TlsStreamImpl::TlsStreamImpl() {
    130   SSL_load_error_strings();
    131   SSL_library_init();
    132   if (ssl_ctx_private_data_index < 0) {
    133     ssl_ctx_private_data_index =
    134         SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
    135   }
    136 }
    137 
    138 TlsStream::TlsStreamImpl::~TlsStreamImpl() {
    139   ssl_.reset();
    140   ctx_.reset();
    141 }
    142 
    143 bool TlsStream::TlsStreamImpl::ReadNonBlocking(void* buffer,
    144                                                size_t size_to_read,
    145                                                size_t* size_read,
    146                                                bool* end_of_stream,
    147                                                ErrorPtr* error) {
    148   const size_t max_int = std::numeric_limits<int>::max();
    149   int size_int = static_cast<int>(std::min(size_to_read, max_int));
    150   int ret = SSL_read(ssl_.get(), buffer, size_int);
    151   if (ret > 0) {
    152     *size_read = static_cast<size_t>(ret);
    153     if (end_of_stream)
    154       *end_of_stream = false;
    155     return true;
    156   }
    157 
    158   int err = SSL_get_error(ssl_.get(), ret);
    159   if (err == SSL_ERROR_ZERO_RETURN) {
    160     *size_read = 0;
    161     if (end_of_stream)
    162       *end_of_stream = true;
    163     return true;
    164   }
    165 
    166   if (err == SSL_ERROR_WANT_READ) {
    167     need_more_read_ = true;
    168   } else if (err == SSL_ERROR_WANT_WRITE) {
    169     // Writes might be required for SSL_read() because of possible TLS
    170     // re-negotiations which can happen at any time.
    171     need_more_write_ = true;
    172   } else {
    173     return ReportError(error, FROM_HERE, "Error reading from TLS socket");
    174   }
    175   *size_read = 0;
    176   if (end_of_stream)
    177     *end_of_stream = false;
    178   return true;
    179 }
    180 
    181 bool TlsStream::TlsStreamImpl::WriteNonBlocking(const void* buffer,
    182                                                 size_t size_to_write,
    183                                                 size_t* size_written,
    184                                                 ErrorPtr* error) {
    185   const size_t max_int = std::numeric_limits<int>::max();
    186   int size_int = static_cast<int>(std::min(size_to_write, max_int));
    187   int ret = SSL_write(ssl_.get(), buffer, size_int);
    188   if (ret > 0) {
    189     *size_written = static_cast<size_t>(ret);
    190     return true;
    191   }
    192 
    193   int err = SSL_get_error(ssl_.get(), ret);
    194   if (err == SSL_ERROR_WANT_READ) {
    195     // Reads might be required for SSL_write() because of possible TLS
    196     // re-negotiations which can happen at any time.
    197     need_more_read_ = true;
    198   } else if (err == SSL_ERROR_WANT_WRITE) {
    199     need_more_write_ = true;
    200   } else {
    201     return ReportError(error, FROM_HERE, "Error writing to TLS socket");
    202   }
    203   *size_written = 0;
    204   return true;
    205 }
    206 
    207 bool TlsStream::TlsStreamImpl::Flush(ErrorPtr* error) {
    208   return socket_->FlushBlocking(error);
    209 }
    210 
    211 bool TlsStream::TlsStreamImpl::Close(ErrorPtr* error) {
    212   // 2 seconds should be plenty here.
    213   const base::TimeDelta kTimeout = base::TimeDelta::FromSeconds(2);
    214   // The retry count of 4 below is just arbitrary, to ensure we don't get stuck
    215   // here forever. We should rarely need to repeat SSL_shutdown anyway.
    216   for (int retry_count = 0; retry_count < 4; retry_count++) {
    217     int ret = SSL_shutdown(ssl_.get());
    218     // We really don't care for bi-directional shutdown here.
    219     // Just make sure we only send the "close notify" alert to the remote peer.
    220     if (ret >= 0)
    221       break;
    222 
    223     int err = SSL_get_error(ssl_.get(), ret);
    224     if (err == SSL_ERROR_WANT_READ) {
    225       if (!socket_->WaitForDataBlocking(AccessMode::READ, kTimeout, nullptr,
    226                                         error)) {
    227         break;
    228       }
    229     } else if (err == SSL_ERROR_WANT_WRITE) {
    230       if (!socket_->WaitForDataBlocking(AccessMode::WRITE, kTimeout, nullptr,
    231                                         error)) {
    232         break;
    233       }
    234     } else {
    235       LOG(ERROR) << "SSL_shutdown returned error #" << err;
    236       ReportError(error, FROM_HERE, "Failed to shut down TLS socket");
    237       break;
    238     }
    239   }
    240   return socket_->CloseBlocking(error);
    241 }
    242 
    243 bool TlsStream::TlsStreamImpl::WaitForData(
    244     AccessMode mode,
    245     const base::Callback<void(AccessMode)>& callback,
    246     ErrorPtr* error) {
    247   bool is_read = stream_utils::IsReadAccessMode(mode);
    248   bool is_write = stream_utils::IsWriteAccessMode(mode);
    249   is_read |= need_more_read_;
    250   is_write |= need_more_write_;
    251   need_more_read_ = false;
    252   need_more_write_ = false;
    253   if (is_read && SSL_pending(ssl_.get()) > 0) {
    254     callback.Run(AccessMode::READ);
    255     return true;
    256   }
    257   mode = stream_utils::MakeAccessMode(is_read, is_write);
    258   return socket_->WaitForData(mode, callback, error);
    259 }
    260 
    261 bool TlsStream::TlsStreamImpl::WaitForDataBlocking(AccessMode in_mode,
    262                                                    base::TimeDelta timeout,
    263                                                    AccessMode* out_mode,
    264                                                    ErrorPtr* error) {
    265   bool is_read = stream_utils::IsReadAccessMode(in_mode);
    266   bool is_write = stream_utils::IsWriteAccessMode(in_mode);
    267   is_read |= need_more_read_;
    268   is_write |= need_more_write_;
    269   need_more_read_ = need_more_write_ = false;
    270   if (is_read && SSL_pending(ssl_.get()) > 0) {
    271     if (out_mode)
    272       *out_mode = AccessMode::READ;
    273     return true;
    274   }
    275   in_mode = stream_utils::MakeAccessMode(is_read, is_write);
    276   return socket_->WaitForDataBlocking(in_mode, timeout, out_mode, error);
    277 }
    278 
    279 void TlsStream::TlsStreamImpl::CancelPendingAsyncOperations() {
    280   socket_->CancelPendingAsyncOperations();
    281   weak_ptr_factory_.InvalidateWeakPtrs();
    282 }
    283 
    284 bool TlsStream::TlsStreamImpl::ReportError(
    285     ErrorPtr* error,
    286     const tracked_objects::Location& location,
    287     const std::string& message) {
    288   const char* file = nullptr;
    289   int line = 0;
    290   const char* data = 0;
    291   int flags = 0;
    292   while (auto errnum = ERR_get_error_line_data(&file, &line, &data, &flags)) {
    293     char buf[256];
    294     ERR_error_string_n(errnum, buf, sizeof(buf));
    295     tracked_objects::Location ssl_location{"Unknown", file, line, nullptr};
    296     std::string ssl_message = buf;
    297     if (flags & ERR_TXT_STRING) {
    298       ssl_message += ": ";
    299       ssl_message += data;
    300     }
    301     Error::AddTo(error, ssl_location, "openssl", std::to_string(errnum),
    302                  ssl_message);
    303   }
    304   Error::AddTo(error, location, "tls_stream", "failed", message);
    305   return false;
    306 }
    307 
    308 int TlsStream::TlsStreamImpl::OnCertVerifyResults(int ok, X509_STORE_CTX* ctx) {
    309   // OpenSSL already performs a comprehensive check of the certificate chain
    310   // (using X509_verify_cert() function) and calls back with the result of its
    311   // verification.
    312   // |ok| is set to 1 if the verification passed and 0 if an error was detected.
    313   // Here we can perform some additional checks if we need to, or simply log
    314   // the issues found.
    315 
    316   // For now, just log an error if it occurred.
    317   if (!ok) {
    318     LOG(ERROR) << "Server certificate validation failed: "
    319                << X509_verify_cert_error_string(X509_STORE_CTX_get_error(ctx));
    320   }
    321   return ok;
    322 }
    323 
    324 int TlsStream::TlsStreamImpl::OnCertVerifyResultsStatic(int ok,
    325                                                         X509_STORE_CTX* ctx) {
    326   // Obtain the pointer to the instance of TlsStream::TlsStreamImpl from the
    327   // SSL CTX object referenced by |ctx|.
    328   SSL* ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(
    329       ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
    330   SSL_CTX* ssl_ctx = ssl ? SSL_get_SSL_CTX(ssl) : nullptr;
    331   TlsStream::TlsStreamImpl* self = nullptr;
    332   if (ssl_ctx) {
    333     self = static_cast<TlsStream::TlsStreamImpl*>(SSL_CTX_get_ex_data(
    334         ssl_ctx, ssl_ctx_private_data_index));
    335   }
    336   return self ? self->OnCertVerifyResults(ok, ctx) : ok;
    337 }
    338 
    339 bool TlsStream::TlsStreamImpl::Init(StreamPtr socket,
    340                                     const std::string& host,
    341                                     const base::Closure& success_callback,
    342                                     const Stream::ErrorCallback& error_callback,
    343                                     ErrorPtr* error) {
    344   ctx_.reset(SSL_CTX_new(TLSv1_2_client_method()));
    345   if (!ctx_)
    346     return ReportError(error, FROM_HERE, "Cannot create SSL_CTX");
    347 
    348   // Top cipher suites supported by both Google GFEs and OpenSSL (in server
    349   // preferred order).
    350   int res = SSL_CTX_set_cipher_list(ctx_.get(),
    351                                     "ECDHE-ECDSA-AES128-GCM-SHA256:"
    352                                     "ECDHE-ECDSA-AES256-GCM-SHA384:"
    353                                     "ECDHE-RSA-AES128-GCM-SHA256:"
    354                                     "ECDHE-RSA-AES256-GCM-SHA384");
    355   if (res != 1)
    356     return ReportError(error, FROM_HERE, "Cannot set the cipher list");
    357 
    358   res = SSL_CTX_load_verify_locations(ctx_.get(), nullptr, kCACertificatePath);
    359   if (res != 1) {
    360     return ReportError(error, FROM_HERE,
    361                        "Failed to specify trusted certificate location");
    362   }
    363 
    364   // Store a pointer to "this" into SSL_CTX instance.
    365   SSL_CTX_set_ex_data(ctx_.get(), ssl_ctx_private_data_index, this);
    366 
    367   // Ask OpenSSL to validate the server host from the certificate to match
    368   // the expected host name we are given:
    369   X509_VERIFY_PARAM* param = SSL_CTX_get0_param(ctx_.get());
    370   X509_VERIFY_PARAM_set1_host(param, host.c_str(), host.size());
    371 
    372   SSL_CTX_set_verify(ctx_.get(), SSL_VERIFY_PEER,
    373                      &TlsStreamImpl::OnCertVerifyResultsStatic);
    374 
    375   socket_ = std::move(socket);
    376   ssl_.reset(SSL_new(ctx_.get()));
    377 
    378   // Enable TLS progress callback if VLOG level is >=3.
    379   if (VLOG_IS_ON(3))
    380     SSL_set_info_callback(ssl_.get(), TlsInfoCallback);
    381 
    382   stream_bio_ = BIO_new_stream(socket_.get());
    383   SSL_set_bio(ssl_.get(), stream_bio_, stream_bio_);
    384   SSL_set_connect_state(ssl_.get());
    385 
    386   // We might have no message loop (e.g. we are in unit tests).
    387   if (MessageLoop::ThreadHasCurrent()) {
    388     MessageLoop::current()->PostTask(
    389         FROM_HERE,
    390         base::Bind(&TlsStreamImpl::DoHandshake,
    391                    weak_ptr_factory_.GetWeakPtr(),
    392                    success_callback,
    393                    error_callback));
    394   } else {
    395     DoHandshake(success_callback, error_callback);
    396   }
    397   return true;
    398 }
    399 
    400 void TlsStream::TlsStreamImpl::RetryHandshake(
    401     const base::Closure& success_callback,
    402     const Stream::ErrorCallback& error_callback,
    403     Stream::AccessMode /* mode */) {
    404   VLOG(1) << "Retrying TLS handshake";
    405   DoHandshake(success_callback, error_callback);
    406 }
    407 
    408 void TlsStream::TlsStreamImpl::DoHandshake(
    409     const base::Closure& success_callback,
    410     const Stream::ErrorCallback& error_callback) {
    411   VLOG(1) << "Begin TLS handshake";
    412   int res = SSL_do_handshake(ssl_.get());
    413   if (res == 1) {
    414     VLOG(1) << "Handshake successful";
    415     success_callback.Run();
    416     return;
    417   }
    418   ErrorPtr error;
    419   int err = SSL_get_error(ssl_.get(), res);
    420   if (err == SSL_ERROR_WANT_READ) {
    421     VLOG(1) << "Waiting for read data...";
    422     bool ok = socket_->WaitForData(
    423         Stream::AccessMode::READ,
    424         base::Bind(&TlsStreamImpl::RetryHandshake,
    425                    weak_ptr_factory_.GetWeakPtr(),
    426                    success_callback, error_callback),
    427         &error);
    428     if (ok)
    429       return;
    430   } else if (err == SSL_ERROR_WANT_WRITE) {
    431     VLOG(1) << "Waiting for write data...";
    432     bool ok = socket_->WaitForData(
    433         Stream::AccessMode::WRITE,
    434         base::Bind(&TlsStreamImpl::RetryHandshake,
    435                    weak_ptr_factory_.GetWeakPtr(),
    436                    success_callback, error_callback),
    437         &error);
    438     if (ok)
    439       return;
    440   } else {
    441     ReportError(&error, FROM_HERE, "TLS handshake failed.");
    442   }
    443   error_callback.Run(error.get());
    444 }
    445 
    446 /////////////////////////////////////////////////////////////////////////////
    447 TlsStream::TlsStream(std::unique_ptr<TlsStreamImpl> impl)
    448     : impl_{std::move(impl)} {}
    449 
    450 TlsStream::~TlsStream() {
    451   if (impl_) {
    452     impl_->Close(nullptr);
    453   }
    454 }
    455 
    456 void TlsStream::Connect(StreamPtr socket,
    457                         const std::string& host,
    458                         const base::Callback<void(StreamPtr)>& success_callback,
    459                         const Stream::ErrorCallback& error_callback) {
    460   std::unique_ptr<TlsStreamImpl> impl{new TlsStreamImpl};
    461   std::unique_ptr<TlsStream> stream{new TlsStream{std::move(impl)}};
    462 
    463   TlsStreamImpl* pimpl = stream->impl_.get();
    464   ErrorPtr error;
    465   bool success = pimpl->Init(std::move(socket), host,
    466                              base::Bind(success_callback,
    467                                         base::Passed(std::move(stream))),
    468                              error_callback, &error);
    469 
    470   if (!success)
    471     error_callback.Run(error.get());
    472 }
    473 
    474 bool TlsStream::IsOpen() const {
    475   return impl_ ? true : false;
    476 }
    477 
    478 bool TlsStream::SetSizeBlocking(uint64_t /* size */, ErrorPtr* error) {
    479   return stream_utils::ErrorOperationNotSupported(FROM_HERE, error);
    480 }
    481 
    482 bool TlsStream::Seek(int64_t /* offset */,
    483                      Whence /* whence */,
    484                      uint64_t* /* new_position*/,
    485                      ErrorPtr* error) {
    486   return stream_utils::ErrorOperationNotSupported(FROM_HERE, error);
    487 }
    488 
    489 bool TlsStream::ReadNonBlocking(void* buffer,
    490                                 size_t size_to_read,
    491                                 size_t* size_read,
    492                                 bool* end_of_stream,
    493                                 ErrorPtr* error) {
    494   if (!impl_)
    495     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
    496   return impl_->ReadNonBlocking(buffer, size_to_read, size_read, end_of_stream,
    497                                 error);
    498 }
    499 
    500 bool TlsStream::WriteNonBlocking(const void* buffer,
    501                                  size_t size_to_write,
    502                                  size_t* size_written,
    503                                  ErrorPtr* error) {
    504   if (!impl_)
    505     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
    506   return impl_->WriteNonBlocking(buffer, size_to_write, size_written, error);
    507 }
    508 
    509 bool TlsStream::FlushBlocking(ErrorPtr* error) {
    510   if (!impl_)
    511     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
    512   return impl_->Flush(error);
    513 }
    514 
    515 bool TlsStream::CloseBlocking(ErrorPtr* error) {
    516   if (impl_ && !impl_->Close(error))
    517     return false;
    518   impl_.reset();
    519   return true;
    520 }
    521 
    522 bool TlsStream::WaitForData(AccessMode mode,
    523                             const base::Callback<void(AccessMode)>& callback,
    524                             ErrorPtr* error) {
    525   if (!impl_)
    526     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
    527   return impl_->WaitForData(mode, callback, error);
    528 }
    529 
    530 bool TlsStream::WaitForDataBlocking(AccessMode in_mode,
    531                                     base::TimeDelta timeout,
    532                                     AccessMode* out_mode,
    533                                     ErrorPtr* error) {
    534   if (!impl_)
    535     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
    536   return impl_->WaitForDataBlocking(in_mode, timeout, out_mode, error);
    537 }
    538 
    539 void TlsStream::CancelPendingAsyncOperations() {
    540   if (impl_)
    541     impl_->CancelPendingAsyncOperations();
    542   Stream::CancelPendingAsyncOperations();
    543 }
    544 
    545 }  // namespace brillo
    546