1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include <vector> 6 7 #include "base/bind.h" 8 #include "base/bind_helpers.h" 9 #include "base/compiler_specific.h" 10 #include "base/format_macros.h" 11 #include "base/memory/ref_counted.h" 12 #include "base/memory/scoped_ptr.h" 13 #include "base/memory/weak_ptr.h" 14 #include "base/message_loop/message_loop.h" 15 #include "base/message_loop/message_loop_proxy.h" 16 #include "base/run_loop.h" 17 #include "base/strings/string_split.h" 18 #include "base/strings/string_util.h" 19 #include "base/strings/stringprintf.h" 20 #include "base/time/time.h" 21 #include "net/base/address_list.h" 22 #include "net/base/io_buffer.h" 23 #include "net/base/ip_endpoint.h" 24 #include "net/base/net_errors.h" 25 #include "net/base/net_log.h" 26 #include "net/server/http_server.h" 27 #include "net/server/http_server_request_info.h" 28 #include "net/socket/tcp_client_socket.h" 29 #include "net/socket/tcp_listen_socket.h" 30 #include "net/url_request/url_fetcher.h" 31 #include "net/url_request/url_fetcher_delegate.h" 32 #include "net/url_request/url_request_context.h" 33 #include "net/url_request/url_request_context_getter.h" 34 #include "net/url_request/url_request_test_util.h" 35 #include "testing/gtest/include/gtest/gtest.h" 36 37 namespace net { 38 39 namespace { 40 41 void SetTimedOutAndQuitLoop(const base::WeakPtr<bool> timed_out, 42 const base::Closure& quit_loop_func) { 43 if (timed_out) { 44 *timed_out = true; 45 quit_loop_func.Run(); 46 } 47 } 48 49 bool RunLoopWithTimeout(base::RunLoop* run_loop) { 50 bool timed_out = false; 51 base::WeakPtrFactory<bool> timed_out_weak_factory(&timed_out); 52 base::MessageLoop::current()->PostDelayedTask( 53 FROM_HERE, 54 base::Bind(&SetTimedOutAndQuitLoop, 55 timed_out_weak_factory.GetWeakPtr(), 56 run_loop->QuitClosure()), 57 base::TimeDelta::FromSeconds(1)); 58 run_loop->Run(); 59 return !timed_out; 60 } 61 62 class TestHttpClient { 63 public: 64 TestHttpClient() : connect_result_(OK) {} 65 66 int ConnectAndWait(const IPEndPoint& address) { 67 AddressList addresses(address); 68 NetLog::Source source; 69 socket_.reset(new TCPClientSocket(addresses, NULL, source)); 70 71 base::RunLoop run_loop; 72 connect_result_ = socket_->Connect(base::Bind(&TestHttpClient::OnConnect, 73 base::Unretained(this), 74 run_loop.QuitClosure())); 75 if (connect_result_ != OK && connect_result_ != ERR_IO_PENDING) 76 return connect_result_; 77 78 if (!RunLoopWithTimeout(&run_loop)) 79 return ERR_TIMED_OUT; 80 return connect_result_; 81 } 82 83 void Send(const std::string& data) { 84 write_buffer_ = 85 new DrainableIOBuffer(new StringIOBuffer(data), data.length()); 86 Write(); 87 } 88 89 private: 90 void OnConnect(const base::Closure& quit_loop, int result) { 91 connect_result_ = result; 92 quit_loop.Run(); 93 } 94 95 void Write() { 96 int result = socket_->Write( 97 write_buffer_.get(), 98 write_buffer_->BytesRemaining(), 99 base::Bind(&TestHttpClient::OnWrite, base::Unretained(this))); 100 if (result != ERR_IO_PENDING) 101 OnWrite(result); 102 } 103 104 void OnWrite(int result) { 105 ASSERT_GT(result, 0); 106 write_buffer_->DidConsume(result); 107 if (write_buffer_->BytesRemaining()) 108 Write(); 109 } 110 111 scoped_refptr<DrainableIOBuffer> write_buffer_; 112 scoped_ptr<TCPClientSocket> socket_; 113 int connect_result_; 114 }; 115 116 } // namespace 117 118 class HttpServerTest : public testing::Test, 119 public HttpServer::Delegate { 120 public: 121 HttpServerTest() : quit_after_request_count_(0) {} 122 123 virtual void SetUp() OVERRIDE { 124 TCPListenSocketFactory socket_factory("127.0.0.1", 0); 125 server_ = new HttpServer(socket_factory, this); 126 ASSERT_EQ(OK, server_->GetLocalAddress(&server_address_)); 127 } 128 129 virtual void OnHttpRequest(int connection_id, 130 const HttpServerRequestInfo& info) OVERRIDE { 131 requests_.push_back(info); 132 if (requests_.size() == quit_after_request_count_) 133 run_loop_quit_func_.Run(); 134 } 135 136 virtual void OnWebSocketRequest(int connection_id, 137 const HttpServerRequestInfo& info) OVERRIDE { 138 NOTREACHED(); 139 } 140 141 virtual void OnWebSocketMessage(int connection_id, 142 const std::string& data) OVERRIDE { 143 NOTREACHED(); 144 } 145 146 virtual void OnClose(int connection_id) OVERRIDE {} 147 148 bool RunUntilRequestsReceived(size_t count) { 149 quit_after_request_count_ = count; 150 if (requests_.size() == count) 151 return true; 152 153 base::RunLoop run_loop; 154 run_loop_quit_func_ = run_loop.QuitClosure(); 155 bool success = RunLoopWithTimeout(&run_loop); 156 run_loop_quit_func_.Reset(); 157 return success; 158 } 159 160 protected: 161 scoped_refptr<HttpServer> server_; 162 IPEndPoint server_address_; 163 base::Closure run_loop_quit_func_; 164 std::vector<HttpServerRequestInfo> requests_; 165 166 private: 167 size_t quit_after_request_count_; 168 }; 169 170 TEST_F(HttpServerTest, Request) { 171 TestHttpClient client; 172 ASSERT_EQ(OK, client.ConnectAndWait(server_address_)); 173 client.Send("GET /test HTTP/1.1\r\n\r\n"); 174 ASSERT_TRUE(RunUntilRequestsReceived(1)); 175 ASSERT_EQ("GET", requests_[0].method); 176 ASSERT_EQ("/test", requests_[0].path); 177 ASSERT_EQ("", requests_[0].data); 178 ASSERT_EQ(0u, requests_[0].headers.size()); 179 } 180 181 TEST_F(HttpServerTest, RequestWithHeaders) { 182 TestHttpClient client; 183 ASSERT_EQ(OK, client.ConnectAndWait(server_address_)); 184 const char* kHeaders[][3] = { 185 {"Header", ": ", "1"}, 186 {"HeaderWithNoWhitespace", ":", "1"}, 187 {"HeaderWithWhitespace", " : \t ", "1 1 1 \t "}, 188 {"HeaderWithColon", ": ", "1:1"}, 189 {"EmptyHeader", ":", ""}, 190 {"EmptyHeaderWithWhitespace", ": \t ", ""}, 191 {"HeaderWithNonASCII", ": ", "\u00f7"}, 192 }; 193 std::string headers; 194 for (size_t i = 0; i < arraysize(kHeaders); ++i) { 195 headers += 196 std::string(kHeaders[i][0]) + kHeaders[i][1] + kHeaders[i][2] + "\r\n"; 197 } 198 199 client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n"); 200 ASSERT_TRUE(RunUntilRequestsReceived(1)); 201 ASSERT_EQ("", requests_[0].data); 202 203 for (size_t i = 0; i < arraysize(kHeaders); ++i) { 204 std::string field = StringToLowerASCII(std::string(kHeaders[i][0])); 205 std::string value = kHeaders[i][2]; 206 ASSERT_EQ(1u, requests_[0].headers.count(field)) << field; 207 ASSERT_EQ(value, requests_[0].headers[field]) << kHeaders[i][0]; 208 } 209 } 210 211 TEST_F(HttpServerTest, RequestWithBody) { 212 TestHttpClient client; 213 ASSERT_EQ(OK, client.ConnectAndWait(server_address_)); 214 std::string body = "a" + std::string(1 << 10, 'b') + "c"; 215 client.Send(base::StringPrintf( 216 "GET /test HTTP/1.1\r\n" 217 "SomeHeader: 1\r\n" 218 "Content-Length: %" PRIuS "\r\n\r\n%s", 219 body.length(), 220 body.c_str())); 221 ASSERT_TRUE(RunUntilRequestsReceived(1)); 222 ASSERT_EQ(2u, requests_[0].headers.size()); 223 ASSERT_EQ(body.length(), requests_[0].data.length()); 224 ASSERT_EQ('a', body[0]); 225 ASSERT_EQ('c', *body.rbegin()); 226 } 227 228 TEST_F(HttpServerTest, RequestWithTooLargeBody) { 229 class TestURLFetcherDelegate : public URLFetcherDelegate { 230 public: 231 TestURLFetcherDelegate(const base::Closure& quit_loop_func) 232 : quit_loop_func_(quit_loop_func) {} 233 virtual ~TestURLFetcherDelegate() {} 234 235 virtual void OnURLFetchComplete(const URLFetcher* source) OVERRIDE { 236 EXPECT_EQ(HTTP_INTERNAL_SERVER_ERROR, source->GetResponseCode()); 237 quit_loop_func_.Run(); 238 } 239 240 private: 241 base::Closure quit_loop_func_; 242 }; 243 244 base::RunLoop run_loop; 245 TestURLFetcherDelegate delegate(run_loop.QuitClosure()); 246 247 scoped_refptr<URLRequestContextGetter> request_context_getter( 248 new TestURLRequestContextGetter(base::MessageLoopProxy::current())); 249 scoped_ptr<URLFetcher> fetcher( 250 URLFetcher::Create(GURL(base::StringPrintf("http://127.0.0.1:%d/test", 251 server_address_.port())), 252 URLFetcher::GET, 253 &delegate)); 254 fetcher->SetRequestContext(request_context_getter.get()); 255 fetcher->AddExtraRequestHeader( 256 base::StringPrintf("content-length:%d", 1 << 30)); 257 fetcher->Start(); 258 259 ASSERT_TRUE(RunLoopWithTimeout(&run_loop)); 260 ASSERT_EQ(0u, requests_.size()); 261 } 262 263 namespace { 264 265 class MockStreamListenSocket : public StreamListenSocket { 266 public: 267 MockStreamListenSocket(StreamListenSocket::Delegate* delegate) 268 : StreamListenSocket(kInvalidSocket, delegate) {} 269 270 virtual void Accept() OVERRIDE { NOTREACHED(); } 271 272 private: 273 virtual ~MockStreamListenSocket() {} 274 }; 275 276 } // namespace 277 278 TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) { 279 scoped_refptr<StreamListenSocket> socket( 280 new MockStreamListenSocket(server_.get())); 281 server_->DidAccept(NULL, socket.get()); 282 std::string body("body"); 283 std::string request = base::StringPrintf( 284 "GET /test HTTP/1.1\r\n" 285 "SomeHeader: 1\r\n" 286 "Content-Length: %" PRIuS "\r\n\r\n%s", 287 body.length(), 288 body.c_str()); 289 server_->DidRead(socket.get(), request.c_str(), request.length() - 2); 290 ASSERT_EQ(0u, requests_.size()); 291 server_->DidRead(socket.get(), request.c_str() + request.length() - 2, 2); 292 ASSERT_EQ(1u, requests_.size()); 293 ASSERT_EQ(body, requests_[0].data); 294 } 295 296 TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) { 297 // The idea behind this test is that requests with or without bodies should 298 // not break parsing of the next request. 299 TestHttpClient client; 300 ASSERT_EQ(OK, client.ConnectAndWait(server_address_)); 301 std::string body = "body"; 302 client.Send(base::StringPrintf( 303 "GET /test HTTP/1.1\r\n" 304 "Content-Length: %" PRIuS "\r\n\r\n%s", 305 body.length(), 306 body.c_str())); 307 ASSERT_TRUE(RunUntilRequestsReceived(1)); 308 ASSERT_EQ(body, requests_[0].data); 309 310 client.Send("GET /test2 HTTP/1.1\r\n\r\n"); 311 ASSERT_TRUE(RunUntilRequestsReceived(2)); 312 ASSERT_EQ("/test2", requests_[1].path); 313 314 client.Send("GET /test3 HTTP/1.1\r\n\r\n"); 315 ASSERT_TRUE(RunUntilRequestsReceived(3)); 316 ASSERT_EQ("/test3", requests_[2].path); 317 } 318 319 } // namespace net 320