Home | History | Annotate | Download | only in cloud
      1 // Copyright (c) 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "chrome/browser/policy/cloud/test_request_interceptor.h"
      6 
      7 #include <limits>
      8 #include <queue>
      9 
     10 #include "base/bind.h"
     11 #include "base/bind_helpers.h"
     12 #include "base/memory/scoped_ptr.h"
     13 #include "base/run_loop.h"
     14 #include "base/sequenced_task_runner.h"
     15 #include "content/test/net/url_request_mock_http_job.h"
     16 #include "net/base/net_errors.h"
     17 #include "net/base/upload_bytes_element_reader.h"
     18 #include "net/base/upload_data_stream.h"
     19 #include "net/base/upload_element_reader.h"
     20 #include "net/url_request/url_request_error_job.h"
     21 #include "net/url_request/url_request_filter.h"
     22 #include "net/url_request/url_request_job_factory.h"
     23 #include "net/url_request/url_request_test_job.h"
     24 #include "url/gurl.h"
     25 
     26 namespace em = enterprise_management;
     27 
     28 namespace policy {
     29 
     30 namespace {
     31 
     32 // Helper callback for jobs that should fail with a network |error|.
     33 net::URLRequestJob* ErrorJobCallback(int error,
     34                                      net::URLRequest* request,
     35                                      net::NetworkDelegate* network_delegate) {
     36   return new net::URLRequestErrorJob(request, network_delegate, error);
     37 }
     38 
     39 // Helper callback for jobs that should fail with a 400 HTTP error.
     40 net::URLRequestJob* BadRequestJobCallback(
     41     net::URLRequest* request,
     42     net::NetworkDelegate* network_delegate) {
     43   static const char kBadHeaders[] =
     44       "HTTP/1.1 400 Bad request\0"
     45       "Content-type: application/protobuf\0"
     46       "\0";
     47   std::string headers(kBadHeaders, arraysize(kBadHeaders));
     48   return new net::URLRequestTestJob(
     49       request, network_delegate, headers, std::string(), true);
     50 }
     51 
     52 net::URLRequestJob* FileJobCallback(const base::FilePath& file_path,
     53                                     net::URLRequest* request,
     54                                     net::NetworkDelegate* network_delegate) {
     55   return new content::URLRequestMockHTTPJob(
     56       request,
     57       network_delegate,
     58       file_path);
     59 }
     60 
     61 // Parses the upload data in |request| into |request_msg|, and validates the
     62 // request. The query string in the URL must contain the |expected_type| for
     63 // the "request" parameter. Returns true if all checks succeeded, and the
     64 // request data has been parsed into |request_msg|.
     65 bool ValidRequest(net::URLRequest* request,
     66                   const std::string& expected_type,
     67                   em::DeviceManagementRequest* request_msg) {
     68   if (request->method() != "POST")
     69     return false;
     70   std::string spec = request->url().spec();
     71   if (spec.find("request=" + expected_type) == std::string::npos)
     72     return false;
     73 
     74   // This assumes that the payload data was set from a single string. In that
     75   // case the UploadDataStream has a single UploadBytesElementReader with the
     76   // data in memory.
     77   const net::UploadDataStream* stream = request->get_upload();
     78   if (!stream)
     79     return false;
     80   const ScopedVector<net::UploadElementReader>& readers =
     81       stream->element_readers();
     82   if (readers.size() != 1u)
     83     return false;
     84   const net::UploadBytesElementReader* reader = readers[0]->AsBytesReader();
     85   if (!reader)
     86     return false;
     87   std::string data(reader->bytes(), reader->length());
     88   if (!request_msg->ParseFromString(data))
     89     return false;
     90 
     91   return true;
     92 }
     93 
     94 // Helper callback for register jobs that should suceed. Validates the request
     95 // parameters and returns an appropriate response job. If |expect_reregister|
     96 // is true then the reregister flag must be set in the DeviceRegisterRequest
     97 // protobuf.
     98 net::URLRequestJob* RegisterJobCallback(
     99     em::DeviceRegisterRequest::Type expected_type,
    100     bool expect_reregister,
    101     net::URLRequest* request,
    102     net::NetworkDelegate* network_delegate) {
    103   em::DeviceManagementRequest request_msg;
    104   if (!ValidRequest(request, "register", &request_msg))
    105     return BadRequestJobCallback(request, network_delegate);
    106 
    107   if (!request_msg.has_register_request() ||
    108       request_msg.has_unregister_request() ||
    109       request_msg.has_policy_request() ||
    110       request_msg.has_device_status_report_request() ||
    111       request_msg.has_session_status_report_request() ||
    112       request_msg.has_auto_enrollment_request()) {
    113     return BadRequestJobCallback(request, network_delegate);
    114   }
    115 
    116   const em::DeviceRegisterRequest& register_request =
    117       request_msg.register_request();
    118   if (expect_reregister &&
    119       (!register_request.has_reregister() || !register_request.reregister())) {
    120     return BadRequestJobCallback(request, network_delegate);
    121   } else if (!expect_reregister &&
    122              register_request.has_reregister() &&
    123              register_request.reregister()) {
    124     return BadRequestJobCallback(request, network_delegate);
    125   }
    126 
    127   if (!register_request.has_type() || register_request.type() != expected_type)
    128     return BadRequestJobCallback(request, network_delegate);
    129 
    130   em::DeviceManagementResponse response;
    131   em::DeviceRegisterResponse* register_response =
    132       response.mutable_register_response();
    133   register_response->set_device_management_token("s3cr3t70k3n");
    134   std::string data;
    135   response.SerializeToString(&data);
    136 
    137   static const char kGoodHeaders[] =
    138       "HTTP/1.1 200 OK\0"
    139       "Content-type: application/protobuf\0"
    140       "\0";
    141   std::string headers(kGoodHeaders, arraysize(kGoodHeaders));
    142   return new net::URLRequestTestJob(
    143       request, network_delegate, headers, data, true);
    144 }
    145 
    146 }  // namespace
    147 
    148 class TestRequestInterceptor::Delegate
    149     : public net::URLRequestJobFactory::ProtocolHandler {
    150  public:
    151   Delegate(const std::string& hostname,
    152            scoped_refptr<base::SequencedTaskRunner> io_task_runner);
    153   virtual ~Delegate();
    154 
    155   // ProtocolHandler implementation:
    156   virtual net::URLRequestJob* MaybeCreateJob(
    157       net::URLRequest* request,
    158       net::NetworkDelegate* network_delegate) const OVERRIDE;
    159 
    160   void GetPendingSize(size_t* pending_size) const;
    161   void PushJobCallback(const JobCallback& callback);
    162 
    163  private:
    164   const std::string hostname_;
    165   scoped_refptr<base::SequencedTaskRunner> io_task_runner_;
    166 
    167   // The queue of pending callbacks. 'mutable' because MaybeCreateJob() is a
    168   // const method; it can't reenter though, because it runs exclusively on
    169   // the IO thread.
    170   mutable std::queue<JobCallback> pending_job_callbacks_;
    171 };
    172 
    173 TestRequestInterceptor::Delegate::Delegate(
    174     const std::string& hostname,
    175     scoped_refptr<base::SequencedTaskRunner> io_task_runner)
    176     : hostname_(hostname), io_task_runner_(io_task_runner) {}
    177 
    178 TestRequestInterceptor::Delegate::~Delegate() {}
    179 
    180 net::URLRequestJob* TestRequestInterceptor::Delegate::MaybeCreateJob(
    181     net::URLRequest* request,
    182     net::NetworkDelegate* network_delegate) const {
    183   CHECK(io_task_runner_->RunsTasksOnCurrentThread());
    184 
    185   if (request->url().host() != hostname_) {
    186     // Reject requests to other servers.
    187     return ErrorJobCallback(
    188         net::ERR_CONNECTION_REFUSED, request, network_delegate);
    189   }
    190 
    191   if (pending_job_callbacks_.empty()) {
    192     // Reject dmserver requests by default.
    193     return BadRequestJobCallback(request, network_delegate);
    194   }
    195 
    196   JobCallback callback = pending_job_callbacks_.front();
    197   pending_job_callbacks_.pop();
    198   return callback.Run(request, network_delegate);
    199 }
    200 
    201 void TestRequestInterceptor::Delegate::GetPendingSize(
    202     size_t* pending_size) const {
    203   CHECK(io_task_runner_->RunsTasksOnCurrentThread());
    204   *pending_size = pending_job_callbacks_.size();
    205 }
    206 
    207 void TestRequestInterceptor::Delegate::PushJobCallback(
    208     const JobCallback& callback) {
    209   CHECK(io_task_runner_->RunsTasksOnCurrentThread());
    210   pending_job_callbacks_.push(callback);
    211 }
    212 
    213 TestRequestInterceptor::TestRequestInterceptor(const std::string& hostname,
    214     scoped_refptr<base::SequencedTaskRunner> io_task_runner)
    215     : hostname_(hostname),
    216       io_task_runner_(io_task_runner) {
    217   delegate_ = new Delegate(hostname_, io_task_runner_);
    218   scoped_ptr<net::URLRequestJobFactory::ProtocolHandler> handler(delegate_);
    219   PostToIOAndWait(
    220       base::Bind(&net::URLRequestFilter::AddHostnameProtocolHandler,
    221                  base::Unretained(net::URLRequestFilter::GetInstance()),
    222                  "http", hostname_, base::Passed(&handler)));
    223 }
    224 
    225 TestRequestInterceptor::~TestRequestInterceptor() {
    226   // RemoveHostnameHandler() destroys the |delegate_|, which is owned by
    227   // the URLRequestFilter.
    228   delegate_ = NULL;
    229   PostToIOAndWait(
    230       base::Bind(&net::URLRequestFilter::RemoveHostnameHandler,
    231                  base::Unretained(net::URLRequestFilter::GetInstance()),
    232                  "http", hostname_));
    233 }
    234 
    235 size_t TestRequestInterceptor::GetPendingSize() {
    236   size_t pending_size = std::numeric_limits<size_t>::max();
    237   PostToIOAndWait(base::Bind(&Delegate::GetPendingSize,
    238                              base::Unretained(delegate_),
    239                              &pending_size));
    240   return pending_size;
    241 }
    242 
    243 void TestRequestInterceptor::PushJobCallback(const JobCallback& callback) {
    244   PostToIOAndWait(base::Bind(&Delegate::PushJobCallback,
    245                              base::Unretained(delegate_),
    246                              callback));
    247 }
    248 
    249 // static
    250 TestRequestInterceptor::JobCallback TestRequestInterceptor::ErrorJob(
    251     int error) {
    252   return base::Bind(&ErrorJobCallback, error);
    253 }
    254 
    255 // static
    256 TestRequestInterceptor::JobCallback TestRequestInterceptor::BadRequestJob() {
    257   return base::Bind(&BadRequestJobCallback);
    258 }
    259 
    260 // static
    261 TestRequestInterceptor::JobCallback TestRequestInterceptor::RegisterJob(
    262     em::DeviceRegisterRequest::Type expected_type,
    263     bool expect_reregister) {
    264   return base::Bind(&RegisterJobCallback, expected_type, expect_reregister);
    265 }
    266 
    267 // static
    268 TestRequestInterceptor::JobCallback TestRequestInterceptor::FileJob(
    269     const base::FilePath& file_path) {
    270   return base::Bind(&FileJobCallback, file_path);
    271 }
    272 
    273 void TestRequestInterceptor::PostToIOAndWait(const base::Closure& task) {
    274   io_task_runner_->PostTask(FROM_HERE, task);
    275   base::RunLoop run_loop;
    276   io_task_runner_->PostTask(
    277       FROM_HERE,
    278       base::Bind(
    279           base::IgnoreResult(&base::MessageLoopProxy::PostTask),
    280           base::MessageLoopProxy::current(),
    281           FROM_HERE,
    282           run_loop.QuitClosure()));
    283   run_loop.Run();
    284 }
    285 
    286 }  // namespace policy
    287