1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "content/browser/renderer_host/websocket_host.h" 6 7 #include "base/basictypes.h" 8 #include "base/memory/weak_ptr.h" 9 #include "base/strings/string_util.h" 10 #include "content/browser/renderer_host/websocket_dispatcher_host.h" 11 #include "content/browser/ssl/ssl_error_handler.h" 12 #include "content/browser/ssl/ssl_manager.h" 13 #include "content/common/websocket_messages.h" 14 #include "ipc/ipc_message_macros.h" 15 #include "net/http/http_request_headers.h" 16 #include "net/http/http_response_headers.h" 17 #include "net/http/http_util.h" 18 #include "net/ssl/ssl_info.h" 19 #include "net/websockets/websocket_channel.h" 20 #include "net/websockets/websocket_event_interface.h" 21 #include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode 22 #include "net/websockets/websocket_handshake_request_info.h" 23 #include "net/websockets/websocket_handshake_response_info.h" 24 #include "url/origin.h" 25 26 namespace content { 27 28 namespace { 29 30 typedef net::WebSocketEventInterface::ChannelState ChannelState; 31 32 // Convert a content::WebSocketMessageType to a 33 // net::WebSocketFrameHeader::OpCode 34 net::WebSocketFrameHeader::OpCode MessageTypeToOpCode( 35 WebSocketMessageType type) { 36 DCHECK(type == WEB_SOCKET_MESSAGE_TYPE_CONTINUATION || 37 type == WEB_SOCKET_MESSAGE_TYPE_TEXT || 38 type == WEB_SOCKET_MESSAGE_TYPE_BINARY); 39 typedef net::WebSocketFrameHeader::OpCode OpCode; 40 // These compile asserts verify that the same underlying values are used for 41 // both types, so we can simply cast between them. 42 COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_CONTINUATION) == 43 net::WebSocketFrameHeader::kOpCodeContinuation, 44 enum_values_must_match_for_opcode_continuation); 45 COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_TEXT) == 46 net::WebSocketFrameHeader::kOpCodeText, 47 enum_values_must_match_for_opcode_text); 48 COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_BINARY) == 49 net::WebSocketFrameHeader::kOpCodeBinary, 50 enum_values_must_match_for_opcode_binary); 51 return static_cast<OpCode>(type); 52 } 53 54 WebSocketMessageType OpCodeToMessageType( 55 net::WebSocketFrameHeader::OpCode opCode) { 56 DCHECK(opCode == net::WebSocketFrameHeader::kOpCodeContinuation || 57 opCode == net::WebSocketFrameHeader::kOpCodeText || 58 opCode == net::WebSocketFrameHeader::kOpCodeBinary); 59 // This cast is guaranteed valid by the COMPILE_ASSERT() statements above. 60 return static_cast<WebSocketMessageType>(opCode); 61 } 62 63 ChannelState StateCast(WebSocketDispatcherHost::WebSocketHostState host_state) { 64 const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_ALIVE = 65 WebSocketDispatcherHost::WEBSOCKET_HOST_ALIVE; 66 const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_DELETED = 67 WebSocketDispatcherHost::WEBSOCKET_HOST_DELETED; 68 69 DCHECK(host_state == WEBSOCKET_HOST_ALIVE || 70 host_state == WEBSOCKET_HOST_DELETED); 71 // These compile asserts verify that we can get away with using static_cast<> 72 // for the conversion. 73 COMPILE_ASSERT(static_cast<ChannelState>(WEBSOCKET_HOST_ALIVE) == 74 net::WebSocketEventInterface::CHANNEL_ALIVE, 75 enum_values_must_match_for_state_alive); 76 COMPILE_ASSERT(static_cast<ChannelState>(WEBSOCKET_HOST_DELETED) == 77 net::WebSocketEventInterface::CHANNEL_DELETED, 78 enum_values_must_match_for_state_deleted); 79 return static_cast<ChannelState>(host_state); 80 } 81 82 // Implementation of net::WebSocketEventInterface. Receives events from our 83 // WebSocketChannel object. Each event is translated to an IPC and sent to the 84 // renderer or child process via WebSocketDispatcherHost. 85 class WebSocketEventHandler : public net::WebSocketEventInterface { 86 public: 87 WebSocketEventHandler(WebSocketDispatcherHost* dispatcher, 88 int routing_id, 89 int render_frame_id); 90 virtual ~WebSocketEventHandler(); 91 92 // net::WebSocketEventInterface implementation 93 94 virtual ChannelState OnAddChannelResponse( 95 bool fail, 96 const std::string& selected_subprotocol, 97 const std::string& extensions) OVERRIDE; 98 virtual ChannelState OnDataFrame(bool fin, 99 WebSocketMessageType type, 100 const std::vector<char>& data) OVERRIDE; 101 virtual ChannelState OnClosingHandshake() OVERRIDE; 102 virtual ChannelState OnFlowControl(int64 quota) OVERRIDE; 103 virtual ChannelState OnDropChannel(bool was_clean, 104 uint16 code, 105 const std::string& reason) OVERRIDE; 106 virtual ChannelState OnFailChannel(const std::string& message) OVERRIDE; 107 virtual ChannelState OnStartOpeningHandshake( 108 scoped_ptr<net::WebSocketHandshakeRequestInfo> request) OVERRIDE; 109 virtual ChannelState OnFinishOpeningHandshake( 110 scoped_ptr<net::WebSocketHandshakeResponseInfo> response) OVERRIDE; 111 virtual ChannelState OnSSLCertificateError( 112 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks, 113 const GURL& url, 114 const net::SSLInfo& ssl_info, 115 bool fatal) OVERRIDE; 116 117 private: 118 class SSLErrorHandlerDelegate : public SSLErrorHandler::Delegate { 119 public: 120 SSLErrorHandlerDelegate( 121 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks); 122 virtual ~SSLErrorHandlerDelegate(); 123 124 base::WeakPtr<SSLErrorHandler::Delegate> GetWeakPtr(); 125 126 // SSLErrorHandler::Delegate methods 127 virtual void CancelSSLRequest(const GlobalRequestID& id, 128 int error, 129 const net::SSLInfo* ssl_info) OVERRIDE; 130 virtual void ContinueSSLRequest(const GlobalRequestID& id) OVERRIDE; 131 132 private: 133 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks_; 134 base::WeakPtrFactory<SSLErrorHandlerDelegate> weak_ptr_factory_; 135 136 DISALLOW_COPY_AND_ASSIGN(SSLErrorHandlerDelegate); 137 }; 138 139 WebSocketDispatcherHost* const dispatcher_; 140 const int routing_id_; 141 const int render_frame_id_; 142 scoped_ptr<SSLErrorHandlerDelegate> ssl_error_handler_delegate_; 143 144 DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler); 145 }; 146 147 WebSocketEventHandler::WebSocketEventHandler( 148 WebSocketDispatcherHost* dispatcher, 149 int routing_id, 150 int render_frame_id) 151 : dispatcher_(dispatcher), 152 routing_id_(routing_id), 153 render_frame_id_(render_frame_id) { 154 } 155 156 WebSocketEventHandler::~WebSocketEventHandler() { 157 DVLOG(1) << "WebSocketEventHandler destroyed routing_id=" << routing_id_; 158 } 159 160 ChannelState WebSocketEventHandler::OnAddChannelResponse( 161 bool fail, 162 const std::string& selected_protocol, 163 const std::string& extensions) { 164 DVLOG(3) << "WebSocketEventHandler::OnAddChannelResponse" 165 << " routing_id=" << routing_id_ << " fail=" << fail 166 << " selected_protocol=\"" << selected_protocol << "\"" 167 << " extensions=\"" << extensions << "\""; 168 169 return StateCast(dispatcher_->SendAddChannelResponse( 170 routing_id_, fail, selected_protocol, extensions)); 171 } 172 173 ChannelState WebSocketEventHandler::OnDataFrame( 174 bool fin, 175 net::WebSocketFrameHeader::OpCode type, 176 const std::vector<char>& data) { 177 DVLOG(3) << "WebSocketEventHandler::OnDataFrame" 178 << " routing_id=" << routing_id_ << " fin=" << fin 179 << " type=" << type << " data is " << data.size() << " bytes"; 180 181 return StateCast(dispatcher_->SendFrame( 182 routing_id_, fin, OpCodeToMessageType(type), data)); 183 } 184 185 ChannelState WebSocketEventHandler::OnClosingHandshake() { 186 DVLOG(3) << "WebSocketEventHandler::OnClosingHandshake" 187 << " routing_id=" << routing_id_; 188 189 return StateCast(dispatcher_->NotifyClosingHandshake(routing_id_)); 190 } 191 192 ChannelState WebSocketEventHandler::OnFlowControl(int64 quota) { 193 DVLOG(3) << "WebSocketEventHandler::OnFlowControl" 194 << " routing_id=" << routing_id_ << " quota=" << quota; 195 196 return StateCast(dispatcher_->SendFlowControl(routing_id_, quota)); 197 } 198 199 ChannelState WebSocketEventHandler::OnDropChannel(bool was_clean, 200 uint16 code, 201 const std::string& reason) { 202 DVLOG(3) << "WebSocketEventHandler::OnDropChannel" 203 << " routing_id=" << routing_id_ << " was_clean=" << was_clean 204 << " code=" << code << " reason=\"" << reason << "\""; 205 206 return StateCast( 207 dispatcher_->DoDropChannel(routing_id_, was_clean, code, reason)); 208 } 209 210 ChannelState WebSocketEventHandler::OnFailChannel(const std::string& message) { 211 DVLOG(3) << "WebSocketEventHandler::OnFailChannel" 212 << " routing_id=" << routing_id_ 213 << " message=\"" << message << "\""; 214 215 return StateCast(dispatcher_->NotifyFailure(routing_id_, message)); 216 } 217 218 ChannelState WebSocketEventHandler::OnStartOpeningHandshake( 219 scoped_ptr<net::WebSocketHandshakeRequestInfo> request) { 220 bool should_send = dispatcher_->CanReadRawCookies(); 221 DVLOG(3) << "WebSocketEventHandler::OnStartOpeningHandshake " 222 << "should_send=" << should_send; 223 224 if (!should_send) 225 return WebSocketEventInterface::CHANNEL_ALIVE; 226 227 WebSocketHandshakeRequest request_to_pass; 228 request_to_pass.url.Swap(&request->url); 229 net::HttpRequestHeaders::Iterator it(request->headers); 230 while (it.GetNext()) 231 request_to_pass.headers.push_back(std::make_pair(it.name(), it.value())); 232 request_to_pass.headers_text = 233 base::StringPrintf("GET %s HTTP/1.1\r\n", 234 request_to_pass.url.spec().c_str()) + 235 request->headers.ToString(); 236 request_to_pass.request_time = request->request_time; 237 238 return StateCast(dispatcher_->NotifyStartOpeningHandshake(routing_id_, 239 request_to_pass)); 240 } 241 242 ChannelState WebSocketEventHandler::OnFinishOpeningHandshake( 243 scoped_ptr<net::WebSocketHandshakeResponseInfo> response) { 244 bool should_send = dispatcher_->CanReadRawCookies(); 245 DVLOG(3) << "WebSocketEventHandler::OnFinishOpeningHandshake " 246 << "should_send=" << should_send; 247 248 if (!should_send) 249 return WebSocketEventInterface::CHANNEL_ALIVE; 250 251 WebSocketHandshakeResponse response_to_pass; 252 response_to_pass.url.Swap(&response->url); 253 response_to_pass.status_code = response->status_code; 254 response_to_pass.status_text.swap(response->status_text); 255 void* iter = NULL; 256 std::string name, value; 257 while (response->headers->EnumerateHeaderLines(&iter, &name, &value)) 258 response_to_pass.headers.push_back(std::make_pair(name, value)); 259 response_to_pass.headers_text = 260 net::HttpUtil::ConvertHeadersBackToHTTPResponse( 261 response->headers->raw_headers()); 262 response_to_pass.response_time = response->response_time; 263 264 return StateCast(dispatcher_->NotifyFinishOpeningHandshake(routing_id_, 265 response_to_pass)); 266 } 267 268 ChannelState WebSocketEventHandler::OnSSLCertificateError( 269 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks, 270 const GURL& url, 271 const net::SSLInfo& ssl_info, 272 bool fatal) { 273 DVLOG(3) << "WebSocketEventHandler::OnSSLCertificateError" 274 << " routing_id=" << routing_id_ << " url=" << url.spec() 275 << " cert_status=" << ssl_info.cert_status << " fatal=" << fatal; 276 ssl_error_handler_delegate_.reset( 277 new SSLErrorHandlerDelegate(callbacks.Pass())); 278 // We don't need request_id to be unique so just make a fake one. 279 GlobalRequestID request_id(-1, -1); 280 SSLManager::OnSSLCertificateError(ssl_error_handler_delegate_->GetWeakPtr(), 281 request_id, 282 ResourceType::SUB_RESOURCE, 283 url, 284 dispatcher_->render_process_id(), 285 render_frame_id_, 286 ssl_info, 287 fatal); 288 // The above method is always asynchronous. 289 return WebSocketEventInterface::CHANNEL_ALIVE; 290 } 291 292 WebSocketEventHandler::SSLErrorHandlerDelegate::SSLErrorHandlerDelegate( 293 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks) 294 : callbacks_(callbacks.Pass()), weak_ptr_factory_(this) {} 295 296 WebSocketEventHandler::SSLErrorHandlerDelegate::~SSLErrorHandlerDelegate() {} 297 298 base::WeakPtr<SSLErrorHandler::Delegate> 299 WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() { 300 return weak_ptr_factory_.GetWeakPtr(); 301 } 302 303 void WebSocketEventHandler::SSLErrorHandlerDelegate::CancelSSLRequest( 304 const GlobalRequestID& id, 305 int error, 306 const net::SSLInfo* ssl_info) { 307 DVLOG(3) << "SSLErrorHandlerDelegate::CancelSSLRequest" 308 << " error=" << error 309 << " cert_status=" << (ssl_info ? ssl_info->cert_status 310 : static_cast<net::CertStatus>(-1)); 311 callbacks_->CancelSSLRequest(error, ssl_info); 312 } 313 314 void WebSocketEventHandler::SSLErrorHandlerDelegate::ContinueSSLRequest( 315 const GlobalRequestID& id) { 316 DVLOG(3) << "SSLErrorHandlerDelegate::ContinueSSLRequest"; 317 callbacks_->ContinueSSLRequest(); 318 } 319 320 } // namespace 321 322 WebSocketHost::WebSocketHost(int routing_id, 323 WebSocketDispatcherHost* dispatcher, 324 net::URLRequestContext* url_request_context) 325 : dispatcher_(dispatcher), 326 url_request_context_(url_request_context), 327 routing_id_(routing_id) { 328 DVLOG(1) << "WebSocketHost: created routing_id=" << routing_id; 329 } 330 331 WebSocketHost::~WebSocketHost() {} 332 333 bool WebSocketHost::OnMessageReceived(const IPC::Message& message) { 334 bool handled = true; 335 IPC_BEGIN_MESSAGE_MAP(WebSocketHost, message) 336 IPC_MESSAGE_HANDLER(WebSocketHostMsg_AddChannelRequest, OnAddChannelRequest) 337 IPC_MESSAGE_HANDLER(WebSocketMsg_SendFrame, OnSendFrame) 338 IPC_MESSAGE_HANDLER(WebSocketMsg_FlowControl, OnFlowControl) 339 IPC_MESSAGE_HANDLER(WebSocketMsg_DropChannel, OnDropChannel) 340 IPC_MESSAGE_UNHANDLED(handled = false) 341 IPC_END_MESSAGE_MAP() 342 return handled; 343 } 344 345 void WebSocketHost::OnAddChannelRequest( 346 const GURL& socket_url, 347 const std::vector<std::string>& requested_protocols, 348 const url::Origin& origin, 349 int render_frame_id) { 350 DVLOG(3) << "WebSocketHost::OnAddChannelRequest" 351 << " routing_id=" << routing_id_ << " socket_url=\"" << socket_url 352 << "\" requested_protocols=\"" 353 << JoinString(requested_protocols, ", ") << "\" origin=\"" 354 << origin.string() << "\""; 355 356 DCHECK(!channel_); 357 scoped_ptr<net::WebSocketEventInterface> event_interface( 358 new WebSocketEventHandler(dispatcher_, routing_id_, render_frame_id)); 359 channel_.reset( 360 new net::WebSocketChannel(event_interface.Pass(), url_request_context_)); 361 channel_->SendAddChannelRequest(socket_url, requested_protocols, origin); 362 } 363 364 void WebSocketHost::OnSendFrame(bool fin, 365 WebSocketMessageType type, 366 const std::vector<char>& data) { 367 DVLOG(3) << "WebSocketHost::OnSendFrame" 368 << " routing_id=" << routing_id_ << " fin=" << fin 369 << " type=" << type << " data is " << data.size() << " bytes"; 370 371 DCHECK(channel_); 372 channel_->SendFrame(fin, MessageTypeToOpCode(type), data); 373 } 374 375 void WebSocketHost::OnFlowControl(int64 quota) { 376 DVLOG(3) << "WebSocketHost::OnFlowControl" 377 << " routing_id=" << routing_id_ << " quota=" << quota; 378 379 DCHECK(channel_); 380 channel_->SendFlowControl(quota); 381 } 382 383 void WebSocketHost::OnDropChannel(bool was_clean, 384 uint16 code, 385 const std::string& reason) { 386 DVLOG(3) << "WebSocketHost::OnDropChannel" 387 << " routing_id=" << routing_id_ << " was_clean=" << was_clean 388 << " code=" << code << " reason=\"" << reason << "\""; 389 390 DCHECK(channel_); 391 // TODO(yhirano): Handle |was_clean| appropriately. 392 channel_->StartClosingHandshake(code, reason); 393 } 394 395 } // namespace content 396