Home | History | Annotate | Download | only in webservd
      1 // Copyright 2015 The Android Open Source Project
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "webservd/protocol_handler.h"
     16 
     17 #include <linux/tcp.h>
     18 #include <microhttpd.h>
     19 #include <netinet/in.h>
     20 #include <sys/socket.h>
     21 
     22 #include <algorithm>
     23 #include <limits>
     24 #include <vector>
     25 
     26 #include <base/bind.h>
     27 #include <base/guid.h>
     28 #include <base/logging.h>
     29 #include <base/message_loop/message_loop.h>
     30 
     31 #include "webservd/request.h"
     32 #include "webservd/request_handler_interface.h"
     33 #include "webservd/server_interface.h"
     34 
     35 namespace webservd {
     36 
     37 // Helper class to provide static callback methods to libmicrohttpd library,
     38 // with the ability to access private methods of Server class.
     39 class ServerHelper final {
     40  public:
     41   static int ConnectionHandler(void *cls,
     42                                MHD_Connection* connection,
     43                                const char* url,
     44                                const char* method,
     45                                const char* version,
     46                                const char* upload_data,
     47                                size_t* upload_data_size,
     48                                void** con_cls) {
     49     auto handler = reinterpret_cast<ProtocolHandler*>(cls);
     50     if (nullptr == *con_cls) {
     51       std::string request_handler_id = handler->FindRequestHandler(url, method);
     52       std::unique_ptr<Request> request{new Request{
     53           request_handler_id, url, method, version, connection, handler
     54       }};
     55       if (!request->BeginRequestData())
     56         return MHD_NO;
     57 
     58       // Pass the raw pointer here in order to interface with libmicrohttpd's
     59       // old-style C API.
     60       *con_cls = request.release();
     61     } else {
     62       auto request = reinterpret_cast<Request*>(*con_cls);
     63       if (*upload_data_size) {
     64         if (!request->AddRequestData(upload_data, upload_data_size))
     65           return MHD_NO;
     66       } else {
     67         request->EndRequestData();
     68       }
     69     }
     70     return MHD_YES;
     71   }
     72 
     73   static void RequestCompleted(void* /* cls */,
     74                                MHD_Connection*  /* connection */,
     75                                void** con_cls,
     76                                MHD_RequestTerminationCode toe) {
     77     if (toe != MHD_REQUEST_TERMINATED_COMPLETED_OK) {
     78       LOG(ERROR) << "Web request terminated abnormally with error code: "
     79                  << toe;
     80     }
     81     auto request = reinterpret_cast<Request*>(*con_cls);
     82     *con_cls = nullptr;
     83     delete request;
     84   }
     85 };
     86 
     87 ProtocolHandler::ProtocolHandler(const std::string& name,
     88                                  ServerInterface* server_interface)
     89     : id_{base::GenerateGUID()},
     90       name_{name},
     91       server_interface_{server_interface} {}
     92 
     93 ProtocolHandler::~ProtocolHandler() {
     94   Stop();
     95 }
     96 
     97 std::string ProtocolHandler::AddRequestHandler(
     98     const std::string& url,
     99     const std::string& method,
    100     std::unique_ptr<RequestHandlerInterface> handler) {
    101   std::string handler_id = base::GenerateGUID();
    102   request_handlers_.emplace(handler_id,
    103                             HandlerMapEntry{url, method, std::move(handler)});
    104   return handler_id;
    105 }
    106 
    107 bool ProtocolHandler::RemoveRequestHandler(const std::string& handler_id) {
    108   return request_handlers_.erase(handler_id) == 1;
    109 }
    110 
    111 std::string ProtocolHandler::FindRequestHandler(
    112     const base::StringPiece& url,
    113     const base::StringPiece& method) const {
    114   size_t score = std::numeric_limits<size_t>::max();
    115   std::string handler_id;
    116   for (const auto& pair : request_handlers_) {
    117     std::string handler_url = pair.second.url;
    118     bool url_match = (handler_url == url);
    119     bool method_match = (pair.second.method == method);
    120 
    121     // Try exact match first. If everything matches, we have our handler.
    122     if (url_match && method_match)
    123       return pair.first;
    124 
    125     // Calculate the current handler's similarity score. The lower the score
    126     // the better the match is...
    127     size_t current_score = 0;
    128     if (!url_match && !handler_url.empty() && handler_url.back() == '/') {
    129       if (url.starts_with(handler_url)) {
    130         url_match = true;
    131         // Use the difference in URL length as URL match quality proxy.
    132         // The longer URL, the more specific (better) match is.
    133         // Multiply by 2 to allow for extra score point for matching the method.
    134         current_score = (url.size() - handler_url.size()) * 2;
    135       }
    136     }
    137 
    138     if (!method_match && pair.second.method.empty()) {
    139       // If the handler didn't specify the method it handles, this means
    140       // it doesn't care. However this isn't the exact match, so bump
    141       // the score up one point.
    142       method_match = true;
    143       ++current_score;
    144     }
    145 
    146     if (url_match && method_match && current_score < score) {
    147       score = current_score;
    148       handler_id = pair.first;
    149     }
    150   }
    151 
    152   return handler_id;
    153 }
    154 
    155 bool ProtocolHandler::Start(Config::ProtocolHandler* config) {
    156   if (server_) {
    157     LOG(ERROR) << "Protocol handler is already running.";
    158     return false;
    159   }
    160 
    161   // If using TLS, the certificate, private key and fingerprint must be
    162   // provided.
    163   CHECK_EQ(config->use_tls, !config->private_key.empty());
    164   CHECK_EQ(config->use_tls, !config->certificate.empty());
    165   CHECK_EQ(config->use_tls, !config->certificate_fingerprint.empty());
    166 
    167   LOG(INFO) << "Starting " << (config->use_tls ? "HTTPS" : "HTTP")
    168             << " protocol handler on port: " << config->port;
    169 
    170   port_ = config->port;
    171   protocol_ = (config->use_tls ? "https" : "http");
    172   certificate_fingerprint_ = config->certificate_fingerprint;
    173 
    174   auto callback_addr =
    175       reinterpret_cast<intptr_t>(&ServerHelper::RequestCompleted);
    176   uint32_t flags = MHD_NO_FLAG;
    177   if (server_interface_->GetConfig().use_debug)
    178     flags |= MHD_USE_DEBUG;
    179 
    180   // Enable IPv6 if supported.
    181   if (server_interface_->GetConfig().use_ipv6)
    182     flags |= MHD_USE_DUAL_STACK;
    183   flags |= MHD_USE_TCP_FASTOPEN;  // Use TCP Fast Open (see RFC 7413).
    184   flags |= MHD_USE_SUSPEND_RESUME;  // Allow suspending/resuming connections.
    185 
    186   // MHD uses timeout of 0 to mean there is no timeout.
    187   int timeout = server_interface_->GetConfig().default_request_timeout_seconds;
    188   if (timeout < 0)
    189     timeout = 0;
    190 
    191   std::vector<MHD_OptionItem> options{
    192     {MHD_OPTION_CONNECTION_LIMIT, 10, nullptr},
    193     {MHD_OPTION_CONNECTION_TIMEOUT, timeout, nullptr},
    194     {MHD_OPTION_NOTIFY_COMPLETED, callback_addr, nullptr},
    195   };
    196 
    197   if (config->socket_fd != -1) {
    198     // Take ownership of the socket.
    199     int socket_fd = config->socket_fd;
    200     config->socket_fd = -1;
    201 
    202     // Set some more socket options. These options were set in libmicrohttpd.
    203     int on = 1;
    204     if (setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) {
    205       // Treat this as a non-fatal failure. Just continue after logging.
    206       PLOG(WARNING) << "Failed to set SO_REUSEADDR option on listening socket.";
    207     }
    208     on = (MHD_USE_DUAL_STACK != (flags & MHD_USE_DUAL_STACK));
    209     if (setsockopt(socket_fd, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) < 0) {
    210       PLOG(WARNING) << "Failed to set IPV6_V6ONLY option on listening socket.";
    211       close(socket_fd);
    212       return false;
    213     }
    214 
    215     // Bind socket to the port.
    216     sockaddr_in6 addr = {};
    217     addr.sin6_family = AF_INET6;
    218     addr.sin6_port = htons(config->port);
    219     if (bind(socket_fd, reinterpret_cast<const sockaddr*>(&addr),
    220              sizeof(addr)) < 0) {
    221       PLOG(ERROR) << "Failed to bind the socket to port " << config->port;
    222       close(socket_fd);
    223       return false;
    224     }
    225     if ((flags & MHD_USE_TCP_FASTOPEN) != 0) {
    226       // This is the default value from libmicrohttpd.
    227       int fastopen_queue_size = 10;
    228       if (setsockopt(socket_fd, IPPROTO_TCP, TCP_FASTOPEN,
    229                      &fastopen_queue_size, sizeof(fastopen_queue_size)) < 0) {
    230         // Treat this as a non-fatal failure. Just continue after logging.
    231         PLOG(WARNING) << "Failed to set TCP_FASTOPEN option on socket.";
    232       }
    233     }
    234 
    235     // Start listening on the socket.
    236     // 32 connections is the value used by libmicrohttpd.
    237     if (listen(socket_fd, 32) < 0) {
    238       PLOG(ERROR) << "Failed to listen for connections on the socket.";
    239       close(socket_fd);
    240       return false;
    241     }
    242 
    243     // Finally, pass the socket to libmicrohttpd.
    244     options.push_back(
    245         MHD_OptionItem{MHD_OPTION_LISTEN_SOCKET, socket_fd, nullptr});
    246   }
    247 
    248   // libmicrohttpd expects both the key and certificate to be zero-terminated
    249   // strings. Make sure they are terminated properly.
    250   brillo::SecureBlob private_key_copy = config->private_key;
    251   brillo::Blob certificate_copy = config->certificate;
    252   private_key_copy.push_back(0);
    253   certificate_copy.push_back(0);
    254 
    255   if (config->use_tls) {
    256     flags |= MHD_USE_SSL;
    257     options.push_back(
    258         MHD_OptionItem{MHD_OPTION_HTTPS_MEM_KEY, 0, private_key_copy.data()});
    259     options.push_back(
    260         MHD_OptionItem{MHD_OPTION_HTTPS_MEM_CERT, 0, certificate_copy.data()});
    261   }
    262 
    263   options.push_back(MHD_OptionItem{MHD_OPTION_END, 0, nullptr});
    264 
    265   server_ = MHD_start_daemon(flags, config->port, nullptr, nullptr,
    266                              &ServerHelper::ConnectionHandler, this,
    267                              MHD_OPTION_ARRAY, options.data(), MHD_OPTION_END);
    268   if (!server_) {
    269     PLOG(ERROR) << "Failed to create protocol handler on port " << config->port;
    270     return false;
    271   }
    272   server_interface_->ProtocolHandlerStarted(this);
    273   DoWork();
    274   LOG(INFO) << "Protocol handler started";
    275   return true;
    276 }
    277 
    278 bool ProtocolHandler::Stop() {
    279   if (server_) {
    280     LOG(INFO) << "Shutting down the protocol handler...";
    281     MHD_stop_daemon(server_);
    282     server_ = nullptr;
    283     server_interface_->ProtocolHandlerStopped(this);
    284     LOG(INFO) << "Protocol handler shutdown complete";
    285   }
    286   port_ = 0;
    287   protocol_.clear();
    288   certificate_fingerprint_.clear();
    289   return true;
    290 }
    291 
    292 void ProtocolHandler::AddRequest(Request* request) {
    293   requests_.emplace(request->GetID(), request);
    294 }
    295 
    296 void ProtocolHandler::RemoveRequest(Request* request) {
    297   requests_.erase(request->GetID());
    298 }
    299 
    300 Request* ProtocolHandler::GetRequest(const std::string& request_id) const {
    301   auto p = requests_.find(request_id);
    302   return (p != requests_.end()) ? p->second : nullptr;
    303 }
    304 
    305 // A file descriptor watcher class that oversees I/O operation notification
    306 // on particular socket file descriptor.
    307 class ProtocolHandler::Watcher final : public base::MessageLoopForIO::Watcher {
    308  public:
    309   Watcher(ProtocolHandler* handler, int fd) : fd_{fd}, handler_{handler} {}
    310 
    311   void Watch(bool read, bool write) {
    312     if (read == watching_read_ && write == watching_write_ && !triggered_)
    313       return;
    314 
    315     controller_.StopWatchingFileDescriptor();
    316     watching_read_ = read;
    317     watching_write_ = write;
    318     triggered_ = false;
    319 
    320     auto mode = base::MessageLoopForIO::WATCH_READ_WRITE;
    321     if (watching_read_ && watching_write_)
    322       mode = base::MessageLoopForIO::WATCH_READ_WRITE;
    323     else if (watching_read_)
    324       mode = base::MessageLoopForIO::WATCH_READ;
    325     else if (watching_write_)
    326       mode = base::MessageLoopForIO::WATCH_WRITE;
    327     base::MessageLoopForIO::current()->WatchFileDescriptor(fd_, false, mode,
    328                                                            &controller_, this);
    329   }
    330 
    331   // Overrides from base::MessageLoopForIO::Watcher.
    332   void OnFileCanReadWithoutBlocking(int /* fd */) override {
    333     triggered_ = true;
    334     handler_->ScheduleWork();
    335   }
    336 
    337   void OnFileCanWriteWithoutBlocking(int /* fd */) override {
    338     triggered_ = true;
    339     handler_->ScheduleWork();
    340   }
    341 
    342   int GetFileDescriptor() const { return fd_; }
    343 
    344  private:
    345   int fd_{-1};
    346   ProtocolHandler* handler_{nullptr};
    347   bool watching_read_{false};
    348   bool watching_write_{false};
    349   bool triggered_{false};
    350   base::MessageLoopForIO::FileDescriptorWatcher controller_;
    351 
    352   DISALLOW_COPY_AND_ASSIGN(Watcher);
    353 };
    354 
    355 void ProtocolHandler::ScheduleWork() {
    356   if (work_scheduled_)
    357     return;
    358 
    359   work_scheduled_ = true;
    360   base::MessageLoopForIO::current()->PostTask(
    361       FROM_HERE,
    362       base::Bind(&ProtocolHandler::DoWork, weak_ptr_factory_.GetWeakPtr()));
    363 }
    364 
    365 void ProtocolHandler::DoWork() {
    366   work_scheduled_ = false;
    367   weak_ptr_factory_.InvalidateWeakPtrs();
    368 
    369   // Check if there is any pending work to be done in libmicrohttpd.
    370   MHD_run(server_);
    371 
    372   // Get all the file descriptors from libmicrohttpd and watch for I/O
    373   // operations on them.
    374   fd_set rs;
    375   fd_set ws;
    376   fd_set es;
    377   int max_fd = MHD_INVALID_SOCKET;
    378   FD_ZERO(&rs);
    379   FD_ZERO(&ws);
    380   FD_ZERO(&es);
    381   CHECK_EQ(MHD_YES, MHD_get_fdset(server_, &rs, &ws, &es, &max_fd));
    382 
    383   for (auto& watcher : watchers_) {
    384     int fd = watcher->GetFileDescriptor();
    385     if (FD_ISSET(fd, &rs) || FD_ISSET(fd, &ws)) {
    386       watcher->Watch(FD_ISSET(fd, &rs), FD_ISSET(fd, &ws));
    387       FD_CLR(fd, &rs);
    388       FD_CLR(fd, &ws);
    389     } else {
    390       watcher.reset();
    391     }
    392   }
    393 
    394   watchers_.erase(std::remove(watchers_.begin(), watchers_.end(), nullptr),
    395                   watchers_.end());
    396 
    397   for (int fd = 0; fd <= max_fd; fd++) {
    398     // libmicrohttpd is not using exception FDs, so lets put our expectations
    399     // upfront.
    400     CHECK(!FD_ISSET(fd, &es));
    401     if (FD_ISSET(fd, &rs) || FD_ISSET(fd, &ws)) {
    402       // libmicrohttpd should never use any of stdin/stdout/stderr descriptors.
    403       CHECK_GT(fd, STDERR_FILENO);
    404       std::unique_ptr<Watcher> watcher{new Watcher{this, fd}};
    405       watcher->Watch(FD_ISSET(fd, &rs), FD_ISSET(fd, &ws));
    406       watchers_.push_back(std::move(watcher));
    407     }
    408   }
    409 
    410   // Schedule a time-out timer, if asked by libmicrohttpd.
    411   MHD_UNSIGNED_LONG_LONG mhd_timeout = 0;
    412   if (MHD_get_timeout(server_, &mhd_timeout) == MHD_YES) {
    413     base::MessageLoopForIO::current()->PostDelayedTask(
    414         FROM_HERE,
    415         base::Bind(&ProtocolHandler::DoWork, weak_ptr_factory_.GetWeakPtr()),
    416         base::TimeDelta::FromMilliseconds(mhd_timeout));
    417   }
    418 }
    419 
    420 }  // namespace webservd
    421