Home | History | Annotate | Download | only in test
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "chrome_frame/test/test_server.h"
      6 
      7 #include <windows.h>
      8 #include <objbase.h>
      9 #include <urlmon.h>
     10 
     11 #include "base/bind.h"
     12 #include "base/logging.h"
     13 #include "base/strings/string_number_conversions.h"
     14 #include "base/strings/string_piece.h"
     15 #include "base/strings/string_util.h"
     16 #include "base/strings/stringprintf.h"
     17 #include "base/strings/utf_string_conversions.h"
     18 #include "chrome_frame/test/chrome_frame_test_utils.h"
     19 #include "net/base/winsock_init.h"
     20 #include "net/http/http_util.h"
     21 #include "net/socket/tcp_listen_socket.h"
     22 
     23 namespace test_server {
     24 const char kDefaultHeaderTemplate[] =
     25     "HTTP/1.1 %hs\r\n"
     26     "Connection: close\r\n"
     27     "Content-Type: %hs\r\n"
     28     "Content-Length: %i\r\n\r\n";
     29 const char kStatusOk[] = "200 OK";
     30 const char kStatusNotFound[] = "404 Not Found";
     31 const char kDefaultContentType[] = "text/html; charset=UTF-8";
     32 
     33 void Request::ParseHeaders(const std::string& headers) {
     34   DCHECK(method_.length() == 0);
     35 
     36   size_t pos = headers.find("\r\n");
     37   DCHECK(pos != std::string::npos);
     38   if (pos != std::string::npos) {
     39     headers_ = headers.substr(pos + 2);
     40 
     41     base::StringTokenizer tokenizer(
     42         headers.begin(), headers.begin() + pos, " ");
     43     std::string* parse[] = { &method_, &path_, &version_ };
     44     int field = 0;
     45     while (tokenizer.GetNext() && field < arraysize(parse)) {
     46       parse[field++]->assign(tokenizer.token_begin(),
     47                              tokenizer.token_end());
     48     }
     49   }
     50 
     51   // Check for content-length in case we're being sent some data.
     52   net::HttpUtil::HeadersIterator it(headers_.begin(), headers_.end(),
     53                                     "\r\n");
     54   while (it.GetNext()) {
     55     if (LowerCaseEqualsASCII(it.name(), "content-length")) {
     56       int int_content_length;
     57       base::StringToInt(base::StringPiece(it.values_begin(),
     58                                           it.values_end()),
     59                         &int_content_length);
     60       content_length_ = int_content_length;
     61       break;
     62     }
     63   }
     64 }
     65 
     66 void Request::OnDataReceived(const std::string& data) {
     67   content_ += data;
     68 
     69   if (method_.length() == 0) {
     70     size_t index = content_.find("\r\n\r\n");
     71     if (index != std::string::npos) {
     72       // Parse the headers before returning and chop them of the
     73       // data buffer we've already received.
     74       std::string headers(content_.substr(0, index + 2));
     75       ParseHeaders(headers);
     76       content_.erase(0, index + 4);
     77     }
     78   }
     79 }
     80 
     81 ResponseForPath::~ResponseForPath() {
     82 }
     83 
     84 SimpleResponse::~SimpleResponse() {
     85 }
     86 
     87 bool FileResponse::GetContentType(std::string* content_type) const {
     88   size_t length = ContentLength();
     89   char buffer[4096];
     90   void* data = NULL;
     91 
     92   if (length) {
     93     // Create a copy of the first few bytes of the file.
     94     // If we try and use the mapped file directly, FindMimeFromData will crash
     95     // 'cause it cheats and temporarily tries to write to the buffer!
     96     length = std::min(arraysize(buffer), length);
     97     memcpy(buffer, file_->data(), length);
     98     data = buffer;
     99   }
    100 
    101   LPOLESTR mime_type = NULL;
    102   FindMimeFromData(NULL, file_path_.value().c_str(), data, length, NULL,
    103                    FMFD_DEFAULT, &mime_type, 0);
    104   if (mime_type) {
    105     *content_type = WideToASCII(mime_type);
    106     ::CoTaskMemFree(mime_type);
    107   }
    108 
    109   return content_type->length() > 0;
    110 }
    111 
    112 void FileResponse::WriteContents(net::StreamListenSocket* socket) const {
    113   DCHECK(file_.get());
    114   if (file_.get()) {
    115     socket->Send(reinterpret_cast<const char*>(file_->data()),
    116                  file_->length(), false);
    117   }
    118 }
    119 
    120 size_t FileResponse::ContentLength() const {
    121   if (file_.get() == NULL) {
    122     file_.reset(new base::MemoryMappedFile());
    123     if (!file_->Initialize(file_path_)) {
    124       NOTREACHED();
    125       file_.reset();
    126     }
    127   }
    128   return file_.get() ? file_->length() : 0;
    129 }
    130 
    131 bool RedirectResponse::GetCustomHeaders(std::string* headers) const {
    132   *headers = base::StringPrintf("HTTP/1.1 302 Found\r\n"
    133                                 "Connection: close\r\n"
    134                                 "Content-Length: 0\r\n"
    135                                 "Content-Type: text/html\r\n"
    136                                 "Location: %hs\r\n\r\n",
    137                                 redirect_url_.c_str());
    138   return true;
    139 }
    140 
    141 SimpleWebServer::SimpleWebServer(int port) {
    142   Construct(chrome_frame_test::GetLocalIPv4Address(), port);
    143 }
    144 
    145 SimpleWebServer::SimpleWebServer(const std::string& address, int port) {
    146   Construct(address, port);
    147 }
    148 
    149 SimpleWebServer::~SimpleWebServer() {
    150   ConnectionList::const_iterator it;
    151   for (it = connections_.begin(); it != connections_.end(); ++it)
    152     delete (*it);
    153   connections_.clear();
    154 }
    155 
    156 void SimpleWebServer::Construct(const std::string& address, int port) {
    157   CHECK(base::MessageLoop::current())
    158       << "SimpleWebServer requires a message loop";
    159   net::EnsureWinsockInit();
    160   AddResponse(&quit_);
    161   host_ = address;
    162   server_ = net::TCPListenSocket::CreateAndListen(address, port, this);
    163   LOG_IF(DFATAL, !server_.get())
    164       << "Failed to create listener socket at " << address << ":" << port;
    165 }
    166 
    167 void SimpleWebServer::AddResponse(Response* response) {
    168   responses_.push_back(response);
    169 }
    170 
    171 void SimpleWebServer::DeleteAllResponses() {
    172   std::list<Response*>::const_iterator it;
    173   for (it = responses_.begin(); it != responses_.end(); ++it) {
    174     if ((*it) != &quit_)
    175       delete (*it);
    176   }
    177 }
    178 
    179 Response* SimpleWebServer::FindResponse(const Request& request) const {
    180   std::list<Response*>::const_iterator it;
    181   for (it = responses_.begin(); it != responses_.end(); it++) {
    182     Response* response = (*it);
    183     if (response->Matches(request)) {
    184       return response;
    185     }
    186   }
    187   return NULL;
    188 }
    189 
    190 Connection* SimpleWebServer::FindConnection(
    191     const net::StreamListenSocket* socket) const {
    192   ConnectionList::const_iterator it;
    193   for (it = connections_.begin(); it != connections_.end(); it++) {
    194     if ((*it)->IsSame(socket)) {
    195       return (*it);
    196     }
    197   }
    198   return NULL;
    199 }
    200 
    201 void SimpleWebServer::DidAccept(
    202     net::StreamListenSocket* server,
    203     scoped_ptr<net::StreamListenSocket> connection) {
    204   connections_.push_back(new Connection(connection.Pass()));
    205 }
    206 
    207 void SimpleWebServer::DidRead(net::StreamListenSocket* connection,
    208                               const char* data,
    209                               int len) {
    210   Connection* c = FindConnection(connection);
    211   DCHECK(c);
    212   Request& r = c->request();
    213   std::string str(data, len);
    214   r.OnDataReceived(str);
    215   if (r.AllContentReceived()) {
    216     const Request& request = c->request();
    217     Response* response = FindResponse(request);
    218     if (response) {
    219       std::string headers;
    220       if (!response->GetCustomHeaders(&headers)) {
    221         std::string content_type;
    222         if (!response->GetContentType(&content_type))
    223           content_type = kDefaultContentType;
    224         headers = base::StringPrintf(kDefaultHeaderTemplate, kStatusOk,
    225                                      content_type.c_str(),
    226                                      response->ContentLength());
    227       }
    228 
    229       connection->Send(headers, false);
    230       response->WriteContents(connection);
    231       response->IncrementAccessCounter();
    232     } else {
    233       std::string payload = "sorry, I can't find " + request.path();
    234       std::string headers(base::StringPrintf(kDefaultHeaderTemplate,
    235                                              kStatusNotFound,
    236                                              kDefaultContentType,
    237                                              payload.length()));
    238       connection->Send(headers, false);
    239       connection->Send(payload, false);
    240     }
    241   }
    242 }
    243 
    244 void SimpleWebServer::DidClose(net::StreamListenSocket* sock) {
    245   // To keep the historical list of connections reasonably tidy, we delete
    246   // 404's when the connection ends.
    247   Connection* c = FindConnection(sock);
    248   DCHECK(c);
    249   c->OnSocketClosed();
    250   if (!FindResponse(c->request())) {
    251     // extremely inefficient, but in one line and not that common... :)
    252     connections_.erase(std::find(connections_.begin(), connections_.end(), c));
    253     delete c;
    254   }
    255 }
    256 
    257 HTTPTestServer::HTTPTestServer(int port, const std::wstring& address,
    258                                base::FilePath root_dir)
    259     : port_(port), address_(address), root_dir_(root_dir) {
    260   net::EnsureWinsockInit();
    261   server_ =
    262       net::TCPListenSocket::CreateAndListen(WideToUTF8(address), port, this);
    263 }
    264 
    265 HTTPTestServer::~HTTPTestServer() {
    266 }
    267 
    268 std::list<scoped_refptr<ConfigurableConnection>>::iterator
    269 HTTPTestServer::FindConnection(const net::StreamListenSocket* socket) {
    270   ConnectionList::iterator it;
    271   // Scan through the list searching for the desired socket. Along the way,
    272   // erase any connections for which the corresponding socket has already been
    273   // forgotten about as a result of all data having been sent.
    274   for (it = connection_list_.begin(); it != connection_list_.end(); ) {
    275     ConfigurableConnection* connection = it->get();
    276     if (connection->socket_ == NULL) {
    277       connection_list_.erase(it++);
    278       continue;
    279     }
    280     if (connection->socket_ == socket)
    281       break;
    282     ++it;
    283   }
    284 
    285   return it;
    286 }
    287 
    288 scoped_refptr<ConfigurableConnection> HTTPTestServer::ConnectionFromSocket(
    289     const net::StreamListenSocket* socket) {
    290   ConnectionList::iterator it = FindConnection(socket);
    291   if (it != connection_list_.end())
    292     return *it;
    293   return NULL;
    294 }
    295 
    296 void HTTPTestServer::DidAccept(net::StreamListenSocket* server,
    297                                scoped_ptr<net::StreamListenSocket> socket) {
    298   connection_list_.push_back(new ConfigurableConnection(socket.Pass()));
    299 }
    300 
    301 void HTTPTestServer::DidRead(net::StreamListenSocket* socket,
    302                              const char* data,
    303                              int len) {
    304   scoped_refptr<ConfigurableConnection> connection =
    305       ConnectionFromSocket(socket);
    306   if (connection) {
    307     std::string str(data, len);
    308     connection->r_.OnDataReceived(str);
    309     if (connection->r_.AllContentReceived()) {
    310       VLOG(1) << __FUNCTION__ << ": " << connection->r_.method() << " "
    311               << connection->r_.path();
    312       std::wstring path = UTF8ToWide(connection->r_.path());
    313       if (LowerCaseEqualsASCII(connection->r_.method(), "post"))
    314         this->Post(connection, path, connection->r_);
    315       else
    316         this->Get(connection, path, connection->r_);
    317     }
    318   }
    319 }
    320 
    321 void HTTPTestServer::DidClose(net::StreamListenSocket* socket) {
    322   ConnectionList::iterator it = FindConnection(socket);
    323   if (it != connection_list_.end())
    324     connection_list_.erase(it);
    325 }
    326 
    327 std::wstring HTTPTestServer::Resolve(const std::wstring& path) {
    328   // Remove the first '/' if needed.
    329   std::wstring stripped_path = path;
    330   if (path.size() && path[0] == L'/')
    331     stripped_path = path.substr(1);
    332 
    333   if (port_ == 80) {
    334     if (stripped_path.empty()) {
    335       return base::StringPrintf(L"http://%ls", address_.c_str());
    336     } else {
    337       return base::StringPrintf(L"http://%ls/%ls", address_.c_str(),
    338                           stripped_path.c_str());
    339     }
    340   } else {
    341     if (stripped_path.empty()) {
    342       return base::StringPrintf(L"http://%ls:%d", address_.c_str(), port_);
    343     } else {
    344       return base::StringPrintf(L"http://%ls:%d/%ls", address_.c_str(), port_,
    345                                 stripped_path.c_str());
    346     }
    347   }
    348 }
    349 
    350 void ConfigurableConnection::SendChunk() {
    351   int size = (int)data_.size();
    352   const char* chunk_ptr = data_.c_str() + cur_pos_;
    353   int bytes_to_send = std::min(options_.chunk_size_, size - cur_pos_);
    354 
    355   socket_->Send(chunk_ptr, bytes_to_send);
    356   VLOG(1) << "Sent(" << cur_pos_ << "," << bytes_to_send << "): "
    357           << base::StringPiece(chunk_ptr, bytes_to_send);
    358 
    359   cur_pos_ += bytes_to_send;
    360   if (cur_pos_ < size) {
    361     base::MessageLoop::current()->PostDelayedTask(
    362         FROM_HERE, base::Bind(&ConfigurableConnection::SendChunk, this),
    363         base::TimeDelta::FromMilliseconds(options_.timeout_));
    364   } else {
    365     Close();
    366   }
    367 }
    368 
    369 void ConfigurableConnection::Close() {
    370   socket_.reset();
    371 }
    372 
    373 void ConfigurableConnection::Send(const std::string& headers,
    374                                   const std::string& content) {
    375   SendOptions options(SendOptions::IMMEDIATE, 0, 0);
    376   SendWithOptions(headers, content, options);
    377 }
    378 
    379 void ConfigurableConnection::SendWithOptions(const std::string& headers,
    380                                              const std::string& content,
    381                                              const SendOptions& options) {
    382   std::string content_length_header;
    383   if (!content.empty() &&
    384       std::string::npos == headers.find("Context-Length:")) {
    385     content_length_header = base::StringPrintf("Content-Length: %u\r\n",
    386                                                content.size());
    387   }
    388 
    389   // Save the options.
    390   options_ = options;
    391 
    392   if (options_.speed_ == SendOptions::IMMEDIATE) {
    393     socket_->Send(headers);
    394     socket_->Send(content_length_header, true);
    395     socket_->Send(content);
    396     // Post a task to close the socket since StreamListenSocket doesn't like
    397     // instances to go away from within its callbacks.
    398     base::MessageLoop::current()->PostTask(
    399         FROM_HERE, base::Bind(&ConfigurableConnection::Close, this));
    400 
    401     return;
    402   }
    403 
    404   if (options_.speed_ == SendOptions::IMMEDIATE_HEADERS_DELAYED_CONTENT) {
    405     socket_->Send(headers);
    406     socket_->Send(content_length_header, true);
    407     VLOG(1) << "Headers sent: " << headers << content_length_header;
    408     data_.append(content);
    409   }
    410 
    411   if (options_.speed_ == SendOptions::DELAYED) {
    412     data_ = headers;
    413     data_.append(content_length_header);
    414     data_.append("\r\n");
    415   }
    416 
    417   base::MessageLoop::current()->PostDelayedTask(
    418       FROM_HERE, base::Bind(&ConfigurableConnection::SendChunk, this),
    419       base::TimeDelta::FromMilliseconds(options.timeout_));
    420 }
    421 
    422 }  // namespace test_server
    423