1 // Copyright 2015 The Weave Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "examples/provider/ssl_stream.h" 6 7 #include <openssl/err.h> 8 9 #include <base/bind.h> 10 #include <base/bind_helpers.h> 11 #include <weave/provider/task_runner.h> 12 13 namespace weave { 14 namespace examples { 15 16 namespace { 17 18 void AddSslError(ErrorPtr* error, 19 const tracked_objects::Location& location, 20 const std::string& error_code, 21 unsigned long ssl_error_code) { 22 ERR_load_BIO_strings(); 23 SSL_load_error_strings(); 24 Error::AddToPrintf(error, location, error_code, "%s: %s", 25 ERR_lib_error_string(ssl_error_code), 26 ERR_reason_error_string(ssl_error_code)); 27 } 28 29 void RetryAsyncTask(provider::TaskRunner* task_runner, 30 const tracked_objects::Location& location, 31 const base::Closure& task) { 32 task_runner->PostDelayedTask(FROM_HERE, task, 33 base::TimeDelta::FromMilliseconds(100)); 34 } 35 36 } // namespace 37 38 void SSLStream::SslDeleter::operator()(BIO* bio) const { 39 BIO_free(bio); 40 } 41 42 void SSLStream::SslDeleter::operator()(SSL* ssl) const { 43 SSL_free(ssl); 44 } 45 46 void SSLStream::SslDeleter::operator()(SSL_CTX* ctx) const { 47 SSL_CTX_free(ctx); 48 } 49 50 SSLStream::SSLStream(provider::TaskRunner* task_runner, 51 std::unique_ptr<BIO, SslDeleter> stream_bio) 52 : task_runner_{task_runner} { 53 ctx_.reset(SSL_CTX_new(TLSv1_2_client_method())); 54 CHECK(ctx_); 55 ssl_.reset(SSL_new(ctx_.get())); 56 57 SSL_set_bio(ssl_.get(), stream_bio.get(), stream_bio.get()); 58 stream_bio.release(); // Owned by ssl now. 59 SSL_set_connect_state(ssl_.get()); 60 } 61 62 SSLStream::~SSLStream() { 63 CancelPendingOperations(); 64 } 65 66 void SSLStream::RunTask(const base::Closure& task) { 67 task.Run(); 68 } 69 70 void SSLStream::Read(void* buffer, 71 size_t size_to_read, 72 const ReadCallback& callback) { 73 int res = SSL_read(ssl_.get(), buffer, size_to_read); 74 if (res > 0) { 75 task_runner_->PostDelayedTask( 76 FROM_HERE, 77 base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(), 78 base::Bind(callback, res, nullptr)), 79 {}); 80 return; 81 } 82 83 int err = SSL_get_error(ssl_.get(), res); 84 85 if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { 86 return RetryAsyncTask( 87 task_runner_, FROM_HERE, 88 base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(), buffer, 89 size_to_read, callback)); 90 } 91 92 ErrorPtr weave_error; 93 AddSslError(&weave_error, FROM_HERE, "read_failed", err); 94 return task_runner_->PostDelayedTask( 95 FROM_HERE, 96 base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(), 97 base::Bind(callback, 0, base::Passed(&weave_error))), 98 {}); 99 } 100 101 void SSLStream::Write(const void* buffer, 102 size_t size_to_write, 103 const WriteCallback& callback) { 104 int res = SSL_write(ssl_.get(), buffer, size_to_write); 105 if (res > 0) { 106 buffer = static_cast<const char*>(buffer) + res; 107 size_to_write -= res; 108 if (size_to_write == 0) { 109 return task_runner_->PostDelayedTask( 110 FROM_HERE, 111 base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(), 112 base::Bind(callback, nullptr)), 113 {}); 114 } 115 116 return RetryAsyncTask( 117 task_runner_, FROM_HERE, 118 base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer, 119 size_to_write, callback)); 120 } 121 122 int err = SSL_get_error(ssl_.get(), res); 123 124 if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { 125 return RetryAsyncTask( 126 task_runner_, FROM_HERE, 127 base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer, 128 size_to_write, callback)); 129 } 130 131 ErrorPtr weave_error; 132 AddSslError(&weave_error, FROM_HERE, "write_failed", err); 133 task_runner_->PostDelayedTask( 134 FROM_HERE, base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(), 135 base::Bind(callback, base::Passed(&weave_error))), 136 {}); 137 } 138 139 void SSLStream::CancelPendingOperations() { 140 weak_ptr_factory_.InvalidateWeakPtrs(); 141 } 142 143 void SSLStream::Connect( 144 provider::TaskRunner* task_runner, 145 const std::string& host, 146 uint16_t port, 147 const provider::Network::OpenSslSocketCallback& callback) { 148 SSL_library_init(); 149 150 char end_point[255]; 151 snprintf(end_point, sizeof(end_point), "%s:%u", host.c_str(), port); 152 153 std::unique_ptr<BIO, SslDeleter> stream_bio(BIO_new_connect(end_point)); 154 CHECK(stream_bio); 155 BIO_set_nbio(stream_bio.get(), 1); 156 157 std::unique_ptr<SSLStream> stream{ 158 new SSLStream{task_runner, std::move(stream_bio)}}; 159 ConnectBio(std::move(stream), callback); 160 } 161 162 void SSLStream::ConnectBio( 163 std::unique_ptr<SSLStream> stream, 164 const provider::Network::OpenSslSocketCallback& callback) { 165 BIO* bio = SSL_get_rbio(stream->ssl_.get()); 166 if (BIO_do_connect(bio) == 1) 167 return DoHandshake(std::move(stream), callback); 168 169 auto task_runner = stream->task_runner_; 170 if (BIO_should_retry(bio)) { 171 return RetryAsyncTask( 172 task_runner, FROM_HERE, 173 base::Bind(&SSLStream::ConnectBio, base::Passed(&stream), callback)); 174 } 175 176 ErrorPtr error; 177 AddSslError(&error, FROM_HERE, "connect_failed", ERR_get_error()); 178 task_runner->PostDelayedTask( 179 FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {}); 180 } 181 182 void SSLStream::DoHandshake( 183 std::unique_ptr<SSLStream> stream, 184 const provider::Network::OpenSslSocketCallback& callback) { 185 int res = SSL_do_handshake(stream->ssl_.get()); 186 auto task_runner = stream->task_runner_; 187 if (res == 1) { 188 return task_runner->PostDelayedTask( 189 FROM_HERE, base::Bind(callback, base::Passed(&stream), nullptr), {}); 190 } 191 192 res = SSL_get_error(stream->ssl_.get(), res); 193 194 if (res == SSL_ERROR_WANT_READ || res == SSL_ERROR_WANT_WRITE) { 195 return RetryAsyncTask( 196 task_runner, FROM_HERE, 197 base::Bind(&SSLStream::DoHandshake, base::Passed(&stream), callback)); 198 } 199 200 ErrorPtr error; 201 AddSslError(&error, FROM_HERE, "handshake_failed", res); 202 task_runner->PostDelayedTask( 203 FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {}); 204 } 205 206 } // namespace examples 207 } // namespace weave 208