Home | History | Annotate | Download | only in server
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <vector>
      6 
      7 #include "base/bind.h"
      8 #include "base/bind_helpers.h"
      9 #include "base/compiler_specific.h"
     10 #include "base/format_macros.h"
     11 #include "base/memory/ref_counted.h"
     12 #include "base/memory/scoped_ptr.h"
     13 #include "base/memory/weak_ptr.h"
     14 #include "base/message_loop/message_loop.h"
     15 #include "base/message_loop/message_loop_proxy.h"
     16 #include "base/run_loop.h"
     17 #include "base/strings/string_split.h"
     18 #include "base/strings/string_util.h"
     19 #include "base/strings/stringprintf.h"
     20 #include "base/time/time.h"
     21 #include "net/base/address_list.h"
     22 #include "net/base/io_buffer.h"
     23 #include "net/base/ip_endpoint.h"
     24 #include "net/base/net_errors.h"
     25 #include "net/base/net_log.h"
     26 #include "net/server/http_server.h"
     27 #include "net/server/http_server_request_info.h"
     28 #include "net/socket/tcp_client_socket.h"
     29 #include "net/socket/tcp_listen_socket.h"
     30 #include "net/url_request/url_fetcher.h"
     31 #include "net/url_request/url_fetcher_delegate.h"
     32 #include "net/url_request/url_request_context.h"
     33 #include "net/url_request/url_request_context_getter.h"
     34 #include "net/url_request/url_request_test_util.h"
     35 #include "testing/gtest/include/gtest/gtest.h"
     36 
     37 namespace net {
     38 
     39 namespace {
     40 
     41 void SetTimedOutAndQuitLoop(const base::WeakPtr<bool> timed_out,
     42                             const base::Closure& quit_loop_func) {
     43   if (timed_out) {
     44     *timed_out = true;
     45     quit_loop_func.Run();
     46   }
     47 }
     48 
     49 bool RunLoopWithTimeout(base::RunLoop* run_loop) {
     50   bool timed_out = false;
     51   base::WeakPtrFactory<bool> timed_out_weak_factory(&timed_out);
     52   base::MessageLoop::current()->PostDelayedTask(
     53       FROM_HERE,
     54       base::Bind(&SetTimedOutAndQuitLoop,
     55                  timed_out_weak_factory.GetWeakPtr(),
     56                  run_loop->QuitClosure()),
     57       base::TimeDelta::FromSeconds(1));
     58   run_loop->Run();
     59   return !timed_out;
     60 }
     61 
     62 class TestHttpClient {
     63  public:
     64   TestHttpClient() : connect_result_(OK) {}
     65 
     66   int ConnectAndWait(const IPEndPoint& address) {
     67     AddressList addresses(address);
     68     NetLog::Source source;
     69     socket_.reset(new TCPClientSocket(addresses, NULL, source));
     70 
     71     base::RunLoop run_loop;
     72     connect_result_ = socket_->Connect(base::Bind(&TestHttpClient::OnConnect,
     73                                                   base::Unretained(this),
     74                                                   run_loop.QuitClosure()));
     75     if (connect_result_ != OK && connect_result_ != ERR_IO_PENDING)
     76       return connect_result_;
     77 
     78     if (!RunLoopWithTimeout(&run_loop))
     79       return ERR_TIMED_OUT;
     80     return connect_result_;
     81   }
     82 
     83   void Send(const std::string& data) {
     84     write_buffer_ =
     85         new DrainableIOBuffer(new StringIOBuffer(data), data.length());
     86     Write();
     87   }
     88 
     89  private:
     90   void OnConnect(const base::Closure& quit_loop, int result) {
     91     connect_result_ = result;
     92     quit_loop.Run();
     93   }
     94 
     95   void Write() {
     96     int result = socket_->Write(
     97         write_buffer_.get(),
     98         write_buffer_->BytesRemaining(),
     99         base::Bind(&TestHttpClient::OnWrite, base::Unretained(this)));
    100     if (result != ERR_IO_PENDING)
    101       OnWrite(result);
    102   }
    103 
    104   void OnWrite(int result) {
    105     ASSERT_GT(result, 0);
    106     write_buffer_->DidConsume(result);
    107     if (write_buffer_->BytesRemaining())
    108       Write();
    109   }
    110 
    111   scoped_refptr<DrainableIOBuffer> write_buffer_;
    112   scoped_ptr<TCPClientSocket> socket_;
    113   int connect_result_;
    114 };
    115 
    116 }  // namespace
    117 
    118 class HttpServerTest : public testing::Test,
    119                        public HttpServer::Delegate {
    120  public:
    121   HttpServerTest() : quit_after_request_count_(0) {}
    122 
    123   virtual void SetUp() OVERRIDE {
    124     TCPListenSocketFactory socket_factory("127.0.0.1", 0);
    125     server_ = new HttpServer(socket_factory, this);
    126     ASSERT_EQ(OK, server_->GetLocalAddress(&server_address_));
    127   }
    128 
    129   virtual void OnHttpRequest(int connection_id,
    130                              const HttpServerRequestInfo& info) OVERRIDE {
    131     requests_.push_back(info);
    132     if (requests_.size() == quit_after_request_count_)
    133       run_loop_quit_func_.Run();
    134   }
    135 
    136   virtual void OnWebSocketRequest(int connection_id,
    137                                   const HttpServerRequestInfo& info) OVERRIDE {
    138     NOTREACHED();
    139   }
    140 
    141   virtual void OnWebSocketMessage(int connection_id,
    142                                   const std::string& data) OVERRIDE {
    143     NOTREACHED();
    144   }
    145 
    146   virtual void OnClose(int connection_id) OVERRIDE {}
    147 
    148   bool RunUntilRequestsReceived(size_t count) {
    149     quit_after_request_count_ = count;
    150     if (requests_.size() == count)
    151       return true;
    152 
    153     base::RunLoop run_loop;
    154     run_loop_quit_func_ = run_loop.QuitClosure();
    155     bool success = RunLoopWithTimeout(&run_loop);
    156     run_loop_quit_func_.Reset();
    157     return success;
    158   }
    159 
    160  protected:
    161   scoped_refptr<HttpServer> server_;
    162   IPEndPoint server_address_;
    163   base::Closure run_loop_quit_func_;
    164   std::vector<HttpServerRequestInfo> requests_;
    165 
    166  private:
    167   size_t quit_after_request_count_;
    168 };
    169 
    170 TEST_F(HttpServerTest, Request) {
    171   TestHttpClient client;
    172   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
    173   client.Send("GET /test HTTP/1.1\r\n\r\n");
    174   ASSERT_TRUE(RunUntilRequestsReceived(1));
    175   ASSERT_EQ("GET", requests_[0].method);
    176   ASSERT_EQ("/test", requests_[0].path);
    177   ASSERT_EQ("", requests_[0].data);
    178   ASSERT_EQ(0u, requests_[0].headers.size());
    179 }
    180 
    181 TEST_F(HttpServerTest, RequestWithHeaders) {
    182   TestHttpClient client;
    183   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
    184   const char* kHeaders[][3] = {
    185       {"Header", ": ", "1"},
    186       {"HeaderWithNoWhitespace", ":", "1"},
    187       {"HeaderWithWhitespace", "   :  \t   ", "1 1 1 \t  "},
    188       {"HeaderWithColon", ": ", "1:1"},
    189       {"EmptyHeader", ":", ""},
    190       {"EmptyHeaderWithWhitespace", ":  \t  ", ""},
    191       {"HeaderWithNonASCII", ":  ", "\u00f7"},
    192   };
    193   std::string headers;
    194   for (size_t i = 0; i < arraysize(kHeaders); ++i) {
    195     headers +=
    196         std::string(kHeaders[i][0]) + kHeaders[i][1] + kHeaders[i][2] + "\r\n";
    197   }
    198 
    199   client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
    200   ASSERT_TRUE(RunUntilRequestsReceived(1));
    201   ASSERT_EQ("", requests_[0].data);
    202 
    203   for (size_t i = 0; i < arraysize(kHeaders); ++i) {
    204     std::string field = StringToLowerASCII(std::string(kHeaders[i][0]));
    205     std::string value = kHeaders[i][2];
    206     ASSERT_EQ(1u, requests_[0].headers.count(field)) << field;
    207     ASSERT_EQ(value, requests_[0].headers[field]) << kHeaders[i][0];
    208   }
    209 }
    210 
    211 TEST_F(HttpServerTest, RequestWithBody) {
    212   TestHttpClient client;
    213   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
    214   std::string body = "a" + std::string(1 << 10, 'b') + "c";
    215   client.Send(base::StringPrintf(
    216       "GET /test HTTP/1.1\r\n"
    217       "SomeHeader: 1\r\n"
    218       "Content-Length: %" PRIuS "\r\n\r\n%s",
    219       body.length(),
    220       body.c_str()));
    221   ASSERT_TRUE(RunUntilRequestsReceived(1));
    222   ASSERT_EQ(2u, requests_[0].headers.size());
    223   ASSERT_EQ(body.length(), requests_[0].data.length());
    224   ASSERT_EQ('a', body[0]);
    225   ASSERT_EQ('c', *body.rbegin());
    226 }
    227 
    228 TEST_F(HttpServerTest, RequestWithTooLargeBody) {
    229   class TestURLFetcherDelegate : public URLFetcherDelegate {
    230    public:
    231     TestURLFetcherDelegate(const base::Closure& quit_loop_func)
    232         : quit_loop_func_(quit_loop_func) {}
    233     virtual ~TestURLFetcherDelegate() {}
    234 
    235     virtual void OnURLFetchComplete(const URLFetcher* source) OVERRIDE {
    236       EXPECT_EQ(HTTP_INTERNAL_SERVER_ERROR, source->GetResponseCode());
    237       quit_loop_func_.Run();
    238     }
    239 
    240    private:
    241     base::Closure quit_loop_func_;
    242   };
    243 
    244   base::RunLoop run_loop;
    245   TestURLFetcherDelegate delegate(run_loop.QuitClosure());
    246 
    247   scoped_refptr<URLRequestContextGetter> request_context_getter(
    248       new TestURLRequestContextGetter(base::MessageLoopProxy::current()));
    249   scoped_ptr<URLFetcher> fetcher(
    250       URLFetcher::Create(GURL(base::StringPrintf("http://127.0.0.1:%d/test",
    251                                                  server_address_.port())),
    252                          URLFetcher::GET,
    253                          &delegate));
    254   fetcher->SetRequestContext(request_context_getter.get());
    255   fetcher->AddExtraRequestHeader(
    256       base::StringPrintf("content-length:%d", 1 << 30));
    257   fetcher->Start();
    258 
    259   ASSERT_TRUE(RunLoopWithTimeout(&run_loop));
    260   ASSERT_EQ(0u, requests_.size());
    261 }
    262 
    263 namespace {
    264 
    265 class MockStreamListenSocket : public StreamListenSocket {
    266  public:
    267   MockStreamListenSocket(StreamListenSocket::Delegate* delegate)
    268       : StreamListenSocket(kInvalidSocket, delegate) {}
    269 
    270   virtual void Accept() OVERRIDE { NOTREACHED(); }
    271 
    272  private:
    273   virtual ~MockStreamListenSocket() {}
    274 };
    275 
    276 }  // namespace
    277 
    278 TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) {
    279   scoped_refptr<StreamListenSocket> socket(
    280       new MockStreamListenSocket(server_.get()));
    281   server_->DidAccept(NULL, socket.get());
    282   std::string body("body");
    283   std::string request = base::StringPrintf(
    284       "GET /test HTTP/1.1\r\n"
    285       "SomeHeader: 1\r\n"
    286       "Content-Length: %" PRIuS "\r\n\r\n%s",
    287       body.length(),
    288       body.c_str());
    289   server_->DidRead(socket.get(), request.c_str(), request.length() - 2);
    290   ASSERT_EQ(0u, requests_.size());
    291   server_->DidRead(socket.get(), request.c_str() + request.length() - 2, 2);
    292   ASSERT_EQ(1u, requests_.size());
    293   ASSERT_EQ(body, requests_[0].data);
    294 }
    295 
    296 TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) {
    297   // The idea behind this test is that requests with or without bodies should
    298   // not break parsing of the next request.
    299   TestHttpClient client;
    300   ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
    301   std::string body = "body";
    302   client.Send(base::StringPrintf(
    303       "GET /test HTTP/1.1\r\n"
    304       "Content-Length: %" PRIuS "\r\n\r\n%s",
    305       body.length(),
    306       body.c_str()));
    307   ASSERT_TRUE(RunUntilRequestsReceived(1));
    308   ASSERT_EQ(body, requests_[0].data);
    309 
    310   client.Send("GET /test2 HTTP/1.1\r\n\r\n");
    311   ASSERT_TRUE(RunUntilRequestsReceived(2));
    312   ASSERT_EQ("/test2", requests_[1].path);
    313 
    314   client.Send("GET /test3 HTTP/1.1\r\n\r\n");
    315   ASSERT_TRUE(RunUntilRequestsReceived(3));
    316   ASSERT_EQ("/test3", requests_[2].path);
    317 }
    318 
    319 }  // namespace net
    320