1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "chrome_frame/test/test_server.h" 6 7 #include <windows.h> 8 #include <objbase.h> 9 #include <urlmon.h> 10 11 #include "base/bind.h" 12 #include "base/logging.h" 13 #include "base/strings/string_number_conversions.h" 14 #include "base/strings/string_piece.h" 15 #include "base/strings/string_util.h" 16 #include "base/strings/stringprintf.h" 17 #include "base/strings/utf_string_conversions.h" 18 #include "chrome_frame/test/chrome_frame_test_utils.h" 19 #include "net/base/winsock_init.h" 20 #include "net/http/http_util.h" 21 #include "net/socket/tcp_listen_socket.h" 22 23 namespace test_server { 24 const char kDefaultHeaderTemplate[] = 25 "HTTP/1.1 %hs\r\n" 26 "Connection: close\r\n" 27 "Content-Type: %hs\r\n" 28 "Content-Length: %i\r\n\r\n"; 29 const char kStatusOk[] = "200 OK"; 30 const char kStatusNotFound[] = "404 Not Found"; 31 const char kDefaultContentType[] = "text/html; charset=UTF-8"; 32 33 void Request::ParseHeaders(const std::string& headers) { 34 DCHECK(method_.length() == 0); 35 36 size_t pos = headers.find("\r\n"); 37 DCHECK(pos != std::string::npos); 38 if (pos != std::string::npos) { 39 headers_ = headers.substr(pos + 2); 40 41 base::StringTokenizer tokenizer( 42 headers.begin(), headers.begin() + pos, " "); 43 std::string* parse[] = { &method_, &path_, &version_ }; 44 int field = 0; 45 while (tokenizer.GetNext() && field < arraysize(parse)) { 46 parse[field++]->assign(tokenizer.token_begin(), 47 tokenizer.token_end()); 48 } 49 } 50 51 // Check for content-length in case we're being sent some data. 52 net::HttpUtil::HeadersIterator it(headers_.begin(), headers_.end(), 53 "\r\n"); 54 while (it.GetNext()) { 55 if (LowerCaseEqualsASCII(it.name(), "content-length")) { 56 int int_content_length; 57 base::StringToInt(base::StringPiece(it.values_begin(), 58 it.values_end()), 59 &int_content_length); 60 content_length_ = int_content_length; 61 break; 62 } 63 } 64 } 65 66 void Request::OnDataReceived(const std::string& data) { 67 content_ += data; 68 69 if (method_.length() == 0) { 70 size_t index = content_.find("\r\n\r\n"); 71 if (index != std::string::npos) { 72 // Parse the headers before returning and chop them of the 73 // data buffer we've already received. 74 std::string headers(content_.substr(0, index + 2)); 75 ParseHeaders(headers); 76 content_.erase(0, index + 4); 77 } 78 } 79 } 80 81 ResponseForPath::~ResponseForPath() { 82 } 83 84 SimpleResponse::~SimpleResponse() { 85 } 86 87 bool FileResponse::GetContentType(std::string* content_type) const { 88 size_t length = ContentLength(); 89 char buffer[4096]; 90 void* data = NULL; 91 92 if (length) { 93 // Create a copy of the first few bytes of the file. 94 // If we try and use the mapped file directly, FindMimeFromData will crash 95 // 'cause it cheats and temporarily tries to write to the buffer! 96 length = std::min(arraysize(buffer), length); 97 memcpy(buffer, file_->data(), length); 98 data = buffer; 99 } 100 101 LPOLESTR mime_type = NULL; 102 FindMimeFromData(NULL, file_path_.value().c_str(), data, length, NULL, 103 FMFD_DEFAULT, &mime_type, 0); 104 if (mime_type) { 105 *content_type = WideToASCII(mime_type); 106 ::CoTaskMemFree(mime_type); 107 } 108 109 return content_type->length() > 0; 110 } 111 112 void FileResponse::WriteContents(net::StreamListenSocket* socket) const { 113 DCHECK(file_.get()); 114 if (file_.get()) { 115 socket->Send(reinterpret_cast<const char*>(file_->data()), 116 file_->length(), false); 117 } 118 } 119 120 size_t FileResponse::ContentLength() const { 121 if (file_.get() == NULL) { 122 file_.reset(new base::MemoryMappedFile()); 123 if (!file_->Initialize(file_path_)) { 124 NOTREACHED(); 125 file_.reset(); 126 } 127 } 128 return file_.get() ? file_->length() : 0; 129 } 130 131 bool RedirectResponse::GetCustomHeaders(std::string* headers) const { 132 *headers = base::StringPrintf("HTTP/1.1 302 Found\r\n" 133 "Connection: close\r\n" 134 "Content-Length: 0\r\n" 135 "Content-Type: text/html\r\n" 136 "Location: %hs\r\n\r\n", 137 redirect_url_.c_str()); 138 return true; 139 } 140 141 SimpleWebServer::SimpleWebServer(int port) { 142 Construct(chrome_frame_test::GetLocalIPv4Address(), port); 143 } 144 145 SimpleWebServer::SimpleWebServer(const std::string& address, int port) { 146 Construct(address, port); 147 } 148 149 SimpleWebServer::~SimpleWebServer() { 150 ConnectionList::const_iterator it; 151 for (it = connections_.begin(); it != connections_.end(); ++it) 152 delete (*it); 153 connections_.clear(); 154 } 155 156 void SimpleWebServer::Construct(const std::string& address, int port) { 157 CHECK(base::MessageLoop::current()) 158 << "SimpleWebServer requires a message loop"; 159 net::EnsureWinsockInit(); 160 AddResponse(&quit_); 161 host_ = address; 162 server_ = net::TCPListenSocket::CreateAndListen(address, port, this); 163 LOG_IF(DFATAL, !server_.get()) 164 << "Failed to create listener socket at " << address << ":" << port; 165 } 166 167 void SimpleWebServer::AddResponse(Response* response) { 168 responses_.push_back(response); 169 } 170 171 void SimpleWebServer::DeleteAllResponses() { 172 std::list<Response*>::const_iterator it; 173 for (it = responses_.begin(); it != responses_.end(); ++it) { 174 if ((*it) != &quit_) 175 delete (*it); 176 } 177 } 178 179 Response* SimpleWebServer::FindResponse(const Request& request) const { 180 std::list<Response*>::const_iterator it; 181 for (it = responses_.begin(); it != responses_.end(); it++) { 182 Response* response = (*it); 183 if (response->Matches(request)) { 184 return response; 185 } 186 } 187 return NULL; 188 } 189 190 Connection* SimpleWebServer::FindConnection( 191 const net::StreamListenSocket* socket) const { 192 ConnectionList::const_iterator it; 193 for (it = connections_.begin(); it != connections_.end(); it++) { 194 if ((*it)->IsSame(socket)) { 195 return (*it); 196 } 197 } 198 return NULL; 199 } 200 201 void SimpleWebServer::DidAccept( 202 net::StreamListenSocket* server, 203 scoped_ptr<net::StreamListenSocket> connection) { 204 connections_.push_back(new Connection(connection.Pass())); 205 } 206 207 void SimpleWebServer::DidRead(net::StreamListenSocket* connection, 208 const char* data, 209 int len) { 210 Connection* c = FindConnection(connection); 211 DCHECK(c); 212 Request& r = c->request(); 213 std::string str(data, len); 214 r.OnDataReceived(str); 215 if (r.AllContentReceived()) { 216 const Request& request = c->request(); 217 Response* response = FindResponse(request); 218 if (response) { 219 std::string headers; 220 if (!response->GetCustomHeaders(&headers)) { 221 std::string content_type; 222 if (!response->GetContentType(&content_type)) 223 content_type = kDefaultContentType; 224 headers = base::StringPrintf(kDefaultHeaderTemplate, kStatusOk, 225 content_type.c_str(), 226 response->ContentLength()); 227 } 228 229 connection->Send(headers, false); 230 response->WriteContents(connection); 231 response->IncrementAccessCounter(); 232 } else { 233 std::string payload = "sorry, I can't find " + request.path(); 234 std::string headers(base::StringPrintf(kDefaultHeaderTemplate, 235 kStatusNotFound, 236 kDefaultContentType, 237 payload.length())); 238 connection->Send(headers, false); 239 connection->Send(payload, false); 240 } 241 } 242 } 243 244 void SimpleWebServer::DidClose(net::StreamListenSocket* sock) { 245 // To keep the historical list of connections reasonably tidy, we delete 246 // 404's when the connection ends. 247 Connection* c = FindConnection(sock); 248 DCHECK(c); 249 c->OnSocketClosed(); 250 if (!FindResponse(c->request())) { 251 // extremely inefficient, but in one line and not that common... :) 252 connections_.erase(std::find(connections_.begin(), connections_.end(), c)); 253 delete c; 254 } 255 } 256 257 HTTPTestServer::HTTPTestServer(int port, const std::wstring& address, 258 base::FilePath root_dir) 259 : port_(port), address_(address), root_dir_(root_dir) { 260 net::EnsureWinsockInit(); 261 server_ = 262 net::TCPListenSocket::CreateAndListen(WideToUTF8(address), port, this); 263 } 264 265 HTTPTestServer::~HTTPTestServer() { 266 } 267 268 std::list<scoped_refptr<ConfigurableConnection>>::iterator 269 HTTPTestServer::FindConnection(const net::StreamListenSocket* socket) { 270 ConnectionList::iterator it; 271 // Scan through the list searching for the desired socket. Along the way, 272 // erase any connections for which the corresponding socket has already been 273 // forgotten about as a result of all data having been sent. 274 for (it = connection_list_.begin(); it != connection_list_.end(); ) { 275 ConfigurableConnection* connection = it->get(); 276 if (connection->socket_ == NULL) { 277 connection_list_.erase(it++); 278 continue; 279 } 280 if (connection->socket_ == socket) 281 break; 282 ++it; 283 } 284 285 return it; 286 } 287 288 scoped_refptr<ConfigurableConnection> HTTPTestServer::ConnectionFromSocket( 289 const net::StreamListenSocket* socket) { 290 ConnectionList::iterator it = FindConnection(socket); 291 if (it != connection_list_.end()) 292 return *it; 293 return NULL; 294 } 295 296 void HTTPTestServer::DidAccept(net::StreamListenSocket* server, 297 scoped_ptr<net::StreamListenSocket> socket) { 298 connection_list_.push_back(new ConfigurableConnection(socket.Pass())); 299 } 300 301 void HTTPTestServer::DidRead(net::StreamListenSocket* socket, 302 const char* data, 303 int len) { 304 scoped_refptr<ConfigurableConnection> connection = 305 ConnectionFromSocket(socket); 306 if (connection) { 307 std::string str(data, len); 308 connection->r_.OnDataReceived(str); 309 if (connection->r_.AllContentReceived()) { 310 VLOG(1) << __FUNCTION__ << ": " << connection->r_.method() << " " 311 << connection->r_.path(); 312 std::wstring path = UTF8ToWide(connection->r_.path()); 313 if (LowerCaseEqualsASCII(connection->r_.method(), "post")) 314 this->Post(connection, path, connection->r_); 315 else 316 this->Get(connection, path, connection->r_); 317 } 318 } 319 } 320 321 void HTTPTestServer::DidClose(net::StreamListenSocket* socket) { 322 ConnectionList::iterator it = FindConnection(socket); 323 if (it != connection_list_.end()) 324 connection_list_.erase(it); 325 } 326 327 std::wstring HTTPTestServer::Resolve(const std::wstring& path) { 328 // Remove the first '/' if needed. 329 std::wstring stripped_path = path; 330 if (path.size() && path[0] == L'/') 331 stripped_path = path.substr(1); 332 333 if (port_ == 80) { 334 if (stripped_path.empty()) { 335 return base::StringPrintf(L"http://%ls", address_.c_str()); 336 } else { 337 return base::StringPrintf(L"http://%ls/%ls", address_.c_str(), 338 stripped_path.c_str()); 339 } 340 } else { 341 if (stripped_path.empty()) { 342 return base::StringPrintf(L"http://%ls:%d", address_.c_str(), port_); 343 } else { 344 return base::StringPrintf(L"http://%ls:%d/%ls", address_.c_str(), port_, 345 stripped_path.c_str()); 346 } 347 } 348 } 349 350 void ConfigurableConnection::SendChunk() { 351 int size = (int)data_.size(); 352 const char* chunk_ptr = data_.c_str() + cur_pos_; 353 int bytes_to_send = std::min(options_.chunk_size_, size - cur_pos_); 354 355 socket_->Send(chunk_ptr, bytes_to_send); 356 VLOG(1) << "Sent(" << cur_pos_ << "," << bytes_to_send << "): " 357 << base::StringPiece(chunk_ptr, bytes_to_send); 358 359 cur_pos_ += bytes_to_send; 360 if (cur_pos_ < size) { 361 base::MessageLoop::current()->PostDelayedTask( 362 FROM_HERE, base::Bind(&ConfigurableConnection::SendChunk, this), 363 base::TimeDelta::FromMilliseconds(options_.timeout_)); 364 } else { 365 Close(); 366 } 367 } 368 369 void ConfigurableConnection::Close() { 370 socket_.reset(); 371 } 372 373 void ConfigurableConnection::Send(const std::string& headers, 374 const std::string& content) { 375 SendOptions options(SendOptions::IMMEDIATE, 0, 0); 376 SendWithOptions(headers, content, options); 377 } 378 379 void ConfigurableConnection::SendWithOptions(const std::string& headers, 380 const std::string& content, 381 const SendOptions& options) { 382 std::string content_length_header; 383 if (!content.empty() && 384 std::string::npos == headers.find("Context-Length:")) { 385 content_length_header = base::StringPrintf("Content-Length: %u\r\n", 386 content.size()); 387 } 388 389 // Save the options. 390 options_ = options; 391 392 if (options_.speed_ == SendOptions::IMMEDIATE) { 393 socket_->Send(headers); 394 socket_->Send(content_length_header, true); 395 socket_->Send(content); 396 // Post a task to close the socket since StreamListenSocket doesn't like 397 // instances to go away from within its callbacks. 398 base::MessageLoop::current()->PostTask( 399 FROM_HERE, base::Bind(&ConfigurableConnection::Close, this)); 400 401 return; 402 } 403 404 if (options_.speed_ == SendOptions::IMMEDIATE_HEADERS_DELAYED_CONTENT) { 405 socket_->Send(headers); 406 socket_->Send(content_length_header, true); 407 VLOG(1) << "Headers sent: " << headers << content_length_header; 408 data_.append(content); 409 } 410 411 if (options_.speed_ == SendOptions::DELAYED) { 412 data_ = headers; 413 data_.append(content_length_header); 414 data_.append("\r\n"); 415 } 416 417 base::MessageLoop::current()->PostDelayedTask( 418 FROM_HERE, base::Bind(&ConfigurableConnection::SendChunk, this), 419 base::TimeDelta::FromMilliseconds(options.timeout_)); 420 } 421 422 } // namespace test_server 423