Home | History | Annotate | Download | only in renderer_host
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "content/browser/renderer_host/websocket_host.h"
      6 
      7 #include "base/basictypes.h"
      8 #include "base/memory/weak_ptr.h"
      9 #include "base/strings/string_util.h"
     10 #include "content/browser/renderer_host/websocket_dispatcher_host.h"
     11 #include "content/browser/ssl/ssl_error_handler.h"
     12 #include "content/browser/ssl/ssl_manager.h"
     13 #include "content/common/websocket_messages.h"
     14 #include "ipc/ipc_message_macros.h"
     15 #include "net/http/http_request_headers.h"
     16 #include "net/http/http_response_headers.h"
     17 #include "net/http/http_util.h"
     18 #include "net/ssl/ssl_info.h"
     19 #include "net/websockets/websocket_channel.h"
     20 #include "net/websockets/websocket_errors.h"
     21 #include "net/websockets/websocket_event_interface.h"
     22 #include "net/websockets/websocket_frame.h"  // for WebSocketFrameHeader::OpCode
     23 #include "net/websockets/websocket_handshake_request_info.h"
     24 #include "net/websockets/websocket_handshake_response_info.h"
     25 #include "url/origin.h"
     26 
     27 namespace content {
     28 
     29 namespace {
     30 
     31 typedef net::WebSocketEventInterface::ChannelState ChannelState;
     32 
     33 // Convert a content::WebSocketMessageType to a
     34 // net::WebSocketFrameHeader::OpCode
     35 net::WebSocketFrameHeader::OpCode MessageTypeToOpCode(
     36     WebSocketMessageType type) {
     37   DCHECK(type == WEB_SOCKET_MESSAGE_TYPE_CONTINUATION ||
     38          type == WEB_SOCKET_MESSAGE_TYPE_TEXT ||
     39          type == WEB_SOCKET_MESSAGE_TYPE_BINARY);
     40   typedef net::WebSocketFrameHeader::OpCode OpCode;
     41   // These compile asserts verify that the same underlying values are used for
     42   // both types, so we can simply cast between them.
     43   COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_CONTINUATION) ==
     44                      net::WebSocketFrameHeader::kOpCodeContinuation,
     45                  enum_values_must_match_for_opcode_continuation);
     46   COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_TEXT) ==
     47                      net::WebSocketFrameHeader::kOpCodeText,
     48                  enum_values_must_match_for_opcode_text);
     49   COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_BINARY) ==
     50                      net::WebSocketFrameHeader::kOpCodeBinary,
     51                  enum_values_must_match_for_opcode_binary);
     52   return static_cast<OpCode>(type);
     53 }
     54 
     55 WebSocketMessageType OpCodeToMessageType(
     56     net::WebSocketFrameHeader::OpCode opCode) {
     57   DCHECK(opCode == net::WebSocketFrameHeader::kOpCodeContinuation ||
     58          opCode == net::WebSocketFrameHeader::kOpCodeText ||
     59          opCode == net::WebSocketFrameHeader::kOpCodeBinary);
     60   // This cast is guaranteed valid by the COMPILE_ASSERT() statements above.
     61   return static_cast<WebSocketMessageType>(opCode);
     62 }
     63 
     64 ChannelState StateCast(WebSocketDispatcherHost::WebSocketHostState host_state) {
     65   const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_ALIVE =
     66       WebSocketDispatcherHost::WEBSOCKET_HOST_ALIVE;
     67   const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_DELETED =
     68       WebSocketDispatcherHost::WEBSOCKET_HOST_DELETED;
     69 
     70   DCHECK(host_state == WEBSOCKET_HOST_ALIVE ||
     71          host_state == WEBSOCKET_HOST_DELETED);
     72   // These compile asserts verify that we can get away with using static_cast<>
     73   // for the conversion.
     74   COMPILE_ASSERT(static_cast<ChannelState>(WEBSOCKET_HOST_ALIVE) ==
     75                      net::WebSocketEventInterface::CHANNEL_ALIVE,
     76                  enum_values_must_match_for_state_alive);
     77   COMPILE_ASSERT(static_cast<ChannelState>(WEBSOCKET_HOST_DELETED) ==
     78                      net::WebSocketEventInterface::CHANNEL_DELETED,
     79                  enum_values_must_match_for_state_deleted);
     80   return static_cast<ChannelState>(host_state);
     81 }
     82 
     83 // Implementation of net::WebSocketEventInterface. Receives events from our
     84 // WebSocketChannel object. Each event is translated to an IPC and sent to the
     85 // renderer or child process via WebSocketDispatcherHost.
     86 class WebSocketEventHandler : public net::WebSocketEventInterface {
     87  public:
     88   WebSocketEventHandler(WebSocketDispatcherHost* dispatcher,
     89                         int routing_id,
     90                         int render_frame_id);
     91   virtual ~WebSocketEventHandler();
     92 
     93   // net::WebSocketEventInterface implementation
     94 
     95   virtual ChannelState OnAddChannelResponse(
     96       bool fail,
     97       const std::string& selected_subprotocol,
     98       const std::string& extensions) OVERRIDE;
     99   virtual ChannelState OnDataFrame(bool fin,
    100                                    WebSocketMessageType type,
    101                                    const std::vector<char>& data) OVERRIDE;
    102   virtual ChannelState OnClosingHandshake() OVERRIDE;
    103   virtual ChannelState OnFlowControl(int64 quota) OVERRIDE;
    104   virtual ChannelState OnDropChannel(bool was_clean,
    105                                      uint16 code,
    106                                      const std::string& reason) OVERRIDE;
    107   virtual ChannelState OnFailChannel(const std::string& message) OVERRIDE;
    108   virtual ChannelState OnStartOpeningHandshake(
    109       scoped_ptr<net::WebSocketHandshakeRequestInfo> request) OVERRIDE;
    110   virtual ChannelState OnFinishOpeningHandshake(
    111       scoped_ptr<net::WebSocketHandshakeResponseInfo> response) OVERRIDE;
    112   virtual ChannelState OnSSLCertificateError(
    113       scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
    114       const GURL& url,
    115       const net::SSLInfo& ssl_info,
    116       bool fatal) OVERRIDE;
    117 
    118  private:
    119   class SSLErrorHandlerDelegate : public SSLErrorHandler::Delegate {
    120    public:
    121     SSLErrorHandlerDelegate(
    122         scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks);
    123     virtual ~SSLErrorHandlerDelegate();
    124 
    125     base::WeakPtr<SSLErrorHandler::Delegate> GetWeakPtr();
    126 
    127     // SSLErrorHandler::Delegate methods
    128     virtual void CancelSSLRequest(const GlobalRequestID& id,
    129                                   int error,
    130                                   const net::SSLInfo* ssl_info) OVERRIDE;
    131     virtual void ContinueSSLRequest(const GlobalRequestID& id) OVERRIDE;
    132 
    133    private:
    134     scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks_;
    135     base::WeakPtrFactory<SSLErrorHandlerDelegate> weak_ptr_factory_;
    136 
    137     DISALLOW_COPY_AND_ASSIGN(SSLErrorHandlerDelegate);
    138   };
    139 
    140   WebSocketDispatcherHost* const dispatcher_;
    141   const int routing_id_;
    142   const int render_frame_id_;
    143   scoped_ptr<SSLErrorHandlerDelegate> ssl_error_handler_delegate_;
    144 
    145   DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler);
    146 };
    147 
    148 WebSocketEventHandler::WebSocketEventHandler(
    149     WebSocketDispatcherHost* dispatcher,
    150     int routing_id,
    151     int render_frame_id)
    152     : dispatcher_(dispatcher),
    153       routing_id_(routing_id),
    154       render_frame_id_(render_frame_id) {
    155 }
    156 
    157 WebSocketEventHandler::~WebSocketEventHandler() {
    158   DVLOG(1) << "WebSocketEventHandler destroyed routing_id=" << routing_id_;
    159 }
    160 
    161 ChannelState WebSocketEventHandler::OnAddChannelResponse(
    162     bool fail,
    163     const std::string& selected_protocol,
    164     const std::string& extensions) {
    165   DVLOG(3) << "WebSocketEventHandler::OnAddChannelResponse"
    166            << " routing_id=" << routing_id_ << " fail=" << fail
    167            << " selected_protocol=\"" << selected_protocol << "\""
    168            << " extensions=\"" << extensions << "\"";
    169 
    170   return StateCast(dispatcher_->SendAddChannelResponse(
    171       routing_id_, fail, selected_protocol, extensions));
    172 }
    173 
    174 ChannelState WebSocketEventHandler::OnDataFrame(
    175     bool fin,
    176     net::WebSocketFrameHeader::OpCode type,
    177     const std::vector<char>& data) {
    178   DVLOG(3) << "WebSocketEventHandler::OnDataFrame"
    179            << " routing_id=" << routing_id_ << " fin=" << fin
    180            << " type=" << type << " data is " << data.size() << " bytes";
    181 
    182   return StateCast(dispatcher_->SendFrame(
    183       routing_id_, fin, OpCodeToMessageType(type), data));
    184 }
    185 
    186 ChannelState WebSocketEventHandler::OnClosingHandshake() {
    187   DVLOG(3) << "WebSocketEventHandler::OnClosingHandshake"
    188            << " routing_id=" << routing_id_;
    189 
    190   return StateCast(dispatcher_->NotifyClosingHandshake(routing_id_));
    191 }
    192 
    193 ChannelState WebSocketEventHandler::OnFlowControl(int64 quota) {
    194   DVLOG(3) << "WebSocketEventHandler::OnFlowControl"
    195            << " routing_id=" << routing_id_ << " quota=" << quota;
    196 
    197   return StateCast(dispatcher_->SendFlowControl(routing_id_, quota));
    198 }
    199 
    200 ChannelState WebSocketEventHandler::OnDropChannel(bool was_clean,
    201                                                   uint16 code,
    202                                                   const std::string& reason) {
    203   DVLOG(3) << "WebSocketEventHandler::OnDropChannel"
    204            << " routing_id=" << routing_id_ << " was_clean=" << was_clean
    205            << " code=" << code << " reason=\"" << reason << "\"";
    206 
    207   return StateCast(
    208       dispatcher_->DoDropChannel(routing_id_, was_clean, code, reason));
    209 }
    210 
    211 ChannelState WebSocketEventHandler::OnFailChannel(const std::string& message) {
    212   DVLOG(3) << "WebSocketEventHandler::OnFailChannel"
    213            << " routing_id=" << routing_id_
    214            << " message=\"" << message << "\"";
    215 
    216   return StateCast(dispatcher_->NotifyFailure(routing_id_, message));
    217 }
    218 
    219 ChannelState WebSocketEventHandler::OnStartOpeningHandshake(
    220     scoped_ptr<net::WebSocketHandshakeRequestInfo> request) {
    221   bool should_send = dispatcher_->CanReadRawCookies();
    222   DVLOG(3) << "WebSocketEventHandler::OnStartOpeningHandshake "
    223            << "should_send=" << should_send;
    224 
    225   if (!should_send)
    226     return WebSocketEventInterface::CHANNEL_ALIVE;
    227 
    228   WebSocketHandshakeRequest request_to_pass;
    229   request_to_pass.url.Swap(&request->url);
    230   net::HttpRequestHeaders::Iterator it(request->headers);
    231   while (it.GetNext())
    232     request_to_pass.headers.push_back(std::make_pair(it.name(), it.value()));
    233   request_to_pass.headers_text =
    234       base::StringPrintf("GET %s HTTP/1.1\r\n",
    235                          request_to_pass.url.spec().c_str()) +
    236       request->headers.ToString();
    237   request_to_pass.request_time = request->request_time;
    238 
    239   return StateCast(dispatcher_->NotifyStartOpeningHandshake(routing_id_,
    240                                                             request_to_pass));
    241 }
    242 
    243 ChannelState WebSocketEventHandler::OnFinishOpeningHandshake(
    244     scoped_ptr<net::WebSocketHandshakeResponseInfo> response) {
    245   bool should_send = dispatcher_->CanReadRawCookies();
    246   DVLOG(3) << "WebSocketEventHandler::OnFinishOpeningHandshake "
    247            << "should_send=" << should_send;
    248 
    249   if (!should_send)
    250     return WebSocketEventInterface::CHANNEL_ALIVE;
    251 
    252   WebSocketHandshakeResponse response_to_pass;
    253   response_to_pass.url.Swap(&response->url);
    254   response_to_pass.status_code = response->status_code;
    255   response_to_pass.status_text.swap(response->status_text);
    256   void* iter = NULL;
    257   std::string name, value;
    258   while (response->headers->EnumerateHeaderLines(&iter, &name, &value))
    259     response_to_pass.headers.push_back(std::make_pair(name, value));
    260   response_to_pass.headers_text =
    261       net::HttpUtil::ConvertHeadersBackToHTTPResponse(
    262           response->headers->raw_headers());
    263   response_to_pass.response_time = response->response_time;
    264 
    265   return StateCast(dispatcher_->NotifyFinishOpeningHandshake(routing_id_,
    266                                                              response_to_pass));
    267 }
    268 
    269 ChannelState WebSocketEventHandler::OnSSLCertificateError(
    270     scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
    271     const GURL& url,
    272     const net::SSLInfo& ssl_info,
    273     bool fatal) {
    274   DVLOG(3) << "WebSocketEventHandler::OnSSLCertificateError"
    275            << " routing_id=" << routing_id_ << " url=" << url.spec()
    276            << " cert_status=" << ssl_info.cert_status << " fatal=" << fatal;
    277   ssl_error_handler_delegate_.reset(
    278       new SSLErrorHandlerDelegate(callbacks.Pass()));
    279   // We don't need request_id to be unique so just make a fake one.
    280   GlobalRequestID request_id(-1, -1);
    281   SSLManager::OnSSLCertificateError(ssl_error_handler_delegate_->GetWeakPtr(),
    282                                     request_id,
    283                                     RESOURCE_TYPE_SUB_RESOURCE,
    284                                     url,
    285                                     dispatcher_->render_process_id(),
    286                                     render_frame_id_,
    287                                     ssl_info,
    288                                     fatal);
    289   // The above method is always asynchronous.
    290   return WebSocketEventInterface::CHANNEL_ALIVE;
    291 }
    292 
    293 WebSocketEventHandler::SSLErrorHandlerDelegate::SSLErrorHandlerDelegate(
    294     scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks)
    295     : callbacks_(callbacks.Pass()), weak_ptr_factory_(this) {}
    296 
    297 WebSocketEventHandler::SSLErrorHandlerDelegate::~SSLErrorHandlerDelegate() {}
    298 
    299 base::WeakPtr<SSLErrorHandler::Delegate>
    300 WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() {
    301   return weak_ptr_factory_.GetWeakPtr();
    302 }
    303 
    304 void WebSocketEventHandler::SSLErrorHandlerDelegate::CancelSSLRequest(
    305     const GlobalRequestID& id,
    306     int error,
    307     const net::SSLInfo* ssl_info) {
    308   DVLOG(3) << "SSLErrorHandlerDelegate::CancelSSLRequest"
    309            << " error=" << error
    310            << " cert_status=" << (ssl_info ? ssl_info->cert_status
    311                                            : static_cast<net::CertStatus>(-1));
    312   callbacks_->CancelSSLRequest(error, ssl_info);
    313 }
    314 
    315 void WebSocketEventHandler::SSLErrorHandlerDelegate::ContinueSSLRequest(
    316     const GlobalRequestID& id) {
    317   DVLOG(3) << "SSLErrorHandlerDelegate::ContinueSSLRequest";
    318   callbacks_->ContinueSSLRequest();
    319 }
    320 
    321 }  // namespace
    322 
    323 WebSocketHost::WebSocketHost(int routing_id,
    324                              WebSocketDispatcherHost* dispatcher,
    325                              net::URLRequestContext* url_request_context)
    326     : dispatcher_(dispatcher),
    327       url_request_context_(url_request_context),
    328       routing_id_(routing_id) {
    329   DVLOG(1) << "WebSocketHost: created routing_id=" << routing_id;
    330 }
    331 
    332 WebSocketHost::~WebSocketHost() {}
    333 
    334 void WebSocketHost::GoAway() {
    335   OnDropChannel(false, static_cast<uint16>(net::kWebSocketErrorGoingAway), "");
    336 }
    337 
    338 bool WebSocketHost::OnMessageReceived(const IPC::Message& message) {
    339   bool handled = true;
    340   IPC_BEGIN_MESSAGE_MAP(WebSocketHost, message)
    341     IPC_MESSAGE_HANDLER(WebSocketHostMsg_AddChannelRequest, OnAddChannelRequest)
    342     IPC_MESSAGE_HANDLER(WebSocketMsg_SendFrame, OnSendFrame)
    343     IPC_MESSAGE_HANDLER(WebSocketMsg_FlowControl, OnFlowControl)
    344     IPC_MESSAGE_HANDLER(WebSocketMsg_DropChannel, OnDropChannel)
    345     IPC_MESSAGE_UNHANDLED(handled = false)
    346   IPC_END_MESSAGE_MAP()
    347   return handled;
    348 }
    349 
    350 void WebSocketHost::OnAddChannelRequest(
    351     const GURL& socket_url,
    352     const std::vector<std::string>& requested_protocols,
    353     const url::Origin& origin,
    354     int render_frame_id) {
    355   DVLOG(3) << "WebSocketHost::OnAddChannelRequest"
    356            << " routing_id=" << routing_id_ << " socket_url=\"" << socket_url
    357            << "\" requested_protocols=\""
    358            << JoinString(requested_protocols, ", ") << "\" origin=\""
    359            << origin.string() << "\"";
    360 
    361   DCHECK(!channel_);
    362   scoped_ptr<net::WebSocketEventInterface> event_interface(
    363       new WebSocketEventHandler(dispatcher_, routing_id_, render_frame_id));
    364   channel_.reset(
    365       new net::WebSocketChannel(event_interface.Pass(), url_request_context_));
    366   channel_->SendAddChannelRequest(socket_url, requested_protocols, origin);
    367 }
    368 
    369 void WebSocketHost::OnSendFrame(bool fin,
    370                                 WebSocketMessageType type,
    371                                 const std::vector<char>& data) {
    372   DVLOG(3) << "WebSocketHost::OnSendFrame"
    373            << " routing_id=" << routing_id_ << " fin=" << fin
    374            << " type=" << type << " data is " << data.size() << " bytes";
    375 
    376   DCHECK(channel_);
    377   channel_->SendFrame(fin, MessageTypeToOpCode(type), data);
    378 }
    379 
    380 void WebSocketHost::OnFlowControl(int64 quota) {
    381   DVLOG(3) << "WebSocketHost::OnFlowControl"
    382            << " routing_id=" << routing_id_ << " quota=" << quota;
    383 
    384   DCHECK(channel_);
    385   channel_->SendFlowControl(quota);
    386 }
    387 
    388 void WebSocketHost::OnDropChannel(bool was_clean,
    389                                   uint16 code,
    390                                   const std::string& reason) {
    391   DVLOG(3) << "WebSocketHost::OnDropChannel"
    392            << " routing_id=" << routing_id_ << " was_clean=" << was_clean
    393            << " code=" << code << " reason=\"" << reason << "\"";
    394 
    395   DCHECK(channel_);
    396   // TODO(yhirano): Handle |was_clean| appropriately.
    397   channel_->StartClosingHandshake(code, reason);
    398 }
    399 
    400 }  // namespace content
    401