Home | History | Annotate | Download | only in shill
      1 //
      2 // Copyright (C) 2012 The Android Open Source Project
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 //
     16 
     17 #include "shill/http_request.h"
     18 
     19 #include <netinet/in.h>
     20 
     21 #include <memory>
     22 #include <string>
     23 #include <vector>
     24 
     25 #include <base/bind.h>
     26 #include <base/strings/stringprintf.h>
     27 #include <gtest/gtest.h>
     28 
     29 #include "shill/http_url.h"
     30 #include "shill/mock_async_connection.h"
     31 #include "shill/mock_connection.h"
     32 #include "shill/mock_control.h"
     33 #include "shill/mock_device_info.h"
     34 #include "shill/mock_dns_client.h"
     35 #include "shill/mock_event_dispatcher.h"
     36 #include "shill/net/ip_address.h"
     37 #include "shill/net/mock_sockets.h"
     38 
     39 using base::Bind;
     40 using base::Callback;
     41 using base::StringPrintf;
     42 using base::Unretained;
     43 using std::string;
     44 using std::vector;
     45 using ::testing::_;
     46 using ::testing::AtLeast;
     47 using ::testing::DoAll;
     48 using ::testing::Invoke;
     49 using ::testing::NiceMock;
     50 using ::testing::Return;
     51 using ::testing::ReturnArg;
     52 using ::testing::ReturnNew;
     53 using ::testing::ReturnRef;
     54 using ::testing::SetArgumentPointee;
     55 using ::testing::StrEq;
     56 using ::testing::StrictMock;
     57 using ::testing::Test;
     58 
     59 namespace shill {
     60 
     61 namespace {
     62 const char kTextSiteName[] = "www.chromium.org";
     63 const char kTextURL[] = "http://www.chromium.org/path/to/resource";
     64 const char kNumericURL[] = "http://10.1.1.1";
     65 const char kPath[] = "/path/to/resource";
     66 const char kInterfaceName[] = "int0";
     67 const char kDNSServer0[] = "8.8.8.8";
     68 const char kDNSServer1[] = "8.8.4.4";
     69 const char* kDNSServers[] = { kDNSServer0, kDNSServer1 };
     70 const char kServerAddress[] = "10.1.1.1";
     71 const int kServerFD = 10203;
     72 const int kServerPort = 80;
     73 }  // namespace
     74 
     75 MATCHER_P(IsIPAddress, address, "") {
     76   IPAddress ip_address(IPAddress::kFamilyIPv4);
     77   EXPECT_TRUE(ip_address.SetAddressFromString(address));
     78   return ip_address.Equals(arg);
     79 }
     80 
     81 MATCHER_P(ByteStringMatches, byte_string, "") {
     82   return byte_string.Equals(arg);
     83 }
     84 
     85 MATCHER_P(CallbackEq, callback, "") {
     86   return arg.Equals(callback);
     87 }
     88 
     89 class HTTPRequestTest : public Test {
     90  public:
     91   HTTPRequestTest()
     92       : interface_name_(kInterfaceName),
     93         server_async_connection_(new StrictMock<MockAsyncConnection>()),
     94         dns_servers_(kDNSServers, kDNSServers + 2),
     95         dns_client_(new StrictMock<MockDNSClient>()),
     96         device_info_(
     97             new NiceMock<MockDeviceInfo>(&control_, nullptr, nullptr, nullptr)),
     98         connection_(new StrictMock<MockConnection>(device_info_.get())) {}
     99 
    100  protected:
    101   class CallbackTarget {
    102    public:
    103     CallbackTarget()
    104         : read_event_callback_(
    105               Bind(&CallbackTarget::ReadEventCallTarget, Unretained(this))),
    106           result_callback_(
    107               Bind(&CallbackTarget::ResultCallTarget, Unretained(this))) {}
    108 
    109     MOCK_METHOD1(ReadEventCallTarget, void(const ByteString& response_data));
    110     MOCK_METHOD2(ResultCallTarget, void(HTTPRequest::Result result,
    111                                         const ByteString& response_data));
    112     const Callback<void(const ByteString&)>& read_event_callback() {
    113       return read_event_callback_;
    114     }
    115     const Callback<void(HTTPRequest::Result,
    116                         const ByteString&)>& result_callback() {
    117       return result_callback_;
    118     }
    119 
    120    private:
    121     Callback<void(const ByteString&)> read_event_callback_;
    122     Callback<void(HTTPRequest::Result, const ByteString&)> result_callback_;
    123   };
    124 
    125   virtual void SetUp() {
    126     EXPECT_CALL(*connection_.get(), IsIPv6())
    127         .WillRepeatedly(Return(false));
    128     EXPECT_CALL(*connection_.get(), interface_name())
    129         .WillRepeatedly(ReturnRef(interface_name_));
    130     EXPECT_CALL(*connection_.get(), dns_servers())
    131         .WillRepeatedly(ReturnRef(dns_servers_));
    132 
    133     request_.reset(new HTTPRequest(connection_, &dispatcher_, &sockets_));
    134     // Passes ownership.
    135     request_->dns_client_.reset(dns_client_);
    136     // Passes ownership.
    137     request_->server_async_connection_.reset(server_async_connection_);
    138   }
    139   virtual void TearDown() {
    140     if (request_->is_running_) {
    141       ExpectStop();
    142 
    143       // Subtle: Make sure the finalization of the request happens while our
    144       // expectations are still active.
    145       request_.reset();
    146     }
    147   }
    148   size_t FindInRequestData(const string& find_string) {
    149     string request_string(
    150         reinterpret_cast<char*>(request_->request_data_.GetData()),
    151         request_->request_data_.GetLength());
    152     return request_string.find(find_string);
    153   }
    154   // Accessors
    155   const ByteString& GetRequestData() {
    156     return request_->request_data_;
    157   }
    158   HTTPRequest* request() { return request_.get(); }
    159   MockSockets& sockets() { return sockets_; }
    160 
    161   // Expectations
    162   void ExpectReset() {
    163     EXPECT_EQ(connection_.get(), request_->connection_.get());
    164     EXPECT_EQ(&dispatcher_, request_->dispatcher_);
    165     EXPECT_EQ(&sockets_, request_->sockets_);
    166     EXPECT_TRUE(request_->result_callback_.is_null());
    167     EXPECT_TRUE(request_->read_event_callback_.is_null());
    168     EXPECT_FALSE(request_->connect_completion_callback_.is_null());
    169     EXPECT_FALSE(request_->dns_client_callback_.is_null());
    170     EXPECT_FALSE(request_->read_server_callback_.is_null());
    171     EXPECT_FALSE(request_->write_server_callback_.is_null());
    172     EXPECT_FALSE(request_->read_server_handler_.get());
    173     EXPECT_FALSE(request_->write_server_handler_.get());
    174     EXPECT_EQ(dns_client_, request_->dns_client_.get());
    175     EXPECT_EQ(server_async_connection_,
    176               request_->server_async_connection_.get());
    177     EXPECT_TRUE(request_->server_hostname_.empty());
    178     EXPECT_EQ(-1, request_->server_port_);
    179     EXPECT_EQ(-1, request_->server_socket_);
    180     EXPECT_EQ(HTTPRequest::kResultUnknown, request_->timeout_result_);
    181     EXPECT_TRUE(request_->request_data_.IsEmpty());
    182     EXPECT_TRUE(request_->response_data_.IsEmpty());
    183     EXPECT_FALSE(request_->is_running_);
    184   }
    185   void ExpectStop() {
    186     if (request_->server_socket_ != -1) {
    187       EXPECT_CALL(sockets(), Close(kServerFD))
    188           .WillOnce(Return(0));
    189     }
    190     EXPECT_CALL(*dns_client_, Stop())
    191         .Times(AtLeast(1));
    192     EXPECT_CALL(*server_async_connection_, Stop())
    193         .Times(AtLeast(1));
    194     EXPECT_CALL(*connection_.get(), ReleaseRouting());
    195   }
    196   void ExpectSetTimeout(int timeout) {
    197     EXPECT_CALL(dispatcher_, PostDelayedTask(_, timeout * 1000));
    198   }
    199   void ExpectSetConnectTimeout() {
    200     ExpectSetTimeout(HTTPRequest::kConnectTimeoutSeconds);
    201   }
    202   void ExpectSetInputTimeout() {
    203     ExpectSetTimeout(HTTPRequest::kInputTimeoutSeconds);
    204   }
    205   void ExpectInResponse(const string& expected_response_data) {
    206     string response_string(
    207         reinterpret_cast<char*>(request_->response_data_.GetData()),
    208         request_->response_data_.GetLength());
    209     EXPECT_NE(string::npos, response_string.find(expected_response_data));
    210   }
    211   void ExpectDNSRequest(const string& host, bool return_value) {
    212     EXPECT_CALL(*dns_client_, Start(StrEq(host), _))
    213         .WillOnce(Return(return_value));
    214   }
    215   void ExpectAsyncConnect(const string& address, int port,
    216                           bool return_value) {
    217     EXPECT_CALL(*server_async_connection_, Start(IsIPAddress(address), port))
    218         .WillOnce(Return(return_value));
    219     if (return_value) {
    220       ExpectSetConnectTimeout();
    221     }
    222   }
    223   void  InvokeSyncConnect(const IPAddress& /*address*/, int /*port*/) {
    224     CallConnectCompletion(true, kServerFD);
    225   }
    226   void CallConnectCompletion(bool success, int fd) {
    227     request_->OnConnectCompletion(success, fd);
    228   }
    229   void ExpectSyncConnect(const string& address, int port) {
    230     EXPECT_CALL(*server_async_connection_, Start(IsIPAddress(address), port))
    231         .WillOnce(DoAll(Invoke(this, &HTTPRequestTest::InvokeSyncConnect),
    232                         Return(true)));
    233   }
    234   void ExpectConnectFailure() {
    235     EXPECT_CALL(*server_async_connection_, Start(_, _))
    236         .WillOnce(Return(false));
    237   }
    238   void ExpectMonitorServerInput() {
    239     EXPECT_CALL(dispatcher_,
    240                 CreateInputHandler(kServerFD,
    241                                    CallbackEq(request_->read_server_callback_),
    242                                    _))
    243         .WillOnce(ReturnNew<IOHandler>());
    244     ExpectSetInputTimeout();
    245   }
    246   void ExpectMonitorServerOutput() {
    247     EXPECT_CALL(dispatcher_,
    248                 CreateReadyHandler(
    249                     kServerFD, IOHandler::kModeOutput,
    250                     CallbackEq(request_->write_server_callback_)))
    251         .WillOnce(ReturnNew<IOHandler>());
    252     ExpectSetInputTimeout();
    253   }
    254   void ExpectRouteRequest() {
    255     EXPECT_CALL(*connection_.get(), RequestRouting());
    256   }
    257   void ExpectRouteRelease() {
    258     EXPECT_CALL(*connection_.get(), ReleaseRouting());
    259   }
    260   void ExpectResultCallback(HTTPRequest::Result result) {
    261     EXPECT_CALL(target_, ResultCallTarget(result, _));
    262   }
    263   void InvokeResultVerify(HTTPRequest::Result result,
    264                           const ByteString& response_data) {
    265     EXPECT_EQ(HTTPRequest::kResultSuccess, result);
    266     EXPECT_TRUE(expected_response_.Equals(response_data));
    267   }
    268   void ExpectResultCallbackWithResponse(const string& response) {
    269     expected_response_ = ByteString(response, false);
    270     EXPECT_CALL(target_, ResultCallTarget(HTTPRequest::kResultSuccess, _))
    271         .WillOnce(Invoke(this, &HTTPRequestTest::InvokeResultVerify));
    272   }
    273   void ExpectReadEventCallback(const string& response) {
    274     ByteString response_data(response, false);
    275     EXPECT_CALL(target_, ReadEventCallTarget(ByteStringMatches(response_data)));
    276   }
    277   void GetDNSResultFailure(const string& error_msg) {
    278     Error error(Error::kOperationFailed, error_msg);
    279     IPAddress address(IPAddress::kFamilyUnknown);
    280     request_->GetDNSResult(error, address);
    281   }
    282   void GetDNSResultSuccess(const IPAddress& address) {
    283     Error error;
    284     request_->GetDNSResult(error, address);
    285   }
    286   void OnConnectCompletion(bool result, int sockfd) {
    287     request_->OnConnectCompletion(result, sockfd);
    288   }
    289   void ReadFromServer(const string& data) {
    290     const unsigned char* ptr =
    291         reinterpret_cast<const unsigned char*>(data.c_str());
    292     vector<unsigned char> data_writable(ptr, ptr + data.length());
    293     InputData server_data(data_writable.data(), data_writable.size());
    294     request_->ReadFromServer(&server_data);
    295   }
    296   void WriteToServer(int fd) {
    297     request_->WriteToServer(fd);
    298   }
    299   HTTPRequest::Result StartRequest(const string& url) {
    300     HTTPURL http_url;
    301     EXPECT_TRUE(http_url.ParseFromString(url));
    302     return request_->Start(http_url,
    303                            target_.read_event_callback(),
    304                            target_.result_callback());
    305   }
    306   void SetupConnectWithURL(const string& url, const string& expected_hostname) {
    307     ExpectRouteRequest();
    308     ExpectDNSRequest(expected_hostname, true);
    309     EXPECT_EQ(HTTPRequest::kResultInProgress, StartRequest(url));
    310     IPAddress addr(IPAddress::kFamilyIPv4);
    311     EXPECT_TRUE(addr.SetAddressFromString(kServerAddress));
    312     GetDNSResultSuccess(addr);
    313   }
    314   void SetupConnect() {
    315     SetupConnectWithURL(kTextURL, kTextSiteName);
    316   }
    317   void SetupConnectAsync() {
    318     ExpectAsyncConnect(kServerAddress, kServerPort, true);
    319     SetupConnect();
    320   }
    321   void SetupConnectComplete() {
    322     SetupConnectAsync();
    323     ExpectMonitorServerOutput();
    324     OnConnectCompletion(true, kServerFD);
    325   }
    326   void CallTimeoutTask() {
    327     request_->TimeoutTask();
    328   }
    329   void CallServerErrorCallback() {
    330     request_->OnServerReadError(string());
    331   }
    332 
    333  private:
    334   const string interface_name_;
    335   // Owned by the HTTPRequest, but tracked here for EXPECT().
    336   StrictMock<MockAsyncConnection>* server_async_connection_;
    337   vector<string> dns_servers_;
    338   // Owned by the HTTPRequest, but tracked here for EXPECT().
    339   StrictMock<MockDNSClient>* dns_client_;
    340   StrictMock<MockEventDispatcher> dispatcher_;
    341   MockControl control_;
    342   std::unique_ptr<MockDeviceInfo> device_info_;
    343   scoped_refptr<MockConnection> connection_;
    344   std::unique_ptr<HTTPRequest> request_;
    345   StrictMock<MockSockets> sockets_;
    346   StrictMock<CallbackTarget> target_;
    347   ByteString expected_response_;
    348 };
    349 
    350 TEST_F(HTTPRequestTest, Constructor) {
    351   ExpectReset();
    352 }
    353 
    354 
    355 TEST_F(HTTPRequestTest, FailConnectNumericSynchronous) {
    356   ExpectRouteRequest();
    357   ExpectConnectFailure();
    358   ExpectStop();
    359   EXPECT_EQ(HTTPRequest::kResultConnectionFailure, StartRequest(kNumericURL));
    360   ExpectReset();
    361 }
    362 
    363 TEST_F(HTTPRequestTest, FailConnectNumericAsynchronous) {
    364   ExpectRouteRequest();
    365   ExpectAsyncConnect(kServerAddress, HTTPURL::kDefaultHTTPPort, true);
    366   EXPECT_EQ(HTTPRequest::kResultInProgress, StartRequest(kNumericURL));
    367   ExpectResultCallback(HTTPRequest::kResultConnectionFailure);
    368   ExpectStop();
    369   CallConnectCompletion(false, -1);
    370   ExpectReset();
    371 }
    372 
    373 TEST_F(HTTPRequestTest, FailConnectNumericTimeout) {
    374   ExpectRouteRequest();
    375   ExpectAsyncConnect(kServerAddress, HTTPURL::kDefaultHTTPPort, true);
    376   EXPECT_EQ(HTTPRequest::kResultInProgress, StartRequest(kNumericURL));
    377   ExpectResultCallback(HTTPRequest::kResultConnectionTimeout);
    378   ExpectStop();
    379   CallTimeoutTask();
    380   ExpectReset();
    381 }
    382 
    383 TEST_F(HTTPRequestTest, SyncConnectNumeric) {
    384   ExpectRouteRequest();
    385   ExpectSyncConnect(kServerAddress, HTTPURL::kDefaultHTTPPort);
    386   ExpectMonitorServerOutput();
    387   EXPECT_EQ(HTTPRequest::kResultInProgress, StartRequest(kNumericURL));
    388 }
    389 
    390 TEST_F(HTTPRequestTest, FailDNSStart) {
    391   ExpectRouteRequest();
    392   ExpectDNSRequest(kTextSiteName, false);
    393   ExpectStop();
    394   EXPECT_EQ(HTTPRequest::kResultDNSFailure, StartRequest(kTextURL));
    395   ExpectReset();
    396 }
    397 
    398 TEST_F(HTTPRequestTest, FailDNSFailure) {
    399   ExpectRouteRequest();
    400   ExpectDNSRequest(kTextSiteName, true);
    401   EXPECT_EQ(HTTPRequest::kResultInProgress, StartRequest(kTextURL));
    402   ExpectResultCallback(HTTPRequest::kResultDNSFailure);
    403   ExpectStop();
    404   GetDNSResultFailure(DNSClient::kErrorNoData);
    405   ExpectReset();
    406 }
    407 
    408 TEST_F(HTTPRequestTest, FailDNSTimeout) {
    409   ExpectRouteRequest();
    410   ExpectDNSRequest(kTextSiteName, true);
    411   EXPECT_EQ(HTTPRequest::kResultInProgress, StartRequest(kTextURL));
    412   ExpectResultCallback(HTTPRequest::kResultDNSTimeout);
    413   ExpectStop();
    414   const string error(DNSClient::kErrorTimedOut);
    415   GetDNSResultFailure(error);
    416   ExpectReset();
    417 }
    418 
    419 TEST_F(HTTPRequestTest, FailConnectText) {
    420   ExpectConnectFailure();
    421   ExpectResultCallback(HTTPRequest::kResultConnectionFailure);
    422   ExpectStop();
    423   SetupConnect();
    424   ExpectReset();
    425 }
    426 
    427 TEST_F(HTTPRequestTest, ConnectComplete) {
    428   SetupConnectComplete();
    429 }
    430 
    431 TEST_F(HTTPRequestTest, RequestTimeout) {
    432   SetupConnectComplete();
    433   ExpectResultCallback(HTTPRequest::kResultRequestTimeout);
    434   ExpectStop();
    435   CallTimeoutTask();
    436 }
    437 
    438 TEST_F(HTTPRequestTest, RequestData) {
    439   SetupConnectComplete();
    440   EXPECT_EQ(0, FindInRequestData(string("GET ") + kPath));
    441   EXPECT_NE(string::npos,
    442             FindInRequestData(string("\r\nHost: ") + kTextSiteName));
    443   ByteString request_data = GetRequestData();
    444   EXPECT_CALL(sockets(), Send(kServerFD, _, request_data.GetLength(), 0))
    445       .WillOnce(Return(request_data.GetLength() - 1));
    446   ExpectSetInputTimeout();
    447   WriteToServer(kServerFD);
    448   EXPECT_CALL(sockets(), Send(kServerFD, _, 1, 0))
    449       .WillOnce(Return(1));
    450   ExpectMonitorServerInput();
    451   WriteToServer(kServerFD);
    452 }
    453 
    454 TEST_F(HTTPRequestTest, ResponseTimeout) {
    455   SetupConnectComplete();
    456   ByteString request_data = GetRequestData();
    457   EXPECT_CALL(sockets(), Send(kServerFD, _, request_data.GetLength(), 0))
    458       .WillOnce(Return(request_data.GetLength()));
    459   ExpectMonitorServerInput();
    460   WriteToServer(kServerFD);
    461   ExpectResultCallback(HTTPRequest::kResultResponseTimeout);
    462   ExpectStop();
    463   CallTimeoutTask();
    464 }
    465 
    466 TEST_F(HTTPRequestTest, ResponseInputError) {
    467   SetupConnectComplete();
    468   ByteString request_data = GetRequestData();
    469   EXPECT_CALL(sockets(), Send(kServerFD, _, request_data.GetLength(), 0))
    470       .WillOnce(Return(request_data.GetLength()));
    471   ExpectMonitorServerInput();
    472   WriteToServer(kServerFD);
    473   ExpectResultCallback(HTTPRequest::kResultResponseFailure);
    474   ExpectStop();
    475   CallServerErrorCallback();
    476 }
    477 
    478 TEST_F(HTTPRequestTest, ResponseData) {
    479   SetupConnectComplete();
    480   const string response0("hello");
    481   ExpectReadEventCallback(response0);
    482   ExpectSetInputTimeout();
    483   ReadFromServer(response0);
    484   ExpectInResponse(response0);
    485 
    486   const string response1(" to you");
    487   ExpectReadEventCallback(response0 + response1);
    488   ExpectSetInputTimeout();
    489   ReadFromServer(response1);
    490   ExpectInResponse(response1);
    491 
    492   ExpectResultCallbackWithResponse(response0 + response1);
    493   ExpectStop();
    494   ReadFromServer("");
    495   ExpectReset();
    496 }
    497 
    498 }  // namespace shill
    499