Home | History | Annotate | Download | only in http
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/http/http_response_body_drainer.h"
      6 
      7 #include <cstring>
      8 
      9 #include "base/bind.h"
     10 #include "base/compiler_specific.h"
     11 #include "base/memory/weak_ptr.h"
     12 #include "base/message_loop/message_loop.h"
     13 #include "net/base/io_buffer.h"
     14 #include "net/base/net_errors.h"
     15 #include "net/base/test_completion_callback.h"
     16 #include "net/http/http_network_session.h"
     17 #include "net/http/http_server_properties_impl.h"
     18 #include "net/http/http_stream.h"
     19 #include "net/http/transport_security_state.h"
     20 #include "net/proxy/proxy_service.h"
     21 #include "net/ssl/ssl_config_service_defaults.h"
     22 #include "testing/gtest/include/gtest/gtest.h"
     23 
     24 namespace net {
     25 
     26 namespace {
     27 
     28 const int kMagicChunkSize = 1024;
     29 COMPILE_ASSERT(
     30     (HttpResponseBodyDrainer::kDrainBodyBufferSize % kMagicChunkSize) == 0,
     31     chunk_size_needs_to_divide_evenly_into_buffer_size);
     32 
     33 class CloseResultWaiter {
     34  public:
     35   CloseResultWaiter()
     36       : result_(false),
     37         have_result_(false),
     38         waiting_for_result_(false) {}
     39 
     40   int WaitForResult() {
     41     CHECK(!waiting_for_result_);
     42     while (!have_result_) {
     43       waiting_for_result_ = true;
     44       base::MessageLoop::current()->Run();
     45       waiting_for_result_ = false;
     46     }
     47     return result_;
     48   }
     49 
     50   void set_result(bool result) {
     51     result_ = result;
     52     have_result_ = true;
     53     if (waiting_for_result_)
     54       base::MessageLoop::current()->Quit();
     55   }
     56 
     57  private:
     58   int result_;
     59   bool have_result_;
     60   bool waiting_for_result_;
     61 
     62   DISALLOW_COPY_AND_ASSIGN(CloseResultWaiter);
     63 };
     64 
     65 class MockHttpStream : public HttpStream {
     66  public:
     67   MockHttpStream(CloseResultWaiter* result_waiter)
     68       : result_waiter_(result_waiter),
     69         buf_len_(0),
     70         closed_(false),
     71         stall_reads_forever_(false),
     72         num_chunks_(0),
     73         is_sync_(false),
     74         is_last_chunk_zero_size_(false),
     75         is_complete_(false),
     76         weak_factory_(this) {}
     77   virtual ~MockHttpStream() {}
     78 
     79   // HttpStream implementation.
     80   virtual int InitializeStream(const HttpRequestInfo* request_info,
     81                                RequestPriority priority,
     82                                const BoundNetLog& net_log,
     83                                const CompletionCallback& callback) OVERRIDE {
     84     return ERR_UNEXPECTED;
     85   }
     86   virtual int SendRequest(const HttpRequestHeaders& request_headers,
     87                           HttpResponseInfo* response,
     88                           const CompletionCallback& callback) OVERRIDE {
     89     return ERR_UNEXPECTED;
     90   }
     91   virtual UploadProgress GetUploadProgress() const OVERRIDE {
     92     return UploadProgress();
     93   }
     94   virtual int ReadResponseHeaders(const CompletionCallback& callback) OVERRIDE {
     95     return ERR_UNEXPECTED;
     96   }
     97 
     98   virtual bool CanFindEndOfResponse() const OVERRIDE { return true; }
     99   virtual bool IsConnectionReused() const OVERRIDE { return false; }
    100   virtual void SetConnectionReused() OVERRIDE {}
    101   virtual bool IsConnectionReusable() const OVERRIDE { return false; }
    102   virtual int64 GetTotalReceivedBytes() const OVERRIDE { return 0; }
    103   virtual void GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {}
    104   virtual void GetSSLCertRequestInfo(
    105       SSLCertRequestInfo* cert_request_info) OVERRIDE {}
    106 
    107   // Mocked API
    108   virtual int ReadResponseBody(IOBuffer* buf, int buf_len,
    109                                const CompletionCallback& callback) OVERRIDE;
    110   virtual void Close(bool not_reusable) OVERRIDE {
    111     CHECK(!closed_);
    112     closed_ = true;
    113     result_waiter_->set_result(not_reusable);
    114   }
    115 
    116   virtual HttpStream* RenewStreamForAuth() OVERRIDE {
    117     return NULL;
    118   }
    119 
    120   virtual bool IsResponseBodyComplete() const OVERRIDE { return is_complete_; }
    121 
    122   virtual bool IsSpdyHttpStream() const OVERRIDE { return false; }
    123 
    124   virtual bool GetLoadTimingInfo(
    125       LoadTimingInfo* load_timing_info) const OVERRIDE { return false; }
    126 
    127   virtual void Drain(HttpNetworkSession*) OVERRIDE {}
    128 
    129   virtual void SetPriority(RequestPriority priority) OVERRIDE {}
    130 
    131   // Methods to tweak/observer mock behavior:
    132   void set_stall_reads_forever() { stall_reads_forever_ = true; }
    133 
    134   void set_num_chunks(int num_chunks) { num_chunks_ = num_chunks; }
    135 
    136   void set_sync() { is_sync_ = true; }
    137 
    138   void set_is_last_chunk_zero_size() { is_last_chunk_zero_size_ = true; }
    139 
    140  private:
    141   int ReadResponseBodyImpl(IOBuffer* buf, int buf_len);
    142   void CompleteRead();
    143 
    144   bool closed() const { return closed_; }
    145 
    146   CloseResultWaiter* const result_waiter_;
    147   scoped_refptr<IOBuffer> user_buf_;
    148   CompletionCallback callback_;
    149   int buf_len_;
    150   bool closed_;
    151   bool stall_reads_forever_;
    152   int num_chunks_;
    153   bool is_sync_;
    154   bool is_last_chunk_zero_size_;
    155   bool is_complete_;
    156   base::WeakPtrFactory<MockHttpStream> weak_factory_;
    157 };
    158 
    159 int MockHttpStream::ReadResponseBody(IOBuffer* buf,
    160                                      int buf_len,
    161                                      const CompletionCallback& callback) {
    162   CHECK(!callback.is_null());
    163   CHECK(callback_.is_null());
    164   CHECK(buf);
    165 
    166   if (stall_reads_forever_)
    167     return ERR_IO_PENDING;
    168 
    169   if (is_complete_)
    170     return ERR_UNEXPECTED;
    171 
    172   if (!is_sync_) {
    173     user_buf_ = buf;
    174     buf_len_ = buf_len;
    175     callback_ = callback;
    176     base::MessageLoop::current()->PostTask(
    177         FROM_HERE,
    178         base::Bind(&MockHttpStream::CompleteRead, weak_factory_.GetWeakPtr()));
    179     return ERR_IO_PENDING;
    180   } else {
    181     return ReadResponseBodyImpl(buf, buf_len);
    182   }
    183 }
    184 
    185 int MockHttpStream::ReadResponseBodyImpl(IOBuffer* buf, int buf_len) {
    186   if (is_last_chunk_zero_size_ && num_chunks_ == 1) {
    187     buf_len = 0;
    188   } else {
    189     if (buf_len > kMagicChunkSize)
    190       buf_len = kMagicChunkSize;
    191     std::memset(buf->data(), 1, buf_len);
    192   }
    193   num_chunks_--;
    194   if (!num_chunks_)
    195     is_complete_ = true;
    196 
    197   return buf_len;
    198 }
    199 
    200 void MockHttpStream::CompleteRead() {
    201   int result = ReadResponseBodyImpl(user_buf_.get(), buf_len_);
    202   user_buf_ = NULL;
    203   CompletionCallback callback = callback_;
    204   callback_.Reset();
    205   callback.Run(result);
    206 }
    207 
    208 class HttpResponseBodyDrainerTest : public testing::Test {
    209  protected:
    210   HttpResponseBodyDrainerTest()
    211       : proxy_service_(ProxyService::CreateDirect()),
    212         ssl_config_service_(new SSLConfigServiceDefaults),
    213         http_server_properties_(new HttpServerPropertiesImpl()),
    214         transport_security_state_(new TransportSecurityState()),
    215         session_(CreateNetworkSession()),
    216         mock_stream_(new MockHttpStream(&result_waiter_)),
    217         drainer_(new HttpResponseBodyDrainer(mock_stream_)) {}
    218 
    219   virtual ~HttpResponseBodyDrainerTest() {}
    220 
    221   HttpNetworkSession* CreateNetworkSession() const {
    222     HttpNetworkSession::Params params;
    223     params.proxy_service = proxy_service_.get();
    224     params.ssl_config_service = ssl_config_service_.get();
    225     params.http_server_properties = http_server_properties_->GetWeakPtr();
    226     params.transport_security_state = transport_security_state_.get();
    227     return new HttpNetworkSession(params);
    228   }
    229 
    230   scoped_ptr<ProxyService> proxy_service_;
    231   scoped_refptr<SSLConfigService> ssl_config_service_;
    232   scoped_ptr<HttpServerPropertiesImpl> http_server_properties_;
    233   scoped_ptr<TransportSecurityState> transport_security_state_;
    234   const scoped_refptr<HttpNetworkSession> session_;
    235   CloseResultWaiter result_waiter_;
    236   MockHttpStream* const mock_stream_;  // Owned by |drainer_|.
    237   HttpResponseBodyDrainer* const drainer_;  // Deletes itself.
    238 };
    239 
    240 TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncSingleOK) {
    241   mock_stream_->set_num_chunks(1);
    242   mock_stream_->set_sync();
    243   drainer_->Start(session_.get());
    244   EXPECT_FALSE(result_waiter_.WaitForResult());
    245 }
    246 
    247 TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncOK) {
    248   mock_stream_->set_num_chunks(3);
    249   mock_stream_->set_sync();
    250   drainer_->Start(session_.get());
    251   EXPECT_FALSE(result_waiter_.WaitForResult());
    252 }
    253 
    254 TEST_F(HttpResponseBodyDrainerTest, DrainBodyAsyncOK) {
    255   mock_stream_->set_num_chunks(3);
    256   drainer_->Start(session_.get());
    257   EXPECT_FALSE(result_waiter_.WaitForResult());
    258 }
    259 
    260 // Test the case when the final chunk is 0 bytes. This can happen when
    261 // the final 0-byte chunk of a chunk-encoded http response is read in a last
    262 // call to ReadResponseBody, after all data were returned from HttpStream.
    263 TEST_F(HttpResponseBodyDrainerTest, DrainBodyAsyncEmptyChunk) {
    264   mock_stream_->set_num_chunks(4);
    265   mock_stream_->set_is_last_chunk_zero_size();
    266   drainer_->Start(session_.get());
    267   EXPECT_FALSE(result_waiter_.WaitForResult());
    268 }
    269 
    270 TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncEmptyChunk) {
    271   mock_stream_->set_num_chunks(4);
    272   mock_stream_->set_sync();
    273   mock_stream_->set_is_last_chunk_zero_size();
    274   drainer_->Start(session_.get());
    275   EXPECT_FALSE(result_waiter_.WaitForResult());
    276 }
    277 
    278 TEST_F(HttpResponseBodyDrainerTest, DrainBodySizeEqualsDrainBuffer) {
    279   mock_stream_->set_num_chunks(
    280       HttpResponseBodyDrainer::kDrainBodyBufferSize / kMagicChunkSize);
    281   drainer_->Start(session_.get());
    282   EXPECT_FALSE(result_waiter_.WaitForResult());
    283 }
    284 
    285 TEST_F(HttpResponseBodyDrainerTest, DrainBodyTimeOut) {
    286   mock_stream_->set_num_chunks(2);
    287   mock_stream_->set_stall_reads_forever();
    288   drainer_->Start(session_.get());
    289   EXPECT_TRUE(result_waiter_.WaitForResult());
    290 }
    291 
    292 TEST_F(HttpResponseBodyDrainerTest, CancelledBySession) {
    293   mock_stream_->set_num_chunks(2);
    294   mock_stream_->set_stall_reads_forever();
    295   drainer_->Start(session_.get());
    296   // HttpNetworkSession should delete |drainer_|.
    297 }
    298 
    299 TEST_F(HttpResponseBodyDrainerTest, DrainBodyTooLarge) {
    300   int too_many_chunks =
    301       HttpResponseBodyDrainer::kDrainBodyBufferSize / kMagicChunkSize;
    302   too_many_chunks += 1;  // Now it's too large.
    303 
    304   mock_stream_->set_num_chunks(too_many_chunks);
    305   drainer_->Start(session_.get());
    306   EXPECT_TRUE(result_waiter_.WaitForResult());
    307 }
    308 
    309 }  // namespace
    310 
    311 }  // namespace net
    312