Home | History | Annotate | Download | only in http
      1 // Copyright 2014 The Chromium OS Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <brillo/http/http_transport_curl.h>
      6 
      7 #include <limits>
      8 
      9 #include <base/bind.h>
     10 #include <base/logging.h>
     11 #include <base/message_loop/message_loop.h>
     12 #include <brillo/http/http_connection_curl.h>
     13 #include <brillo/http/http_request.h>
     14 #include <brillo/strings/string_utils.h>
     15 
     16 namespace {
     17 
     18 const char kCACertificatePath[] =
     19 #ifdef __ANDROID__
     20     "/system/etc/security/cacerts_google";
     21 #else
     22     "/usr/share/brillo-ca-certificates";
     23 #endif
     24 
     25 }  // namespace
     26 
     27 namespace brillo {
     28 namespace http {
     29 namespace curl {
     30 
     31 // This is a class that stores connection data on particular CURL socket
     32 // and provides file descriptor watcher to monitor read and/or write operations
     33 // on the socket's file descriptor.
     34 class Transport::SocketPollData : public base::MessageLoopForIO::Watcher {
     35  public:
     36   SocketPollData(const std::shared_ptr<CurlInterface>& curl_interface,
     37                  CURLM* curl_multi_handle,
     38                  Transport* transport,
     39                  curl_socket_t socket_fd)
     40       : curl_interface_(curl_interface),
     41         curl_multi_handle_(curl_multi_handle),
     42         transport_(transport),
     43         socket_fd_(socket_fd) {}
     44 
     45   // Returns the pointer for the socket-specific file descriptor watcher.
     46   base::MessageLoopForIO::FileDescriptorWatcher* GetWatcher() {
     47     return &file_descriptor_watcher_;
     48   }
     49 
     50  private:
     51   // Overrides from base::MessageLoopForIO::Watcher.
     52   void OnFileCanReadWithoutBlocking(int fd) override {
     53     OnSocketReady(fd, CURL_CSELECT_IN);
     54   }
     55   void OnFileCanWriteWithoutBlocking(int fd) override {
     56     OnSocketReady(fd, CURL_CSELECT_OUT);
     57   }
     58 
     59   // Data on the socket is available to be read from or written to.
     60   // Notify CURL of the action it needs to take on the socket file descriptor.
     61   void OnSocketReady(int fd, int action) {
     62     CHECK_EQ(socket_fd_, fd) << "Unexpected socket file descriptor";
     63     int still_running_count = 0;
     64     CURLMcode code = curl_interface_->MultiSocketAction(
     65         curl_multi_handle_, socket_fd_, action, &still_running_count);
     66     CHECK_NE(CURLM_CALL_MULTI_PERFORM, code)
     67         << "CURL should no longer return CURLM_CALL_MULTI_PERFORM here";
     68 
     69     if (code == CURLM_OK)
     70       transport_->ProcessAsyncCurlMessages();
     71   }
     72 
     73   // The CURL interface to use.
     74   std::shared_ptr<CurlInterface> curl_interface_;
     75   // CURL multi-handle associated with the transport.
     76   CURLM* curl_multi_handle_;
     77   // Transport object itself.
     78   Transport* transport_;
     79   // The socket file descriptor for the connection.
     80   curl_socket_t socket_fd_;
     81   // File descriptor watcher to notify us of asynchronous I/O on the FD.
     82   base::MessageLoopForIO::FileDescriptorWatcher file_descriptor_watcher_;
     83 
     84   DISALLOW_COPY_AND_ASSIGN(SocketPollData);
     85 };
     86 
     87 // The request data associated with an asynchronous operation on a particular
     88 // connection.
     89 struct Transport::AsyncRequestData {
     90   // Success/error callbacks to be invoked at the end of the request.
     91   SuccessCallback success_callback;
     92   ErrorCallback error_callback;
     93   // We store a connection here to make sure the object is alive for
     94   // as long as asynchronous operation is running.
     95   std::shared_ptr<Connection> connection;
     96   // The ID of this request.
     97   RequestID request_id;
     98 };
     99 
    100 Transport::Transport(const std::shared_ptr<CurlInterface>& curl_interface)
    101     : curl_interface_{curl_interface} {
    102   VLOG(2) << "curl::Transport created";
    103 }
    104 
    105 Transport::Transport(const std::shared_ptr<CurlInterface>& curl_interface,
    106                      const std::string& proxy)
    107     : curl_interface_{curl_interface}, proxy_{proxy} {
    108   VLOG(2) << "curl::Transport created with proxy " << proxy;
    109 }
    110 
    111 Transport::~Transport() {
    112   ShutDownAsyncCurl();
    113   VLOG(2) << "curl::Transport destroyed";
    114 }
    115 
    116 std::shared_ptr<http::Connection> Transport::CreateConnection(
    117     const std::string& url,
    118     const std::string& method,
    119     const HeaderList& headers,
    120     const std::string& user_agent,
    121     const std::string& referer,
    122     brillo::ErrorPtr* error) {
    123   std::shared_ptr<http::Connection> connection;
    124   CURL* curl_handle = curl_interface_->EasyInit();
    125   if (!curl_handle) {
    126     LOG(ERROR) << "Failed to initialize CURL";
    127     brillo::Error::AddTo(error, FROM_HERE, http::kErrorDomain,
    128                          "curl_init_failed", "Failed to initialize CURL");
    129     return connection;
    130   }
    131 
    132   LOG(INFO) << "Sending a " << method << " request to " << url;
    133   CURLcode code = curl_interface_->EasySetOptStr(curl_handle, CURLOPT_URL, url);
    134 
    135   if (code == CURLE_OK) {
    136     code = curl_interface_->EasySetOptStr(curl_handle, CURLOPT_CAPATH,
    137                                           kCACertificatePath);
    138   }
    139   if (code == CURLE_OK) {
    140     code =
    141         curl_interface_->EasySetOptInt(curl_handle, CURLOPT_SSL_VERIFYPEER, 1);
    142   }
    143   if (code == CURLE_OK) {
    144     code =
    145         curl_interface_->EasySetOptInt(curl_handle, CURLOPT_SSL_VERIFYHOST, 2);
    146   }
    147   if (code == CURLE_OK && !user_agent.empty()) {
    148     code = curl_interface_->EasySetOptStr(
    149         curl_handle, CURLOPT_USERAGENT, user_agent);
    150   }
    151   if (code == CURLE_OK && !referer.empty()) {
    152     code =
    153         curl_interface_->EasySetOptStr(curl_handle, CURLOPT_REFERER, referer);
    154   }
    155   if (code == CURLE_OK && !proxy_.empty()) {
    156     code = curl_interface_->EasySetOptStr(curl_handle, CURLOPT_PROXY, proxy_);
    157   }
    158   if (code == CURLE_OK) {
    159     int64_t timeout_ms = connection_timeout_.InMillisecondsRoundedUp();
    160 
    161     if (timeout_ms > 0 && timeout_ms <= std::numeric_limits<int>::max()) {
    162       code = curl_interface_->EasySetOptInt(
    163           curl_handle, CURLOPT_TIMEOUT_MS,
    164           static_cast<int>(timeout_ms));
    165     }
    166   }
    167 
    168   // Setup HTTP request method and optional request body.
    169   if (code == CURLE_OK) {
    170     if (method == request_type::kGet) {
    171       code = curl_interface_->EasySetOptInt(curl_handle, CURLOPT_HTTPGET, 1);
    172     } else if (method == request_type::kHead) {
    173       code = curl_interface_->EasySetOptInt(curl_handle, CURLOPT_NOBODY, 1);
    174     } else if (method == request_type::kPut) {
    175       code = curl_interface_->EasySetOptInt(curl_handle, CURLOPT_UPLOAD, 1);
    176     } else {
    177       // POST and custom request methods
    178       code = curl_interface_->EasySetOptInt(curl_handle, CURLOPT_POST, 1);
    179       if (code == CURLE_OK) {
    180         code = curl_interface_->EasySetOptPtr(
    181             curl_handle, CURLOPT_POSTFIELDS, nullptr);
    182       }
    183       if (code == CURLE_OK && method != request_type::kPost) {
    184         code = curl_interface_->EasySetOptStr(
    185             curl_handle, CURLOPT_CUSTOMREQUEST, method);
    186       }
    187     }
    188   }
    189 
    190   if (code != CURLE_OK) {
    191     AddEasyCurlError(error, FROM_HERE, code, curl_interface_.get());
    192     curl_interface_->EasyCleanup(curl_handle);
    193     return connection;
    194   }
    195 
    196   connection = std::make_shared<http::curl::Connection>(
    197       curl_handle, method, curl_interface_, shared_from_this());
    198   if (!connection->SendHeaders(headers, error)) {
    199     connection.reset();
    200   }
    201   return connection;
    202 }
    203 
    204 void Transport::RunCallbackAsync(const tracked_objects::Location& from_here,
    205                                  const base::Closure& callback) {
    206   base::MessageLoopForIO::current()->PostTask(from_here, callback);
    207 }
    208 
    209 RequestID Transport::StartAsyncTransfer(http::Connection* connection,
    210                                         const SuccessCallback& success_callback,
    211                                         const ErrorCallback& error_callback) {
    212   brillo::ErrorPtr error;
    213   if (!SetupAsyncCurl(&error)) {
    214     RunCallbackAsync(
    215         FROM_HERE, base::Bind(error_callback, 0, base::Owned(error.release())));
    216     return 0;
    217   }
    218 
    219   RequestID request_id = ++last_request_id_;
    220 
    221   auto curl_connection = static_cast<http::curl::Connection*>(connection);
    222   std::unique_ptr<AsyncRequestData> request_data{new AsyncRequestData};
    223   // Add the request data to |async_requests_| before adding the CURL handle
    224   // in case CURL feels like calling the socket callback synchronously which
    225   // will need the data to be in |async_requests_| map already.
    226   request_data->success_callback = success_callback;
    227   request_data->error_callback = error_callback;
    228   request_data->connection =
    229       std::static_pointer_cast<Connection>(curl_connection->shared_from_this());
    230   request_data->request_id = request_id;
    231   async_requests_.emplace(curl_connection, std::move(request_data));
    232   request_id_map_.emplace(request_id, curl_connection);
    233 
    234   // Add the connection's CURL handle to the multi-handle.
    235   CURLMcode code = curl_interface_->MultiAddHandle(
    236       curl_multi_handle_, curl_connection->curl_handle_);
    237   if (code != CURLM_OK) {
    238     brillo::ErrorPtr error;
    239     AddMultiCurlError(&error, FROM_HERE, code, curl_interface_.get());
    240     RunCallbackAsync(
    241         FROM_HERE, base::Bind(error_callback, 0, base::Owned(error.release())));
    242     async_requests_.erase(curl_connection);
    243     request_id_map_.erase(request_id);
    244     return 0;
    245   }
    246   LOG(INFO) << "Started asynchronous HTTP request with ID " << request_id;
    247   return request_id;
    248 }
    249 
    250 bool Transport::CancelRequest(RequestID request_id) {
    251   auto p = request_id_map_.find(request_id);
    252   if (p == request_id_map_.end()) {
    253     // The request must have been completed already...
    254     // This is not necessarily an error condition, so fail gracefully.
    255     LOG(WARNING) << "HTTP request #" << request_id << " not found";
    256     return false;
    257   }
    258   LOG(INFO) << "Canceling HTTP request #" << request_id;
    259   CleanAsyncConnection(p->second);
    260   return true;
    261 }
    262 
    263 void Transport::SetDefaultTimeout(base::TimeDelta timeout) {
    264   connection_timeout_ = timeout;
    265 }
    266 
    267 void Transport::AddEasyCurlError(brillo::ErrorPtr* error,
    268                                  const tracked_objects::Location& location,
    269                                  CURLcode code,
    270                                  CurlInterface* curl_interface) {
    271   brillo::Error::AddTo(error, location, "curl_easy_error",
    272                        brillo::string_utils::ToString(code),
    273                        curl_interface->EasyStrError(code));
    274 }
    275 
    276 void Transport::AddMultiCurlError(brillo::ErrorPtr* error,
    277                                   const tracked_objects::Location& location,
    278                                   CURLMcode code,
    279                                   CurlInterface* curl_interface) {
    280   brillo::Error::AddTo(error, location, "curl_multi_error",
    281                        brillo::string_utils::ToString(code),
    282                        curl_interface->MultiStrError(code));
    283 }
    284 
    285 bool Transport::SetupAsyncCurl(brillo::ErrorPtr* error) {
    286   if (curl_multi_handle_)
    287     return true;
    288 
    289   curl_multi_handle_ = curl_interface_->MultiInit();
    290   if (!curl_multi_handle_) {
    291     LOG(ERROR) << "Failed to initialize CURL";
    292     brillo::Error::AddTo(error, FROM_HERE, http::kErrorDomain,
    293                          "curl_init_failed", "Failed to initialize CURL");
    294     return false;
    295   }
    296 
    297   CURLMcode code = curl_interface_->MultiSetSocketCallback(
    298       curl_multi_handle_, &Transport::MultiSocketCallback, this);
    299   if (code == CURLM_OK) {
    300     code = curl_interface_->MultiSetTimerCallback(
    301         curl_multi_handle_, &Transport::MultiTimerCallback, this);
    302   }
    303   if (code != CURLM_OK) {
    304     AddMultiCurlError(error, FROM_HERE, code, curl_interface_.get());
    305     return false;
    306   }
    307   return true;
    308 }
    309 
    310 void Transport::ShutDownAsyncCurl() {
    311   if (!curl_multi_handle_)
    312     return;
    313   LOG_IF(WARNING, !poll_data_map_.empty())
    314       << "There are pending requests at the time of transport's shutdown";
    315   // Make sure we are not leaking any memory here.
    316   for (const auto& pair : poll_data_map_)
    317     delete pair.second;
    318   poll_data_map_.clear();
    319   curl_interface_->MultiCleanup(curl_multi_handle_);
    320   curl_multi_handle_ = nullptr;
    321 }
    322 
    323 int Transport::MultiSocketCallback(CURL* easy,
    324                                    curl_socket_t s,
    325                                    int what,
    326                                    void* userp,
    327                                    void* socketp) {
    328   auto transport = static_cast<Transport*>(userp);
    329   CHECK(transport) << "Transport must be set for this callback";
    330   auto poll_data = static_cast<SocketPollData*>(socketp);
    331   if (!poll_data) {
    332     // We haven't attached polling data to this socket yet. Let's do this now.
    333     poll_data = new SocketPollData{transport->curl_interface_,
    334                                    transport->curl_multi_handle_,
    335                                    transport,
    336                                    s};
    337     transport->poll_data_map_.emplace(std::make_pair(easy, s), poll_data);
    338     transport->curl_interface_->MultiAssign(
    339         transport->curl_multi_handle_, s, poll_data);
    340   }
    341 
    342   if (what == CURL_POLL_NONE) {
    343     return 0;
    344   } else if (what == CURL_POLL_REMOVE) {
    345     // Remove the attached data from the socket.
    346     transport->curl_interface_->MultiAssign(
    347         transport->curl_multi_handle_, s, nullptr);
    348     transport->poll_data_map_.erase(std::make_pair(easy, s));
    349 
    350     // Make sure we stop watching the socket file descriptor now, before
    351     // we schedule the SocketPollData for deletion.
    352     poll_data->GetWatcher()->StopWatchingFileDescriptor();
    353     // This method can be called indirectly from SocketPollData::OnSocketReady,
    354     // so delay destruction of SocketPollData object till the next loop cycle.
    355     base::MessageLoopForIO::current()->DeleteSoon(FROM_HERE, poll_data);
    356     return 0;
    357   }
    358 
    359   base::MessageLoopForIO::Mode watch_mode = base::MessageLoopForIO::WATCH_READ;
    360   switch (what) {
    361     case CURL_POLL_IN:
    362       watch_mode = base::MessageLoopForIO::WATCH_READ;
    363       break;
    364     case CURL_POLL_OUT:
    365       watch_mode = base::MessageLoopForIO::WATCH_WRITE;
    366       break;
    367     case CURL_POLL_INOUT:
    368       watch_mode = base::MessageLoopForIO::WATCH_READ_WRITE;
    369       break;
    370     default:
    371       LOG(FATAL) << "Unknown CURL socket action: " << what;
    372       break;
    373   }
    374 
    375   // WatchFileDescriptor() can be called with the same controller object
    376   // (watcher) to amend the watch mode, however this has cumulative effect.
    377   // For example, if we were watching a file descriptor for READ operations
    378   // and now call it to watch for WRITE, it will end up watching for both
    379   // READ and WRITE. This is not what we want here, so stop watching the
    380   // file descriptor on previous controller before starting with a different
    381   // mode.
    382   if (!poll_data->GetWatcher()->StopWatchingFileDescriptor())
    383     LOG(WARNING) << "Failed to stop watching the previous socket descriptor";
    384   CHECK(base::MessageLoopForIO::current()->WatchFileDescriptor(
    385       s, true, watch_mode, poll_data->GetWatcher(), poll_data))
    386       << "Failed to watch the CURL socket.";
    387   return 0;
    388 }
    389 
    390 // CURL actually uses "long" types in callback signatures, so we must comply.
    391 int Transport::MultiTimerCallback(CURLM* /* multi */,
    392                                   long timeout_ms,  // NOLINT(runtime/int)
    393                                   void* userp) {
    394   auto transport = static_cast<Transport*>(userp);
    395   // Cancel any previous timer callbacks.
    396   transport->weak_ptr_factory_for_timer_.InvalidateWeakPtrs();
    397   if (timeout_ms >= 0) {
    398     base::MessageLoopForIO::current()->PostDelayedTask(
    399       FROM_HERE,
    400       base::Bind(&Transport::OnTimer,
    401                  transport->weak_ptr_factory_for_timer_.GetWeakPtr()),
    402       base::TimeDelta::FromMilliseconds(timeout_ms));
    403   }
    404   return 0;
    405 }
    406 
    407 void Transport::OnTimer() {
    408   if (curl_multi_handle_) {
    409     int still_running_count = 0;
    410     curl_interface_->MultiSocketAction(
    411         curl_multi_handle_, CURL_SOCKET_TIMEOUT, 0, &still_running_count);
    412     ProcessAsyncCurlMessages();
    413   }
    414 }
    415 
    416 void Transport::ProcessAsyncCurlMessages() {
    417   CURLMsg* msg = nullptr;
    418   int msgs_left = 0;
    419   while ((msg = curl_interface_->MultiInfoRead(curl_multi_handle_,
    420                                                &msgs_left))) {
    421     if (msg->msg == CURLMSG_DONE) {
    422       // Async I/O complete for a connection. Invoke the user callbacks.
    423       Connection* connection = nullptr;
    424       CHECK_EQ(CURLE_OK,
    425                curl_interface_->EasyGetInfoPtr(
    426                    msg->easy_handle,
    427                    CURLINFO_PRIVATE,
    428                    reinterpret_cast<void**>(&connection)));
    429       CHECK(connection != nullptr);
    430       OnTransferComplete(connection, msg->data.result);
    431     }
    432   }
    433 }
    434 
    435 void Transport::OnTransferComplete(Connection* connection, CURLcode code) {
    436   auto p = async_requests_.find(connection);
    437   CHECK(p != async_requests_.end()) << "Unknown connection";
    438   AsyncRequestData* request_data = p->second.get();
    439   LOG(INFO) << "HTTP request # " << request_data->request_id
    440             << " has completed "
    441             << (code == CURLE_OK ? "successfully" : "with an error");
    442   if (code != CURLE_OK) {
    443     brillo::ErrorPtr error;
    444     AddEasyCurlError(&error, FROM_HERE, code, curl_interface_.get());
    445     RunCallbackAsync(FROM_HERE,
    446                      base::Bind(request_data->error_callback,
    447                                 p->second->request_id,
    448                                 base::Owned(error.release())));
    449   } else {
    450     LOG(INFO) << "Response: " << connection->GetResponseStatusCode() << " ("
    451               << connection->GetResponseStatusText() << ")";
    452     brillo::ErrorPtr error;
    453     // Rewind the response data stream to the beginning so the clients can
    454     // read the data back.
    455     const auto& stream = request_data->connection->response_data_stream_;
    456     if (stream && stream->CanSeek() && !stream->SetPosition(0, &error)) {
    457       RunCallbackAsync(FROM_HERE,
    458                        base::Bind(request_data->error_callback,
    459                                   p->second->request_id,
    460                                   base::Owned(error.release())));
    461     } else {
    462       std::unique_ptr<Response> resp{new Response{request_data->connection}};
    463       RunCallbackAsync(FROM_HERE,
    464                        base::Bind(request_data->success_callback,
    465                                   p->second->request_id,
    466                                   base::Passed(&resp)));
    467     }
    468   }
    469   // In case of an error on CURL side, we would have dispatched the error
    470   // callback and we need to clean up the current connection, however the
    471   // error callback has no reference to the connection itself and
    472   // |async_requests_| is the only reference to the shared pointer that
    473   // maintains the lifetime of |connection| and possibly even this Transport
    474   // object instance. As a result, if we call CleanAsyncConnection() directly,
    475   // there is a chance that this object might be deleted.
    476   // Instead, schedule an asynchronous task to clean up the connection.
    477   RunCallbackAsync(FROM_HERE,
    478                    base::Bind(&Transport::CleanAsyncConnection,
    479                               weak_ptr_factory_.GetWeakPtr(),
    480                               connection));
    481 }
    482 
    483 void Transport::CleanAsyncConnection(Connection* connection) {
    484   auto p = async_requests_.find(connection);
    485   CHECK(p != async_requests_.end()) << "Unknown connection";
    486   // Remove the request data from the map first, since this might be the only
    487   // reference to the Connection class and even possibly to this Transport.
    488   auto request_data = std::move(p->second);
    489 
    490   // Remove associated request ID.
    491   request_id_map_.erase(request_data->request_id);
    492 
    493   // Remove the connection's CURL handle from multi-handle.
    494   curl_interface_->MultiRemoveHandle(curl_multi_handle_,
    495                                      connection->curl_handle_);
    496 
    497   // Remove all the socket data associated with this connection.
    498   auto iter = poll_data_map_.begin();
    499   while (iter != poll_data_map_.end()) {
    500     if (iter->first.first == connection->curl_handle_)
    501       iter = poll_data_map_.erase(iter);
    502     else
    503       ++iter;
    504   }
    505   // Remove pending asynchronous request data.
    506   // This must be last since there is a chance of this object being
    507   // destroyed as the result. See the comment in Transport::OnTransferComplete.
    508   async_requests_.erase(p);
    509 }
    510 
    511 }  // namespace curl
    512 }  // namespace http
    513 }  // namespace brillo
    514