Home | History | Annotate | Download | only in renderer_host
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "content/browser/renderer_host/websocket_host.h"
      6 
      7 #include "base/basictypes.h"
      8 #include "base/memory/weak_ptr.h"
      9 #include "base/strings/string_util.h"
     10 #include "content/browser/renderer_host/websocket_dispatcher_host.h"
     11 #include "content/browser/ssl/ssl_error_handler.h"
     12 #include "content/browser/ssl/ssl_manager.h"
     13 #include "content/common/websocket_messages.h"
     14 #include "ipc/ipc_message_macros.h"
     15 #include "net/http/http_request_headers.h"
     16 #include "net/http/http_response_headers.h"
     17 #include "net/http/http_util.h"
     18 #include "net/ssl/ssl_info.h"
     19 #include "net/websockets/websocket_channel.h"
     20 #include "net/websockets/websocket_event_interface.h"
     21 #include "net/websockets/websocket_frame.h"  // for WebSocketFrameHeader::OpCode
     22 #include "net/websockets/websocket_handshake_request_info.h"
     23 #include "net/websockets/websocket_handshake_response_info.h"
     24 #include "url/origin.h"
     25 
     26 namespace content {
     27 
     28 namespace {
     29 
     30 typedef net::WebSocketEventInterface::ChannelState ChannelState;
     31 
     32 // Convert a content::WebSocketMessageType to a
     33 // net::WebSocketFrameHeader::OpCode
     34 net::WebSocketFrameHeader::OpCode MessageTypeToOpCode(
     35     WebSocketMessageType type) {
     36   DCHECK(type == WEB_SOCKET_MESSAGE_TYPE_CONTINUATION ||
     37          type == WEB_SOCKET_MESSAGE_TYPE_TEXT ||
     38          type == WEB_SOCKET_MESSAGE_TYPE_BINARY);
     39   typedef net::WebSocketFrameHeader::OpCode OpCode;
     40   // These compile asserts verify that the same underlying values are used for
     41   // both types, so we can simply cast between them.
     42   COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_CONTINUATION) ==
     43                      net::WebSocketFrameHeader::kOpCodeContinuation,
     44                  enum_values_must_match_for_opcode_continuation);
     45   COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_TEXT) ==
     46                      net::WebSocketFrameHeader::kOpCodeText,
     47                  enum_values_must_match_for_opcode_text);
     48   COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_BINARY) ==
     49                      net::WebSocketFrameHeader::kOpCodeBinary,
     50                  enum_values_must_match_for_opcode_binary);
     51   return static_cast<OpCode>(type);
     52 }
     53 
     54 WebSocketMessageType OpCodeToMessageType(
     55     net::WebSocketFrameHeader::OpCode opCode) {
     56   DCHECK(opCode == net::WebSocketFrameHeader::kOpCodeContinuation ||
     57          opCode == net::WebSocketFrameHeader::kOpCodeText ||
     58          opCode == net::WebSocketFrameHeader::kOpCodeBinary);
     59   // This cast is guaranteed valid by the COMPILE_ASSERT() statements above.
     60   return static_cast<WebSocketMessageType>(opCode);
     61 }
     62 
     63 ChannelState StateCast(WebSocketDispatcherHost::WebSocketHostState host_state) {
     64   const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_ALIVE =
     65       WebSocketDispatcherHost::WEBSOCKET_HOST_ALIVE;
     66   const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_DELETED =
     67       WebSocketDispatcherHost::WEBSOCKET_HOST_DELETED;
     68 
     69   DCHECK(host_state == WEBSOCKET_HOST_ALIVE ||
     70          host_state == WEBSOCKET_HOST_DELETED);
     71   // These compile asserts verify that we can get away with using static_cast<>
     72   // for the conversion.
     73   COMPILE_ASSERT(static_cast<ChannelState>(WEBSOCKET_HOST_ALIVE) ==
     74                      net::WebSocketEventInterface::CHANNEL_ALIVE,
     75                  enum_values_must_match_for_state_alive);
     76   COMPILE_ASSERT(static_cast<ChannelState>(WEBSOCKET_HOST_DELETED) ==
     77                      net::WebSocketEventInterface::CHANNEL_DELETED,
     78                  enum_values_must_match_for_state_deleted);
     79   return static_cast<ChannelState>(host_state);
     80 }
     81 
     82 // Implementation of net::WebSocketEventInterface. Receives events from our
     83 // WebSocketChannel object. Each event is translated to an IPC and sent to the
     84 // renderer or child process via WebSocketDispatcherHost.
     85 class WebSocketEventHandler : public net::WebSocketEventInterface {
     86  public:
     87   WebSocketEventHandler(WebSocketDispatcherHost* dispatcher,
     88                         int routing_id,
     89                         int render_frame_id);
     90   virtual ~WebSocketEventHandler();
     91 
     92   // net::WebSocketEventInterface implementation
     93 
     94   virtual ChannelState OnAddChannelResponse(
     95       bool fail,
     96       const std::string& selected_subprotocol,
     97       const std::string& extensions) OVERRIDE;
     98   virtual ChannelState OnDataFrame(bool fin,
     99                                    WebSocketMessageType type,
    100                                    const std::vector<char>& data) OVERRIDE;
    101   virtual ChannelState OnClosingHandshake() OVERRIDE;
    102   virtual ChannelState OnFlowControl(int64 quota) OVERRIDE;
    103   virtual ChannelState OnDropChannel(bool was_clean,
    104                                      uint16 code,
    105                                      const std::string& reason) OVERRIDE;
    106   virtual ChannelState OnFailChannel(const std::string& message) OVERRIDE;
    107   virtual ChannelState OnStartOpeningHandshake(
    108       scoped_ptr<net::WebSocketHandshakeRequestInfo> request) OVERRIDE;
    109   virtual ChannelState OnFinishOpeningHandshake(
    110       scoped_ptr<net::WebSocketHandshakeResponseInfo> response) OVERRIDE;
    111   virtual ChannelState OnSSLCertificateError(
    112       scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
    113       const GURL& url,
    114       const net::SSLInfo& ssl_info,
    115       bool fatal) OVERRIDE;
    116 
    117  private:
    118   class SSLErrorHandlerDelegate : public SSLErrorHandler::Delegate {
    119    public:
    120     SSLErrorHandlerDelegate(
    121         scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks);
    122     virtual ~SSLErrorHandlerDelegate();
    123 
    124     base::WeakPtr<SSLErrorHandler::Delegate> GetWeakPtr();
    125 
    126     // SSLErrorHandler::Delegate methods
    127     virtual void CancelSSLRequest(const GlobalRequestID& id,
    128                                   int error,
    129                                   const net::SSLInfo* ssl_info) OVERRIDE;
    130     virtual void ContinueSSLRequest(const GlobalRequestID& id) OVERRIDE;
    131 
    132    private:
    133     scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks_;
    134     base::WeakPtrFactory<SSLErrorHandlerDelegate> weak_ptr_factory_;
    135 
    136     DISALLOW_COPY_AND_ASSIGN(SSLErrorHandlerDelegate);
    137   };
    138 
    139   WebSocketDispatcherHost* const dispatcher_;
    140   const int routing_id_;
    141   const int render_frame_id_;
    142   scoped_ptr<SSLErrorHandlerDelegate> ssl_error_handler_delegate_;
    143 
    144   DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler);
    145 };
    146 
    147 WebSocketEventHandler::WebSocketEventHandler(
    148     WebSocketDispatcherHost* dispatcher,
    149     int routing_id,
    150     int render_frame_id)
    151     : dispatcher_(dispatcher),
    152       routing_id_(routing_id),
    153       render_frame_id_(render_frame_id) {
    154 }
    155 
    156 WebSocketEventHandler::~WebSocketEventHandler() {
    157   DVLOG(1) << "WebSocketEventHandler destroyed routing_id=" << routing_id_;
    158 }
    159 
    160 ChannelState WebSocketEventHandler::OnAddChannelResponse(
    161     bool fail,
    162     const std::string& selected_protocol,
    163     const std::string& extensions) {
    164   DVLOG(3) << "WebSocketEventHandler::OnAddChannelResponse"
    165            << " routing_id=" << routing_id_ << " fail=" << fail
    166            << " selected_protocol=\"" << selected_protocol << "\""
    167            << " extensions=\"" << extensions << "\"";
    168 
    169   return StateCast(dispatcher_->SendAddChannelResponse(
    170       routing_id_, fail, selected_protocol, extensions));
    171 }
    172 
    173 ChannelState WebSocketEventHandler::OnDataFrame(
    174     bool fin,
    175     net::WebSocketFrameHeader::OpCode type,
    176     const std::vector<char>& data) {
    177   DVLOG(3) << "WebSocketEventHandler::OnDataFrame"
    178            << " routing_id=" << routing_id_ << " fin=" << fin
    179            << " type=" << type << " data is " << data.size() << " bytes";
    180 
    181   return StateCast(dispatcher_->SendFrame(
    182       routing_id_, fin, OpCodeToMessageType(type), data));
    183 }
    184 
    185 ChannelState WebSocketEventHandler::OnClosingHandshake() {
    186   DVLOG(3) << "WebSocketEventHandler::OnClosingHandshake"
    187            << " routing_id=" << routing_id_;
    188 
    189   return StateCast(dispatcher_->NotifyClosingHandshake(routing_id_));
    190 }
    191 
    192 ChannelState WebSocketEventHandler::OnFlowControl(int64 quota) {
    193   DVLOG(3) << "WebSocketEventHandler::OnFlowControl"
    194            << " routing_id=" << routing_id_ << " quota=" << quota;
    195 
    196   return StateCast(dispatcher_->SendFlowControl(routing_id_, quota));
    197 }
    198 
    199 ChannelState WebSocketEventHandler::OnDropChannel(bool was_clean,
    200                                                   uint16 code,
    201                                                   const std::string& reason) {
    202   DVLOG(3) << "WebSocketEventHandler::OnDropChannel"
    203            << " routing_id=" << routing_id_ << " was_clean=" << was_clean
    204            << " code=" << code << " reason=\"" << reason << "\"";
    205 
    206   return StateCast(
    207       dispatcher_->DoDropChannel(routing_id_, was_clean, code, reason));
    208 }
    209 
    210 ChannelState WebSocketEventHandler::OnFailChannel(const std::string& message) {
    211   DVLOG(3) << "WebSocketEventHandler::OnFailChannel"
    212            << " routing_id=" << routing_id_
    213            << " message=\"" << message << "\"";
    214 
    215   return StateCast(dispatcher_->NotifyFailure(routing_id_, message));
    216 }
    217 
    218 ChannelState WebSocketEventHandler::OnStartOpeningHandshake(
    219     scoped_ptr<net::WebSocketHandshakeRequestInfo> request) {
    220   bool should_send = dispatcher_->CanReadRawCookies();
    221   DVLOG(3) << "WebSocketEventHandler::OnStartOpeningHandshake "
    222            << "should_send=" << should_send;
    223 
    224   if (!should_send)
    225     return WebSocketEventInterface::CHANNEL_ALIVE;
    226 
    227   WebSocketHandshakeRequest request_to_pass;
    228   request_to_pass.url.Swap(&request->url);
    229   net::HttpRequestHeaders::Iterator it(request->headers);
    230   while (it.GetNext())
    231     request_to_pass.headers.push_back(std::make_pair(it.name(), it.value()));
    232   request_to_pass.headers_text =
    233       base::StringPrintf("GET %s HTTP/1.1\r\n",
    234                          request_to_pass.url.spec().c_str()) +
    235       request->headers.ToString();
    236   request_to_pass.request_time = request->request_time;
    237 
    238   return StateCast(dispatcher_->NotifyStartOpeningHandshake(routing_id_,
    239                                                             request_to_pass));
    240 }
    241 
    242 ChannelState WebSocketEventHandler::OnFinishOpeningHandshake(
    243     scoped_ptr<net::WebSocketHandshakeResponseInfo> response) {
    244   bool should_send = dispatcher_->CanReadRawCookies();
    245   DVLOG(3) << "WebSocketEventHandler::OnFinishOpeningHandshake "
    246            << "should_send=" << should_send;
    247 
    248   if (!should_send)
    249     return WebSocketEventInterface::CHANNEL_ALIVE;
    250 
    251   WebSocketHandshakeResponse response_to_pass;
    252   response_to_pass.url.Swap(&response->url);
    253   response_to_pass.status_code = response->status_code;
    254   response_to_pass.status_text.swap(response->status_text);
    255   void* iter = NULL;
    256   std::string name, value;
    257   while (response->headers->EnumerateHeaderLines(&iter, &name, &value))
    258     response_to_pass.headers.push_back(std::make_pair(name, value));
    259   response_to_pass.headers_text =
    260       net::HttpUtil::ConvertHeadersBackToHTTPResponse(
    261           response->headers->raw_headers());
    262   response_to_pass.response_time = response->response_time;
    263 
    264   return StateCast(dispatcher_->NotifyFinishOpeningHandshake(routing_id_,
    265                                                              response_to_pass));
    266 }
    267 
    268 ChannelState WebSocketEventHandler::OnSSLCertificateError(
    269     scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
    270     const GURL& url,
    271     const net::SSLInfo& ssl_info,
    272     bool fatal) {
    273   DVLOG(3) << "WebSocketEventHandler::OnSSLCertificateError"
    274            << " routing_id=" << routing_id_ << " url=" << url.spec()
    275            << " cert_status=" << ssl_info.cert_status << " fatal=" << fatal;
    276   ssl_error_handler_delegate_.reset(
    277       new SSLErrorHandlerDelegate(callbacks.Pass()));
    278   // We don't need request_id to be unique so just make a fake one.
    279   GlobalRequestID request_id(-1, -1);
    280   SSLManager::OnSSLCertificateError(ssl_error_handler_delegate_->GetWeakPtr(),
    281                                     request_id,
    282                                     ResourceType::SUB_RESOURCE,
    283                                     url,
    284                                     dispatcher_->render_process_id(),
    285                                     render_frame_id_,
    286                                     ssl_info,
    287                                     fatal);
    288   // The above method is always asynchronous.
    289   return WebSocketEventInterface::CHANNEL_ALIVE;
    290 }
    291 
    292 WebSocketEventHandler::SSLErrorHandlerDelegate::SSLErrorHandlerDelegate(
    293     scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks)
    294     : callbacks_(callbacks.Pass()), weak_ptr_factory_(this) {}
    295 
    296 WebSocketEventHandler::SSLErrorHandlerDelegate::~SSLErrorHandlerDelegate() {}
    297 
    298 base::WeakPtr<SSLErrorHandler::Delegate>
    299 WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() {
    300   return weak_ptr_factory_.GetWeakPtr();
    301 }
    302 
    303 void WebSocketEventHandler::SSLErrorHandlerDelegate::CancelSSLRequest(
    304     const GlobalRequestID& id,
    305     int error,
    306     const net::SSLInfo* ssl_info) {
    307   DVLOG(3) << "SSLErrorHandlerDelegate::CancelSSLRequest"
    308            << " error=" << error
    309            << " cert_status=" << (ssl_info ? ssl_info->cert_status
    310                                            : static_cast<net::CertStatus>(-1));
    311   callbacks_->CancelSSLRequest(error, ssl_info);
    312 }
    313 
    314 void WebSocketEventHandler::SSLErrorHandlerDelegate::ContinueSSLRequest(
    315     const GlobalRequestID& id) {
    316   DVLOG(3) << "SSLErrorHandlerDelegate::ContinueSSLRequest";
    317   callbacks_->ContinueSSLRequest();
    318 }
    319 
    320 }  // namespace
    321 
    322 WebSocketHost::WebSocketHost(int routing_id,
    323                              WebSocketDispatcherHost* dispatcher,
    324                              net::URLRequestContext* url_request_context)
    325     : dispatcher_(dispatcher),
    326       url_request_context_(url_request_context),
    327       routing_id_(routing_id) {
    328   DVLOG(1) << "WebSocketHost: created routing_id=" << routing_id;
    329 }
    330 
    331 WebSocketHost::~WebSocketHost() {}
    332 
    333 bool WebSocketHost::OnMessageReceived(const IPC::Message& message) {
    334   bool handled = true;
    335   IPC_BEGIN_MESSAGE_MAP(WebSocketHost, message)
    336     IPC_MESSAGE_HANDLER(WebSocketHostMsg_AddChannelRequest, OnAddChannelRequest)
    337     IPC_MESSAGE_HANDLER(WebSocketMsg_SendFrame, OnSendFrame)
    338     IPC_MESSAGE_HANDLER(WebSocketMsg_FlowControl, OnFlowControl)
    339     IPC_MESSAGE_HANDLER(WebSocketMsg_DropChannel, OnDropChannel)
    340     IPC_MESSAGE_UNHANDLED(handled = false)
    341   IPC_END_MESSAGE_MAP()
    342   return handled;
    343 }
    344 
    345 void WebSocketHost::OnAddChannelRequest(
    346     const GURL& socket_url,
    347     const std::vector<std::string>& requested_protocols,
    348     const url::Origin& origin,
    349     int render_frame_id) {
    350   DVLOG(3) << "WebSocketHost::OnAddChannelRequest"
    351            << " routing_id=" << routing_id_ << " socket_url=\"" << socket_url
    352            << "\" requested_protocols=\""
    353            << JoinString(requested_protocols, ", ") << "\" origin=\""
    354            << origin.string() << "\"";
    355 
    356   DCHECK(!channel_);
    357   scoped_ptr<net::WebSocketEventInterface> event_interface(
    358       new WebSocketEventHandler(dispatcher_, routing_id_, render_frame_id));
    359   channel_.reset(
    360       new net::WebSocketChannel(event_interface.Pass(), url_request_context_));
    361   channel_->SendAddChannelRequest(socket_url, requested_protocols, origin);
    362 }
    363 
    364 void WebSocketHost::OnSendFrame(bool fin,
    365                                 WebSocketMessageType type,
    366                                 const std::vector<char>& data) {
    367   DVLOG(3) << "WebSocketHost::OnSendFrame"
    368            << " routing_id=" << routing_id_ << " fin=" << fin
    369            << " type=" << type << " data is " << data.size() << " bytes";
    370 
    371   DCHECK(channel_);
    372   channel_->SendFrame(fin, MessageTypeToOpCode(type), data);
    373 }
    374 
    375 void WebSocketHost::OnFlowControl(int64 quota) {
    376   DVLOG(3) << "WebSocketHost::OnFlowControl"
    377            << " routing_id=" << routing_id_ << " quota=" << quota;
    378 
    379   DCHECK(channel_);
    380   channel_->SendFlowControl(quota);
    381 }
    382 
    383 void WebSocketHost::OnDropChannel(bool was_clean,
    384                                   uint16 code,
    385                                   const std::string& reason) {
    386   DVLOG(3) << "WebSocketHost::OnDropChannel"
    387            << " routing_id=" << routing_id_ << " was_clean=" << was_clean
    388            << " code=" << code << " reason=\"" << reason << "\"";
    389 
    390   DCHECK(channel_);
    391   // TODO(yhirano): Handle |was_clean| appropriately.
    392   channel_->StartClosingHandshake(code, reason);
    393 }
    394 
    395 }  // namespace content
    396