1 // Copyright 2015 The Android Open Source Project 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "webservd/protocol_handler.h" 16 17 #include <linux/tcp.h> 18 #include <microhttpd.h> 19 #include <netinet/in.h> 20 #include <sys/socket.h> 21 22 #include <algorithm> 23 #include <limits> 24 #include <vector> 25 26 #include <base/bind.h> 27 #include <base/guid.h> 28 #include <base/logging.h> 29 #include <base/message_loop/message_loop.h> 30 31 #include "webservd/request.h" 32 #include "webservd/request_handler_interface.h" 33 #include "webservd/server_interface.h" 34 35 namespace webservd { 36 37 // Helper class to provide static callback methods to libmicrohttpd library, 38 // with the ability to access private methods of Server class. 39 class ServerHelper final { 40 public: 41 static int ConnectionHandler(void *cls, 42 MHD_Connection* connection, 43 const char* url, 44 const char* method, 45 const char* version, 46 const char* upload_data, 47 size_t* upload_data_size, 48 void** con_cls) { 49 auto handler = reinterpret_cast<ProtocolHandler*>(cls); 50 if (nullptr == *con_cls) { 51 std::string request_handler_id = handler->FindRequestHandler(url, method); 52 std::unique_ptr<Request> request{new Request{ 53 request_handler_id, url, method, version, connection, handler 54 }}; 55 if (!request->BeginRequestData()) 56 return MHD_NO; 57 58 // Pass the raw pointer here in order to interface with libmicrohttpd's 59 // old-style C API. 60 *con_cls = request.release(); 61 } else { 62 auto request = reinterpret_cast<Request*>(*con_cls); 63 if (*upload_data_size) { 64 if (!request->AddRequestData(upload_data, upload_data_size)) 65 return MHD_NO; 66 } else { 67 request->EndRequestData(); 68 } 69 } 70 return MHD_YES; 71 } 72 73 static void RequestCompleted(void* /* cls */, 74 MHD_Connection* /* connection */, 75 void** con_cls, 76 MHD_RequestTerminationCode toe) { 77 if (toe != MHD_REQUEST_TERMINATED_COMPLETED_OK) { 78 LOG(ERROR) << "Web request terminated abnormally with error code: " 79 << toe; 80 } 81 auto request = reinterpret_cast<Request*>(*con_cls); 82 *con_cls = nullptr; 83 delete request; 84 } 85 }; 86 87 ProtocolHandler::ProtocolHandler(const std::string& name, 88 ServerInterface* server_interface) 89 : id_{base::GenerateGUID()}, 90 name_{name}, 91 server_interface_{server_interface} {} 92 93 ProtocolHandler::~ProtocolHandler() { 94 Stop(); 95 } 96 97 std::string ProtocolHandler::AddRequestHandler( 98 const std::string& url, 99 const std::string& method, 100 std::unique_ptr<RequestHandlerInterface> handler) { 101 std::string handler_id = base::GenerateGUID(); 102 request_handlers_.emplace(handler_id, 103 HandlerMapEntry{url, method, std::move(handler)}); 104 return handler_id; 105 } 106 107 bool ProtocolHandler::RemoveRequestHandler(const std::string& handler_id) { 108 return request_handlers_.erase(handler_id) == 1; 109 } 110 111 std::string ProtocolHandler::FindRequestHandler( 112 const base::StringPiece& url, 113 const base::StringPiece& method) const { 114 size_t score = std::numeric_limits<size_t>::max(); 115 std::string handler_id; 116 for (const auto& pair : request_handlers_) { 117 std::string handler_url = pair.second.url; 118 bool url_match = (handler_url == url); 119 bool method_match = (pair.second.method == method); 120 121 // Try exact match first. If everything matches, we have our handler. 122 if (url_match && method_match) 123 return pair.first; 124 125 // Calculate the current handler's similarity score. The lower the score 126 // the better the match is... 127 size_t current_score = 0; 128 if (!url_match && !handler_url.empty() && handler_url.back() == '/') { 129 if (url.starts_with(handler_url)) { 130 url_match = true; 131 // Use the difference in URL length as URL match quality proxy. 132 // The longer URL, the more specific (better) match is. 133 // Multiply by 2 to allow for extra score point for matching the method. 134 current_score = (url.size() - handler_url.size()) * 2; 135 } 136 } 137 138 if (!method_match && pair.second.method.empty()) { 139 // If the handler didn't specify the method it handles, this means 140 // it doesn't care. However this isn't the exact match, so bump 141 // the score up one point. 142 method_match = true; 143 ++current_score; 144 } 145 146 if (url_match && method_match && current_score < score) { 147 score = current_score; 148 handler_id = pair.first; 149 } 150 } 151 152 return handler_id; 153 } 154 155 bool ProtocolHandler::Start(Config::ProtocolHandler* config) { 156 if (server_) { 157 LOG(ERROR) << "Protocol handler is already running."; 158 return false; 159 } 160 161 // If using TLS, the certificate, private key and fingerprint must be 162 // provided. 163 CHECK_EQ(config->use_tls, !config->private_key.empty()); 164 CHECK_EQ(config->use_tls, !config->certificate.empty()); 165 CHECK_EQ(config->use_tls, !config->certificate_fingerprint.empty()); 166 167 LOG(INFO) << "Starting " << (config->use_tls ? "HTTPS" : "HTTP") 168 << " protocol handler on port: " << config->port; 169 170 port_ = config->port; 171 protocol_ = (config->use_tls ? "https" : "http"); 172 certificate_fingerprint_ = config->certificate_fingerprint; 173 174 auto callback_addr = 175 reinterpret_cast<intptr_t>(&ServerHelper::RequestCompleted); 176 uint32_t flags = MHD_NO_FLAG; 177 if (server_interface_->GetConfig().use_debug) 178 flags |= MHD_USE_DEBUG; 179 180 // Enable IPv6 if supported. 181 if (server_interface_->GetConfig().use_ipv6) 182 flags |= MHD_USE_DUAL_STACK; 183 flags |= MHD_USE_TCP_FASTOPEN; // Use TCP Fast Open (see RFC 7413). 184 flags |= MHD_USE_SUSPEND_RESUME; // Allow suspending/resuming connections. 185 186 // MHD uses timeout of 0 to mean there is no timeout. 187 int timeout = server_interface_->GetConfig().default_request_timeout_seconds; 188 if (timeout < 0) 189 timeout = 0; 190 191 std::vector<MHD_OptionItem> options{ 192 {MHD_OPTION_CONNECTION_LIMIT, 10, nullptr}, 193 {MHD_OPTION_CONNECTION_TIMEOUT, timeout, nullptr}, 194 {MHD_OPTION_NOTIFY_COMPLETED, callback_addr, nullptr}, 195 }; 196 197 if (config->socket_fd != -1) { 198 // Take ownership of the socket. 199 int socket_fd = config->socket_fd; 200 config->socket_fd = -1; 201 202 // Set some more socket options. These options were set in libmicrohttpd. 203 int on = 1; 204 if (setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) { 205 // Treat this as a non-fatal failure. Just continue after logging. 206 PLOG(WARNING) << "Failed to set SO_REUSEADDR option on listening socket."; 207 } 208 on = (MHD_USE_DUAL_STACK != (flags & MHD_USE_DUAL_STACK)); 209 if (setsockopt(socket_fd, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) < 0) { 210 PLOG(WARNING) << "Failed to set IPV6_V6ONLY option on listening socket."; 211 close(socket_fd); 212 return false; 213 } 214 215 // Bind socket to the port. 216 sockaddr_in6 addr = {}; 217 addr.sin6_family = AF_INET6; 218 addr.sin6_port = htons(config->port); 219 if (bind(socket_fd, reinterpret_cast<const sockaddr*>(&addr), 220 sizeof(addr)) < 0) { 221 PLOG(ERROR) << "Failed to bind the socket to port " << config->port; 222 close(socket_fd); 223 return false; 224 } 225 if ((flags & MHD_USE_TCP_FASTOPEN) != 0) { 226 // This is the default value from libmicrohttpd. 227 int fastopen_queue_size = 10; 228 if (setsockopt(socket_fd, IPPROTO_TCP, TCP_FASTOPEN, 229 &fastopen_queue_size, sizeof(fastopen_queue_size)) < 0) { 230 // Treat this as a non-fatal failure. Just continue after logging. 231 PLOG(WARNING) << "Failed to set TCP_FASTOPEN option on socket."; 232 } 233 } 234 235 // Start listening on the socket. 236 // 32 connections is the value used by libmicrohttpd. 237 if (listen(socket_fd, 32) < 0) { 238 PLOG(ERROR) << "Failed to listen for connections on the socket."; 239 close(socket_fd); 240 return false; 241 } 242 243 // Finally, pass the socket to libmicrohttpd. 244 options.push_back( 245 MHD_OptionItem{MHD_OPTION_LISTEN_SOCKET, socket_fd, nullptr}); 246 } 247 248 // libmicrohttpd expects both the key and certificate to be zero-terminated 249 // strings. Make sure they are terminated properly. 250 brillo::SecureBlob private_key_copy = config->private_key; 251 brillo::Blob certificate_copy = config->certificate; 252 private_key_copy.push_back(0); 253 certificate_copy.push_back(0); 254 255 if (config->use_tls) { 256 flags |= MHD_USE_SSL; 257 options.push_back( 258 MHD_OptionItem{MHD_OPTION_HTTPS_MEM_KEY, 0, private_key_copy.data()}); 259 options.push_back( 260 MHD_OptionItem{MHD_OPTION_HTTPS_MEM_CERT, 0, certificate_copy.data()}); 261 } 262 263 options.push_back(MHD_OptionItem{MHD_OPTION_END, 0, nullptr}); 264 265 server_ = MHD_start_daemon(flags, config->port, nullptr, nullptr, 266 &ServerHelper::ConnectionHandler, this, 267 MHD_OPTION_ARRAY, options.data(), MHD_OPTION_END); 268 if (!server_) { 269 PLOG(ERROR) << "Failed to create protocol handler on port " << config->port; 270 return false; 271 } 272 server_interface_->ProtocolHandlerStarted(this); 273 DoWork(); 274 LOG(INFO) << "Protocol handler started"; 275 return true; 276 } 277 278 bool ProtocolHandler::Stop() { 279 if (server_) { 280 LOG(INFO) << "Shutting down the protocol handler..."; 281 MHD_stop_daemon(server_); 282 server_ = nullptr; 283 server_interface_->ProtocolHandlerStopped(this); 284 LOG(INFO) << "Protocol handler shutdown complete"; 285 } 286 port_ = 0; 287 protocol_.clear(); 288 certificate_fingerprint_.clear(); 289 return true; 290 } 291 292 void ProtocolHandler::AddRequest(Request* request) { 293 requests_.emplace(request->GetID(), request); 294 } 295 296 void ProtocolHandler::RemoveRequest(Request* request) { 297 requests_.erase(request->GetID()); 298 } 299 300 Request* ProtocolHandler::GetRequest(const std::string& request_id) const { 301 auto p = requests_.find(request_id); 302 return (p != requests_.end()) ? p->second : nullptr; 303 } 304 305 // A file descriptor watcher class that oversees I/O operation notification 306 // on particular socket file descriptor. 307 class ProtocolHandler::Watcher final : public base::MessageLoopForIO::Watcher { 308 public: 309 Watcher(ProtocolHandler* handler, int fd) : fd_{fd}, handler_{handler} {} 310 311 void Watch(bool read, bool write) { 312 if (read == watching_read_ && write == watching_write_ && !triggered_) 313 return; 314 315 controller_.StopWatchingFileDescriptor(); 316 watching_read_ = read; 317 watching_write_ = write; 318 triggered_ = false; 319 320 auto mode = base::MessageLoopForIO::WATCH_READ_WRITE; 321 if (watching_read_ && watching_write_) 322 mode = base::MessageLoopForIO::WATCH_READ_WRITE; 323 else if (watching_read_) 324 mode = base::MessageLoopForIO::WATCH_READ; 325 else if (watching_write_) 326 mode = base::MessageLoopForIO::WATCH_WRITE; 327 base::MessageLoopForIO::current()->WatchFileDescriptor(fd_, false, mode, 328 &controller_, this); 329 } 330 331 // Overrides from base::MessageLoopForIO::Watcher. 332 void OnFileCanReadWithoutBlocking(int /* fd */) override { 333 triggered_ = true; 334 handler_->ScheduleWork(); 335 } 336 337 void OnFileCanWriteWithoutBlocking(int /* fd */) override { 338 triggered_ = true; 339 handler_->ScheduleWork(); 340 } 341 342 int GetFileDescriptor() const { return fd_; } 343 344 private: 345 int fd_{-1}; 346 ProtocolHandler* handler_{nullptr}; 347 bool watching_read_{false}; 348 bool watching_write_{false}; 349 bool triggered_{false}; 350 base::MessageLoopForIO::FileDescriptorWatcher controller_; 351 352 DISALLOW_COPY_AND_ASSIGN(Watcher); 353 }; 354 355 void ProtocolHandler::ScheduleWork() { 356 if (work_scheduled_) 357 return; 358 359 work_scheduled_ = true; 360 base::MessageLoopForIO::current()->PostTask( 361 FROM_HERE, 362 base::Bind(&ProtocolHandler::DoWork, weak_ptr_factory_.GetWeakPtr())); 363 } 364 365 void ProtocolHandler::DoWork() { 366 work_scheduled_ = false; 367 weak_ptr_factory_.InvalidateWeakPtrs(); 368 369 // Check if there is any pending work to be done in libmicrohttpd. 370 MHD_run(server_); 371 372 // Get all the file descriptors from libmicrohttpd and watch for I/O 373 // operations on them. 374 fd_set rs; 375 fd_set ws; 376 fd_set es; 377 int max_fd = MHD_INVALID_SOCKET; 378 FD_ZERO(&rs); 379 FD_ZERO(&ws); 380 FD_ZERO(&es); 381 CHECK_EQ(MHD_YES, MHD_get_fdset(server_, &rs, &ws, &es, &max_fd)); 382 383 for (auto& watcher : watchers_) { 384 int fd = watcher->GetFileDescriptor(); 385 if (FD_ISSET(fd, &rs) || FD_ISSET(fd, &ws)) { 386 watcher->Watch(FD_ISSET(fd, &rs), FD_ISSET(fd, &ws)); 387 FD_CLR(fd, &rs); 388 FD_CLR(fd, &ws); 389 } else { 390 watcher.reset(); 391 } 392 } 393 394 watchers_.erase(std::remove(watchers_.begin(), watchers_.end(), nullptr), 395 watchers_.end()); 396 397 for (int fd = 0; fd <= max_fd; fd++) { 398 // libmicrohttpd is not using exception FDs, so lets put our expectations 399 // upfront. 400 CHECK(!FD_ISSET(fd, &es)); 401 if (FD_ISSET(fd, &rs) || FD_ISSET(fd, &ws)) { 402 // libmicrohttpd should never use any of stdin/stdout/stderr descriptors. 403 CHECK_GT(fd, STDERR_FILENO); 404 std::unique_ptr<Watcher> watcher{new Watcher{this, fd}}; 405 watcher->Watch(FD_ISSET(fd, &rs), FD_ISSET(fd, &ws)); 406 watchers_.push_back(std::move(watcher)); 407 } 408 } 409 410 // Schedule a time-out timer, if asked by libmicrohttpd. 411 MHD_UNSIGNED_LONG_LONG mhd_timeout = 0; 412 if (MHD_get_timeout(server_, &mhd_timeout) == MHD_YES) { 413 base::MessageLoopForIO::current()->PostDelayedTask( 414 FROM_HERE, 415 base::Bind(&ProtocolHandler::DoWork, weak_ptr_factory_.GetWeakPtr()), 416 base::TimeDelta::FromMilliseconds(mhd_timeout)); 417 } 418 } 419 420 } // namespace webservd 421