Home | History | Annotate | Download | only in server
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/server/http_server.h"
      6 
      7 #include "base/compiler_specific.h"
      8 #include "base/logging.h"
      9 #include "base/stl_util.h"
     10 #include "base/strings/string_number_conversions.h"
     11 #include "base/strings/string_util.h"
     12 #include "base/strings/stringprintf.h"
     13 #include "base/sys_byteorder.h"
     14 #include "build/build_config.h"
     15 #include "net/base/net_errors.h"
     16 #include "net/server/http_connection.h"
     17 #include "net/server/http_server_request_info.h"
     18 #include "net/server/http_server_response_info.h"
     19 #include "net/server/web_socket.h"
     20 #include "net/socket/server_socket.h"
     21 #include "net/socket/stream_socket.h"
     22 #include "net/socket/tcp_server_socket.h"
     23 
     24 namespace net {
     25 
     26 HttpServer::HttpServer(scoped_ptr<ServerSocket> server_socket,
     27                        HttpServer::Delegate* delegate)
     28     : server_socket_(server_socket.Pass()),
     29       delegate_(delegate),
     30       last_id_(0),
     31       weak_ptr_factory_(this) {
     32   DCHECK(server_socket_);
     33   DoAcceptLoop();
     34 }
     35 
     36 HttpServer::~HttpServer() {
     37   STLDeleteContainerPairSecondPointers(
     38       id_to_connection_.begin(), id_to_connection_.end());
     39 }
     40 
     41 void HttpServer::AcceptWebSocket(
     42     int connection_id,
     43     const HttpServerRequestInfo& request) {
     44   HttpConnection* connection = FindConnection(connection_id);
     45   if (connection == NULL)
     46     return;
     47   DCHECK(connection->web_socket());
     48   connection->web_socket()->Accept(request);
     49 }
     50 
     51 void HttpServer::SendOverWebSocket(int connection_id,
     52                                    const std::string& data) {
     53   HttpConnection* connection = FindConnection(connection_id);
     54   if (connection == NULL)
     55     return;
     56   DCHECK(connection->web_socket());
     57   connection->web_socket()->Send(data);
     58 }
     59 
     60 void HttpServer::SendRaw(int connection_id, const std::string& data) {
     61   HttpConnection* connection = FindConnection(connection_id);
     62   if (connection == NULL)
     63     return;
     64 
     65   bool writing_in_progress = !connection->write_buf()->IsEmpty();
     66   if (connection->write_buf()->Append(data) && !writing_in_progress)
     67     DoWriteLoop(connection);
     68 }
     69 
     70 void HttpServer::SendResponse(int connection_id,
     71                               const HttpServerResponseInfo& response) {
     72   SendRaw(connection_id, response.Serialize());
     73 }
     74 
     75 void HttpServer::Send(int connection_id,
     76                       HttpStatusCode status_code,
     77                       const std::string& data,
     78                       const std::string& content_type) {
     79   HttpServerResponseInfo response(status_code);
     80   response.SetContentHeaders(data.size(), content_type);
     81   SendResponse(connection_id, response);
     82   SendRaw(connection_id, data);
     83 }
     84 
     85 void HttpServer::Send200(int connection_id,
     86                          const std::string& data,
     87                          const std::string& content_type) {
     88   Send(connection_id, HTTP_OK, data, content_type);
     89 }
     90 
     91 void HttpServer::Send404(int connection_id) {
     92   SendResponse(connection_id, HttpServerResponseInfo::CreateFor404());
     93 }
     94 
     95 void HttpServer::Send500(int connection_id, const std::string& message) {
     96   SendResponse(connection_id, HttpServerResponseInfo::CreateFor500(message));
     97 }
     98 
     99 void HttpServer::Close(int connection_id) {
    100   HttpConnection* connection = FindConnection(connection_id);
    101   if (connection == NULL)
    102     return;
    103 
    104   id_to_connection_.erase(connection_id);
    105   delegate_->OnClose(connection_id);
    106 
    107   // The call stack might have callbacks which still have the pointer of
    108   // connection. Instead of referencing connection with ID all the time,
    109   // destroys the connection in next run loop to make sure any pending
    110   // callbacks in the call stack return.
    111   base::MessageLoopProxy::current()->DeleteSoon(FROM_HERE, connection);
    112 }
    113 
    114 int HttpServer::GetLocalAddress(IPEndPoint* address) {
    115   return server_socket_->GetLocalAddress(address);
    116 }
    117 
    118 void HttpServer::SetReceiveBufferSize(int connection_id, int32 size) {
    119   HttpConnection* connection = FindConnection(connection_id);
    120   DCHECK(connection);
    121   connection->read_buf()->set_max_buffer_size(size);
    122 }
    123 
    124 void HttpServer::SetSendBufferSize(int connection_id, int32 size) {
    125   HttpConnection* connection = FindConnection(connection_id);
    126   DCHECK(connection);
    127   connection->write_buf()->set_max_buffer_size(size);
    128 }
    129 
    130 void HttpServer::DoAcceptLoop() {
    131   int rv;
    132   do {
    133     rv = server_socket_->Accept(&accepted_socket_,
    134                                 base::Bind(&HttpServer::OnAcceptCompleted,
    135                                            weak_ptr_factory_.GetWeakPtr()));
    136     if (rv == ERR_IO_PENDING)
    137       return;
    138     rv = HandleAcceptResult(rv);
    139   } while (rv == OK);
    140 }
    141 
    142 void HttpServer::OnAcceptCompleted(int rv) {
    143   if (HandleAcceptResult(rv) == OK)
    144     DoAcceptLoop();
    145 }
    146 
    147 int HttpServer::HandleAcceptResult(int rv) {
    148   if (rv < 0) {
    149     LOG(ERROR) << "Accept error: rv=" << rv;
    150     return rv;
    151   }
    152 
    153   HttpConnection* connection =
    154       new HttpConnection(++last_id_, accepted_socket_.Pass());
    155   id_to_connection_[connection->id()] = connection;
    156   delegate_->OnConnect(connection->id());
    157   if (!HasClosedConnection(connection))
    158     DoReadLoop(connection);
    159   return OK;
    160 }
    161 
    162 void HttpServer::DoReadLoop(HttpConnection* connection) {
    163   int rv;
    164   do {
    165     HttpConnection::ReadIOBuffer* read_buf = connection->read_buf();
    166     // Increases read buffer size if necessary.
    167     if (read_buf->RemainingCapacity() == 0 && !read_buf->IncreaseCapacity()) {
    168       Close(connection->id());
    169       return;
    170     }
    171 
    172     rv = connection->socket()->Read(
    173         read_buf,
    174         read_buf->RemainingCapacity(),
    175         base::Bind(&HttpServer::OnReadCompleted,
    176                    weak_ptr_factory_.GetWeakPtr(), connection->id()));
    177     if (rv == ERR_IO_PENDING)
    178       return;
    179     rv = HandleReadResult(connection, rv);
    180   } while (rv == OK);
    181 }
    182 
    183 void HttpServer::OnReadCompleted(int connection_id, int rv) {
    184   HttpConnection* connection = FindConnection(connection_id);
    185   if (!connection)  // It might be closed right before by write error.
    186     return;
    187 
    188   if (HandleReadResult(connection, rv) == OK)
    189     DoReadLoop(connection);
    190 }
    191 
    192 int HttpServer::HandleReadResult(HttpConnection* connection, int rv) {
    193   if (rv <= 0) {
    194     Close(connection->id());
    195     return rv == 0 ? ERR_CONNECTION_CLOSED : rv;
    196   }
    197 
    198   HttpConnection::ReadIOBuffer* read_buf = connection->read_buf();
    199   read_buf->DidRead(rv);
    200 
    201   // Handles http requests or websocket messages.
    202   while (read_buf->GetSize() > 0) {
    203     if (connection->web_socket()) {
    204       std::string message;
    205       WebSocket::ParseResult result = connection->web_socket()->Read(&message);
    206       if (result == WebSocket::FRAME_INCOMPLETE)
    207         break;
    208 
    209       if (result == WebSocket::FRAME_CLOSE ||
    210           result == WebSocket::FRAME_ERROR) {
    211         Close(connection->id());
    212         return ERR_CONNECTION_CLOSED;
    213       }
    214       delegate_->OnWebSocketMessage(connection->id(), message);
    215       if (HasClosedConnection(connection))
    216         return ERR_CONNECTION_CLOSED;
    217       continue;
    218     }
    219 
    220     HttpServerRequestInfo request;
    221     size_t pos = 0;
    222     if (!ParseHeaders(read_buf->StartOfBuffer(), read_buf->GetSize(),
    223                       &request, &pos)) {
    224       break;
    225     }
    226 
    227     // Sets peer address if exists.
    228     connection->socket()->GetPeerAddress(&request.peer);
    229 
    230     if (request.HasHeaderValue("connection", "upgrade")) {
    231       scoped_ptr<WebSocket> websocket(
    232           WebSocket::CreateWebSocket(this, connection, request, &pos));
    233       if (!websocket)  // Not enough data was received.
    234         break;
    235       connection->SetWebSocket(websocket.Pass());
    236       read_buf->DidConsume(pos);
    237       delegate_->OnWebSocketRequest(connection->id(), request);
    238       if (HasClosedConnection(connection))
    239         return ERR_CONNECTION_CLOSED;
    240       continue;
    241     }
    242 
    243     const char kContentLength[] = "content-length";
    244     if (request.headers.count(kContentLength) > 0) {
    245       size_t content_length = 0;
    246       const size_t kMaxBodySize = 100 << 20;
    247       if (!base::StringToSizeT(request.GetHeaderValue(kContentLength),
    248                                &content_length) ||
    249           content_length > kMaxBodySize) {
    250         SendResponse(connection->id(),
    251                      HttpServerResponseInfo::CreateFor500(
    252                          "request content-length too big or unknown: " +
    253                          request.GetHeaderValue(kContentLength)));
    254         Close(connection->id());
    255         return ERR_CONNECTION_CLOSED;
    256       }
    257 
    258       if (read_buf->GetSize() - pos < content_length)
    259         break;  // Not enough data was received yet.
    260       request.data.assign(read_buf->StartOfBuffer() + pos, content_length);
    261       pos += content_length;
    262     }
    263 
    264     read_buf->DidConsume(pos);
    265     delegate_->OnHttpRequest(connection->id(), request);
    266     if (HasClosedConnection(connection))
    267       return ERR_CONNECTION_CLOSED;
    268   }
    269 
    270   return OK;
    271 }
    272 
    273 void HttpServer::DoWriteLoop(HttpConnection* connection) {
    274   int rv = OK;
    275   HttpConnection::QueuedWriteIOBuffer* write_buf = connection->write_buf();
    276   while (rv == OK && write_buf->GetSizeToWrite() > 0) {
    277     rv = connection->socket()->Write(
    278         write_buf,
    279         write_buf->GetSizeToWrite(),
    280         base::Bind(&HttpServer::OnWriteCompleted,
    281                    weak_ptr_factory_.GetWeakPtr(), connection->id()));
    282     if (rv == ERR_IO_PENDING || rv == OK)
    283       return;
    284     rv = HandleWriteResult(connection, rv);
    285   }
    286 }
    287 
    288 void HttpServer::OnWriteCompleted(int connection_id, int rv) {
    289   HttpConnection* connection = FindConnection(connection_id);
    290   if (!connection)  // It might be closed right before by read error.
    291     return;
    292 
    293   if (HandleWriteResult(connection, rv) == OK)
    294     DoWriteLoop(connection);
    295 }
    296 
    297 int HttpServer::HandleWriteResult(HttpConnection* connection, int rv) {
    298   if (rv < 0) {
    299     Close(connection->id());
    300     return rv;
    301   }
    302 
    303   connection->write_buf()->DidConsume(rv);
    304   return OK;
    305 }
    306 
    307 namespace {
    308 
    309 //
    310 // HTTP Request Parser
    311 // This HTTP request parser uses a simple state machine to quickly parse
    312 // through the headers.  The parser is not 100% complete, as it is designed
    313 // for use in this simple test driver.
    314 //
    315 // Known issues:
    316 //   - does not handle whitespace on first HTTP line correctly.  Expects
    317 //     a single space between the method/url and url/protocol.
    318 
    319 // Input character types.
    320 enum header_parse_inputs {
    321   INPUT_LWS,
    322   INPUT_CR,
    323   INPUT_LF,
    324   INPUT_COLON,
    325   INPUT_DEFAULT,
    326   MAX_INPUTS,
    327 };
    328 
    329 // Parser states.
    330 enum header_parse_states {
    331   ST_METHOD,     // Receiving the method
    332   ST_URL,        // Receiving the URL
    333   ST_PROTO,      // Receiving the protocol
    334   ST_HEADER,     // Starting a Request Header
    335   ST_NAME,       // Receiving a request header name
    336   ST_SEPARATOR,  // Receiving the separator between header name and value
    337   ST_VALUE,      // Receiving a request header value
    338   ST_DONE,       // Parsing is complete and successful
    339   ST_ERR,        // Parsing encountered invalid syntax.
    340   MAX_STATES
    341 };
    342 
    343 // State transition table
    344 int parser_state[MAX_STATES][MAX_INPUTS] = {
    345 /* METHOD    */ { ST_URL,       ST_ERR,     ST_ERR,   ST_ERR,       ST_METHOD },
    346 /* URL       */ { ST_PROTO,     ST_ERR,     ST_ERR,   ST_URL,       ST_URL },
    347 /* PROTOCOL  */ { ST_ERR,       ST_HEADER,  ST_NAME,  ST_ERR,       ST_PROTO },
    348 /* HEADER    */ { ST_ERR,       ST_ERR,     ST_NAME,  ST_ERR,       ST_ERR },
    349 /* NAME      */ { ST_SEPARATOR, ST_DONE,    ST_ERR,   ST_VALUE,     ST_NAME },
    350 /* SEPARATOR */ { ST_SEPARATOR, ST_ERR,     ST_ERR,   ST_VALUE,     ST_ERR },
    351 /* VALUE     */ { ST_VALUE,     ST_HEADER,  ST_NAME,  ST_VALUE,     ST_VALUE },
    352 /* DONE      */ { ST_DONE,      ST_DONE,    ST_DONE,  ST_DONE,      ST_DONE },
    353 /* ERR       */ { ST_ERR,       ST_ERR,     ST_ERR,   ST_ERR,       ST_ERR }
    354 };
    355 
    356 // Convert an input character to the parser's input token.
    357 int charToInput(char ch) {
    358   switch(ch) {
    359     case ' ':
    360     case '\t':
    361       return INPUT_LWS;
    362     case '\r':
    363       return INPUT_CR;
    364     case '\n':
    365       return INPUT_LF;
    366     case ':':
    367       return INPUT_COLON;
    368   }
    369   return INPUT_DEFAULT;
    370 }
    371 
    372 }  // namespace
    373 
    374 bool HttpServer::ParseHeaders(const char* data,
    375                               size_t data_len,
    376                               HttpServerRequestInfo* info,
    377                               size_t* ppos) {
    378   size_t& pos = *ppos;
    379   int state = ST_METHOD;
    380   std::string buffer;
    381   std::string header_name;
    382   std::string header_value;
    383   while (pos < data_len) {
    384     char ch = data[pos++];
    385     int input = charToInput(ch);
    386     int next_state = parser_state[state][input];
    387 
    388     bool transition = (next_state != state);
    389     HttpServerRequestInfo::HeadersMap::iterator it;
    390     if (transition) {
    391       // Do any actions based on state transitions.
    392       switch (state) {
    393         case ST_METHOD:
    394           info->method = buffer;
    395           buffer.clear();
    396           break;
    397         case ST_URL:
    398           info->path = buffer;
    399           buffer.clear();
    400           break;
    401         case ST_PROTO:
    402           // TODO(mbelshe): Deal better with parsing protocol.
    403           DCHECK(buffer == "HTTP/1.1");
    404           buffer.clear();
    405           break;
    406         case ST_NAME:
    407           header_name = base::StringToLowerASCII(buffer);
    408           buffer.clear();
    409           break;
    410         case ST_VALUE:
    411           base::TrimWhitespaceASCII(buffer, base::TRIM_LEADING, &header_value);
    412           it = info->headers.find(header_name);
    413           // See last paragraph ("Multiple message-header fields...")
    414           // of www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
    415           if (it == info->headers.end()) {
    416             info->headers[header_name] = header_value;
    417           } else {
    418             it->second.append(",");
    419             it->second.append(header_value);
    420           }
    421           buffer.clear();
    422           break;
    423         case ST_SEPARATOR:
    424           break;
    425       }
    426       state = next_state;
    427     } else {
    428       // Do any actions based on current state
    429       switch (state) {
    430         case ST_METHOD:
    431         case ST_URL:
    432         case ST_PROTO:
    433         case ST_VALUE:
    434         case ST_NAME:
    435           buffer.append(&ch, 1);
    436           break;
    437         case ST_DONE:
    438           DCHECK(input == INPUT_LF);
    439           return true;
    440         case ST_ERR:
    441           return false;
    442       }
    443     }
    444   }
    445   // No more characters, but we haven't finished parsing yet.
    446   return false;
    447 }
    448 
    449 HttpConnection* HttpServer::FindConnection(int connection_id) {
    450   IdToConnectionMap::iterator it = id_to_connection_.find(connection_id);
    451   if (it == id_to_connection_.end())
    452     return NULL;
    453   return it->second;
    454 }
    455 
    456 // This is called after any delegate callbacks are called to check if Close()
    457 // has been called during callback processing. Using the pointer of connection,
    458 // |connection| is safe here because Close() deletes the connection in next run
    459 // loop.
    460 bool HttpServer::HasClosedConnection(HttpConnection* connection) {
    461   return FindConnection(connection->id()) != connection;
    462 }
    463 
    464 }  // namespace net
    465