1 // Copyright (c) 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "chrome/browser/policy/cloud/test_request_interceptor.h" 6 7 #include <limits> 8 #include <queue> 9 10 #include "base/bind.h" 11 #include "base/bind_helpers.h" 12 #include "base/memory/scoped_ptr.h" 13 #include "base/run_loop.h" 14 #include "base/sequenced_task_runner.h" 15 #include "content/test/net/url_request_mock_http_job.h" 16 #include "net/base/net_errors.h" 17 #include "net/base/upload_bytes_element_reader.h" 18 #include "net/base/upload_data_stream.h" 19 #include "net/base/upload_element_reader.h" 20 #include "net/url_request/url_request_error_job.h" 21 #include "net/url_request/url_request_filter.h" 22 #include "net/url_request/url_request_job_factory.h" 23 #include "net/url_request/url_request_test_job.h" 24 #include "url/gurl.h" 25 26 namespace em = enterprise_management; 27 28 namespace policy { 29 30 namespace { 31 32 // Helper callback for jobs that should fail with a network |error|. 33 net::URLRequestJob* ErrorJobCallback(int error, 34 net::URLRequest* request, 35 net::NetworkDelegate* network_delegate) { 36 return new net::URLRequestErrorJob(request, network_delegate, error); 37 } 38 39 // Helper callback for jobs that should fail with a 400 HTTP error. 40 net::URLRequestJob* BadRequestJobCallback( 41 net::URLRequest* request, 42 net::NetworkDelegate* network_delegate) { 43 static const char kBadHeaders[] = 44 "HTTP/1.1 400 Bad request\0" 45 "Content-type: application/protobuf\0" 46 "\0"; 47 std::string headers(kBadHeaders, arraysize(kBadHeaders)); 48 return new net::URLRequestTestJob( 49 request, network_delegate, headers, std::string(), true); 50 } 51 52 net::URLRequestJob* FileJobCallback(const base::FilePath& file_path, 53 net::URLRequest* request, 54 net::NetworkDelegate* network_delegate) { 55 return new content::URLRequestMockHTTPJob( 56 request, 57 network_delegate, 58 file_path); 59 } 60 61 // Parses the upload data in |request| into |request_msg|, and validates the 62 // request. The query string in the URL must contain the |expected_type| for 63 // the "request" parameter. Returns true if all checks succeeded, and the 64 // request data has been parsed into |request_msg|. 65 bool ValidRequest(net::URLRequest* request, 66 const std::string& expected_type, 67 em::DeviceManagementRequest* request_msg) { 68 if (request->method() != "POST") 69 return false; 70 std::string spec = request->url().spec(); 71 if (spec.find("request=" + expected_type) == std::string::npos) 72 return false; 73 74 // This assumes that the payload data was set from a single string. In that 75 // case the UploadDataStream has a single UploadBytesElementReader with the 76 // data in memory. 77 const net::UploadDataStream* stream = request->get_upload(); 78 if (!stream) 79 return false; 80 const ScopedVector<net::UploadElementReader>& readers = 81 stream->element_readers(); 82 if (readers.size() != 1u) 83 return false; 84 const net::UploadBytesElementReader* reader = readers[0]->AsBytesReader(); 85 if (!reader) 86 return false; 87 std::string data(reader->bytes(), reader->length()); 88 if (!request_msg->ParseFromString(data)) 89 return false; 90 91 return true; 92 } 93 94 // Helper callback for register jobs that should suceed. Validates the request 95 // parameters and returns an appropriate response job. If |expect_reregister| 96 // is true then the reregister flag must be set in the DeviceRegisterRequest 97 // protobuf. 98 net::URLRequestJob* RegisterJobCallback( 99 em::DeviceRegisterRequest::Type expected_type, 100 bool expect_reregister, 101 net::URLRequest* request, 102 net::NetworkDelegate* network_delegate) { 103 em::DeviceManagementRequest request_msg; 104 if (!ValidRequest(request, "register", &request_msg)) 105 return BadRequestJobCallback(request, network_delegate); 106 107 if (!request_msg.has_register_request() || 108 request_msg.has_unregister_request() || 109 request_msg.has_policy_request() || 110 request_msg.has_device_status_report_request() || 111 request_msg.has_session_status_report_request() || 112 request_msg.has_auto_enrollment_request()) { 113 return BadRequestJobCallback(request, network_delegate); 114 } 115 116 const em::DeviceRegisterRequest& register_request = 117 request_msg.register_request(); 118 if (expect_reregister && 119 (!register_request.has_reregister() || !register_request.reregister())) { 120 return BadRequestJobCallback(request, network_delegate); 121 } else if (!expect_reregister && 122 register_request.has_reregister() && 123 register_request.reregister()) { 124 return BadRequestJobCallback(request, network_delegate); 125 } 126 127 if (!register_request.has_type() || register_request.type() != expected_type) 128 return BadRequestJobCallback(request, network_delegate); 129 130 em::DeviceManagementResponse response; 131 em::DeviceRegisterResponse* register_response = 132 response.mutable_register_response(); 133 register_response->set_device_management_token("s3cr3t70k3n"); 134 std::string data; 135 response.SerializeToString(&data); 136 137 static const char kGoodHeaders[] = 138 "HTTP/1.1 200 OK\0" 139 "Content-type: application/protobuf\0" 140 "\0"; 141 std::string headers(kGoodHeaders, arraysize(kGoodHeaders)); 142 return new net::URLRequestTestJob( 143 request, network_delegate, headers, data, true); 144 } 145 146 } // namespace 147 148 class TestRequestInterceptor::Delegate 149 : public net::URLRequestJobFactory::ProtocolHandler { 150 public: 151 Delegate(const std::string& hostname, 152 scoped_refptr<base::SequencedTaskRunner> io_task_runner); 153 virtual ~Delegate(); 154 155 // ProtocolHandler implementation: 156 virtual net::URLRequestJob* MaybeCreateJob( 157 net::URLRequest* request, 158 net::NetworkDelegate* network_delegate) const OVERRIDE; 159 160 void GetPendingSize(size_t* pending_size) const; 161 void PushJobCallback(const JobCallback& callback); 162 163 private: 164 const std::string hostname_; 165 scoped_refptr<base::SequencedTaskRunner> io_task_runner_; 166 167 // The queue of pending callbacks. 'mutable' because MaybeCreateJob() is a 168 // const method; it can't reenter though, because it runs exclusively on 169 // the IO thread. 170 mutable std::queue<JobCallback> pending_job_callbacks_; 171 }; 172 173 TestRequestInterceptor::Delegate::Delegate( 174 const std::string& hostname, 175 scoped_refptr<base::SequencedTaskRunner> io_task_runner) 176 : hostname_(hostname), io_task_runner_(io_task_runner) {} 177 178 TestRequestInterceptor::Delegate::~Delegate() {} 179 180 net::URLRequestJob* TestRequestInterceptor::Delegate::MaybeCreateJob( 181 net::URLRequest* request, 182 net::NetworkDelegate* network_delegate) const { 183 CHECK(io_task_runner_->RunsTasksOnCurrentThread()); 184 185 if (request->url().host() != hostname_) { 186 // Reject requests to other servers. 187 return ErrorJobCallback( 188 net::ERR_CONNECTION_REFUSED, request, network_delegate); 189 } 190 191 if (pending_job_callbacks_.empty()) { 192 // Reject dmserver requests by default. 193 return BadRequestJobCallback(request, network_delegate); 194 } 195 196 JobCallback callback = pending_job_callbacks_.front(); 197 pending_job_callbacks_.pop(); 198 return callback.Run(request, network_delegate); 199 } 200 201 void TestRequestInterceptor::Delegate::GetPendingSize( 202 size_t* pending_size) const { 203 CHECK(io_task_runner_->RunsTasksOnCurrentThread()); 204 *pending_size = pending_job_callbacks_.size(); 205 } 206 207 void TestRequestInterceptor::Delegate::PushJobCallback( 208 const JobCallback& callback) { 209 CHECK(io_task_runner_->RunsTasksOnCurrentThread()); 210 pending_job_callbacks_.push(callback); 211 } 212 213 TestRequestInterceptor::TestRequestInterceptor(const std::string& hostname, 214 scoped_refptr<base::SequencedTaskRunner> io_task_runner) 215 : hostname_(hostname), 216 io_task_runner_(io_task_runner) { 217 delegate_ = new Delegate(hostname_, io_task_runner_); 218 scoped_ptr<net::URLRequestJobFactory::ProtocolHandler> handler(delegate_); 219 PostToIOAndWait( 220 base::Bind(&net::URLRequestFilter::AddHostnameProtocolHandler, 221 base::Unretained(net::URLRequestFilter::GetInstance()), 222 "http", hostname_, base::Passed(&handler))); 223 } 224 225 TestRequestInterceptor::~TestRequestInterceptor() { 226 // RemoveHostnameHandler() destroys the |delegate_|, which is owned by 227 // the URLRequestFilter. 228 delegate_ = NULL; 229 PostToIOAndWait( 230 base::Bind(&net::URLRequestFilter::RemoveHostnameHandler, 231 base::Unretained(net::URLRequestFilter::GetInstance()), 232 "http", hostname_)); 233 } 234 235 size_t TestRequestInterceptor::GetPendingSize() { 236 size_t pending_size = std::numeric_limits<size_t>::max(); 237 PostToIOAndWait(base::Bind(&Delegate::GetPendingSize, 238 base::Unretained(delegate_), 239 &pending_size)); 240 return pending_size; 241 } 242 243 void TestRequestInterceptor::PushJobCallback(const JobCallback& callback) { 244 PostToIOAndWait(base::Bind(&Delegate::PushJobCallback, 245 base::Unretained(delegate_), 246 callback)); 247 } 248 249 // static 250 TestRequestInterceptor::JobCallback TestRequestInterceptor::ErrorJob( 251 int error) { 252 return base::Bind(&ErrorJobCallback, error); 253 } 254 255 // static 256 TestRequestInterceptor::JobCallback TestRequestInterceptor::BadRequestJob() { 257 return base::Bind(&BadRequestJobCallback); 258 } 259 260 // static 261 TestRequestInterceptor::JobCallback TestRequestInterceptor::RegisterJob( 262 em::DeviceRegisterRequest::Type expected_type, 263 bool expect_reregister) { 264 return base::Bind(&RegisterJobCallback, expected_type, expect_reregister); 265 } 266 267 // static 268 TestRequestInterceptor::JobCallback TestRequestInterceptor::FileJob( 269 const base::FilePath& file_path) { 270 return base::Bind(&FileJobCallback, file_path); 271 } 272 273 void TestRequestInterceptor::PostToIOAndWait(const base::Closure& task) { 274 io_task_runner_->PostTask(FROM_HERE, task); 275 base::RunLoop run_loop; 276 io_task_runner_->PostTask( 277 FROM_HERE, 278 base::Bind( 279 base::IgnoreResult(&base::MessageLoopProxy::PostTask), 280 base::MessageLoopProxy::current(), 281 FROM_HERE, 282 run_loop.QuitClosure())); 283 run_loop.Run(); 284 } 285 286 } // namespace policy 287