Home | History | Annotate | Download | only in provider
      1 // Copyright 2015 The Weave Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "examples/provider/ssl_stream.h"
      6 
      7 #include <openssl/err.h>
      8 
      9 #include <base/bind.h>
     10 #include <base/bind_helpers.h>
     11 #include <weave/provider/task_runner.h>
     12 
     13 namespace weave {
     14 namespace examples {
     15 
     16 namespace {
     17 
     18 void AddSslError(ErrorPtr* error,
     19                  const tracked_objects::Location& location,
     20                  const std::string& error_code,
     21                  unsigned long ssl_error_code) {
     22   ERR_load_BIO_strings();
     23   SSL_load_error_strings();
     24   Error::AddToPrintf(error, location, error_code, "%s: %s",
     25                      ERR_lib_error_string(ssl_error_code),
     26                      ERR_reason_error_string(ssl_error_code));
     27 }
     28 
     29 void RetryAsyncTask(provider::TaskRunner* task_runner,
     30                     const tracked_objects::Location& location,
     31                     const base::Closure& task) {
     32   task_runner->PostDelayedTask(FROM_HERE, task,
     33                                base::TimeDelta::FromMilliseconds(100));
     34 }
     35 
     36 }  // namespace
     37 
     38 void SSLStream::SslDeleter::operator()(BIO* bio) const {
     39   BIO_free(bio);
     40 }
     41 
     42 void SSLStream::SslDeleter::operator()(SSL* ssl) const {
     43   SSL_free(ssl);
     44 }
     45 
     46 void SSLStream::SslDeleter::operator()(SSL_CTX* ctx) const {
     47   SSL_CTX_free(ctx);
     48 }
     49 
     50 SSLStream::SSLStream(provider::TaskRunner* task_runner,
     51                      std::unique_ptr<BIO, SslDeleter> stream_bio)
     52     : task_runner_{task_runner} {
     53   ctx_.reset(SSL_CTX_new(TLSv1_2_client_method()));
     54   CHECK(ctx_);
     55   ssl_.reset(SSL_new(ctx_.get()));
     56 
     57   SSL_set_bio(ssl_.get(), stream_bio.get(), stream_bio.get());
     58   stream_bio.release();  // Owned by ssl now.
     59   SSL_set_connect_state(ssl_.get());
     60 }
     61 
     62 SSLStream::~SSLStream() {
     63   CancelPendingOperations();
     64 }
     65 
     66 void SSLStream::RunTask(const base::Closure& task) {
     67   task.Run();
     68 }
     69 
     70 void SSLStream::Read(void* buffer,
     71                      size_t size_to_read,
     72                      const ReadCallback& callback) {
     73   int res = SSL_read(ssl_.get(), buffer, size_to_read);
     74   if (res > 0) {
     75     task_runner_->PostDelayedTask(
     76         FROM_HERE,
     77         base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
     78                    base::Bind(callback, res, nullptr)),
     79         {});
     80     return;
     81   }
     82 
     83   int err = SSL_get_error(ssl_.get(), res);
     84 
     85   if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
     86     return RetryAsyncTask(
     87         task_runner_, FROM_HERE,
     88         base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(), buffer,
     89                    size_to_read, callback));
     90   }
     91 
     92   ErrorPtr weave_error;
     93   AddSslError(&weave_error, FROM_HERE, "read_failed", err);
     94   return task_runner_->PostDelayedTask(
     95       FROM_HERE,
     96       base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
     97                  base::Bind(callback, 0, base::Passed(&weave_error))),
     98       {});
     99 }
    100 
    101 void SSLStream::Write(const void* buffer,
    102                       size_t size_to_write,
    103                       const WriteCallback& callback) {
    104   int res = SSL_write(ssl_.get(), buffer, size_to_write);
    105   if (res > 0) {
    106     buffer = static_cast<const char*>(buffer) + res;
    107     size_to_write -= res;
    108     if (size_to_write == 0) {
    109       return task_runner_->PostDelayedTask(
    110           FROM_HERE,
    111           base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
    112                      base::Bind(callback, nullptr)),
    113           {});
    114     }
    115 
    116     return RetryAsyncTask(
    117         task_runner_, FROM_HERE,
    118         base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer,
    119                    size_to_write, callback));
    120   }
    121 
    122   int err = SSL_get_error(ssl_.get(), res);
    123 
    124   if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
    125     return RetryAsyncTask(
    126         task_runner_, FROM_HERE,
    127         base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer,
    128                    size_to_write, callback));
    129   }
    130 
    131   ErrorPtr weave_error;
    132   AddSslError(&weave_error, FROM_HERE, "write_failed", err);
    133   task_runner_->PostDelayedTask(
    134       FROM_HERE, base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
    135                             base::Bind(callback, base::Passed(&weave_error))),
    136       {});
    137 }
    138 
    139 void SSLStream::CancelPendingOperations() {
    140   weak_ptr_factory_.InvalidateWeakPtrs();
    141 }
    142 
    143 void SSLStream::Connect(
    144     provider::TaskRunner* task_runner,
    145     const std::string& host,
    146     uint16_t port,
    147     const provider::Network::OpenSslSocketCallback& callback) {
    148   SSL_library_init();
    149 
    150   char end_point[255];
    151   snprintf(end_point, sizeof(end_point), "%s:%u", host.c_str(), port);
    152 
    153   std::unique_ptr<BIO, SslDeleter> stream_bio(BIO_new_connect(end_point));
    154   CHECK(stream_bio);
    155   BIO_set_nbio(stream_bio.get(), 1);
    156 
    157   std::unique_ptr<SSLStream> stream{
    158       new SSLStream{task_runner, std::move(stream_bio)}};
    159   ConnectBio(std::move(stream), callback);
    160 }
    161 
    162 void SSLStream::ConnectBio(
    163     std::unique_ptr<SSLStream> stream,
    164     const provider::Network::OpenSslSocketCallback& callback) {
    165   BIO* bio = SSL_get_rbio(stream->ssl_.get());
    166   if (BIO_do_connect(bio) == 1)
    167     return DoHandshake(std::move(stream), callback);
    168 
    169   auto task_runner = stream->task_runner_;
    170   if (BIO_should_retry(bio)) {
    171     return RetryAsyncTask(
    172         task_runner, FROM_HERE,
    173         base::Bind(&SSLStream::ConnectBio, base::Passed(&stream), callback));
    174   }
    175 
    176   ErrorPtr error;
    177   AddSslError(&error, FROM_HERE, "connect_failed", ERR_get_error());
    178   task_runner->PostDelayedTask(
    179       FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
    180 }
    181 
    182 void SSLStream::DoHandshake(
    183     std::unique_ptr<SSLStream> stream,
    184     const provider::Network::OpenSslSocketCallback& callback) {
    185   int res = SSL_do_handshake(stream->ssl_.get());
    186   auto task_runner = stream->task_runner_;
    187   if (res == 1) {
    188     return task_runner->PostDelayedTask(
    189         FROM_HERE, base::Bind(callback, base::Passed(&stream), nullptr), {});
    190   }
    191 
    192   res = SSL_get_error(stream->ssl_.get(), res);
    193 
    194   if (res == SSL_ERROR_WANT_READ || res == SSL_ERROR_WANT_WRITE) {
    195     return RetryAsyncTask(
    196         task_runner, FROM_HERE,
    197         base::Bind(&SSLStream::DoHandshake, base::Passed(&stream), callback));
    198   }
    199 
    200   ErrorPtr error;
    201   AddSslError(&error, FROM_HERE, "handshake_failed", res);
    202   task_runner->PostDelayedTask(
    203       FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
    204 }
    205 
    206 }  // namespace examples
    207 }  // namespace weave
    208