Home | History | Annotate | Download | only in network
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "mojo/services/network/web_socket_impl.h"
      6 
      7 #include "base/logging.h"
      8 #include "base/message_loop/message_loop.h"
      9 #include "mojo/common/handle_watcher.h"
     10 #include "mojo/services/network/network_context.h"
     11 #include "mojo/services/public/cpp/network/web_socket_read_queue.h"
     12 #include "mojo/services/public/cpp/network/web_socket_write_queue.h"
     13 #include "net/websockets/websocket_channel.h"
     14 #include "net/websockets/websocket_errors.h"
     15 #include "net/websockets/websocket_event_interface.h"
     16 #include "net/websockets/websocket_frame.h"  // for WebSocketFrameHeader::OpCode
     17 #include "net/websockets/websocket_handshake_request_info.h"
     18 #include "net/websockets/websocket_handshake_response_info.h"
     19 #include "url/origin.h"
     20 
     21 namespace mojo {
     22 
     23 template <>
     24 struct TypeConverter<net::WebSocketFrameHeader::OpCode,
     25                      WebSocket::MessageType> {
     26   static net::WebSocketFrameHeader::OpCode Convert(
     27       WebSocket::MessageType type) {
     28     DCHECK(type == WebSocket::MESSAGE_TYPE_CONTINUATION ||
     29            type == WebSocket::MESSAGE_TYPE_TEXT ||
     30            type == WebSocket::MESSAGE_TYPE_BINARY);
     31     typedef net::WebSocketFrameHeader::OpCode OpCode;
     32     // These compile asserts verify that the same underlying values are used for
     33     // both types, so we can simply cast between them.
     34     COMPILE_ASSERT(static_cast<OpCode>(WebSocket::MESSAGE_TYPE_CONTINUATION) ==
     35                        net::WebSocketFrameHeader::kOpCodeContinuation,
     36                    enum_values_must_match_for_opcode_continuation);
     37     COMPILE_ASSERT(static_cast<OpCode>(WebSocket::MESSAGE_TYPE_TEXT) ==
     38                        net::WebSocketFrameHeader::kOpCodeText,
     39                    enum_values_must_match_for_opcode_text);
     40     COMPILE_ASSERT(static_cast<OpCode>(WebSocket::MESSAGE_TYPE_BINARY) ==
     41                        net::WebSocketFrameHeader::kOpCodeBinary,
     42                    enum_values_must_match_for_opcode_binary);
     43     return static_cast<OpCode>(type);
     44   }
     45 };
     46 
     47 template <>
     48 struct TypeConverter<WebSocket::MessageType,
     49                      net::WebSocketFrameHeader::OpCode> {
     50   static WebSocket::MessageType Convert(
     51       net::WebSocketFrameHeader::OpCode type) {
     52     DCHECK(type == net::WebSocketFrameHeader::kOpCodeContinuation ||
     53            type == net::WebSocketFrameHeader::kOpCodeText ||
     54            type == net::WebSocketFrameHeader::kOpCodeBinary);
     55     return static_cast<WebSocket::MessageType>(type);
     56   }
     57 };
     58 
     59 namespace {
     60 
     61 typedef net::WebSocketEventInterface::ChannelState ChannelState;
     62 
     63 struct WebSocketEventHandler : public net::WebSocketEventInterface {
     64  public:
     65   WebSocketEventHandler(WebSocketClientPtr client)
     66       : client_(client.Pass()) {
     67   }
     68   virtual ~WebSocketEventHandler() {}
     69 
     70  private:
     71   // net::WebSocketEventInterface methods:
     72   virtual ChannelState OnAddChannelResponse(
     73       bool fail,
     74       const std::string& selected_subprotocol,
     75       const std::string& extensions) OVERRIDE;
     76   virtual ChannelState OnDataFrame(bool fin,
     77                                    WebSocketMessageType type,
     78                                    const std::vector<char>& data) OVERRIDE;
     79   virtual ChannelState OnClosingHandshake() OVERRIDE;
     80   virtual ChannelState OnFlowControl(int64 quota) OVERRIDE;
     81   virtual ChannelState OnDropChannel(bool was_clean,
     82                                      uint16 code,
     83                                      const std::string& reason) OVERRIDE;
     84   virtual ChannelState OnFailChannel(const std::string& message) OVERRIDE;
     85   virtual ChannelState OnStartOpeningHandshake(
     86       scoped_ptr<net::WebSocketHandshakeRequestInfo> request) OVERRIDE;
     87   virtual ChannelState OnFinishOpeningHandshake(
     88       scoped_ptr<net::WebSocketHandshakeResponseInfo> response) OVERRIDE;
     89   virtual ChannelState OnSSLCertificateError(
     90       scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
     91       const GURL& url,
     92       const net::SSLInfo& ssl_info,
     93       bool fatal) OVERRIDE;
     94 
     95   // Called once we've written to |receive_stream_|.
     96   void DidWriteToReceiveStream(bool fin,
     97                                net::WebSocketFrameHeader::OpCode type,
     98                                uint32_t num_bytes,
     99                                const char* buffer);
    100   WebSocketClientPtr client_;
    101   ScopedDataPipeProducerHandle receive_stream_;
    102   scoped_ptr<WebSocketWriteQueue> write_queue_;
    103 
    104   DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler);
    105 };
    106 
    107 ChannelState WebSocketEventHandler::OnAddChannelResponse(
    108     bool fail,
    109     const std::string& selected_protocol,
    110     const std::string& extensions) {
    111   DataPipe data_pipe;
    112   receive_stream_ = data_pipe.producer_handle.Pass();
    113   write_queue_.reset(new WebSocketWriteQueue(receive_stream_.get()));
    114   client_->DidConnect(
    115       fail, selected_protocol, extensions, data_pipe.consumer_handle.Pass());
    116   if (fail)
    117     return WebSocketEventInterface::CHANNEL_DELETED;
    118   return WebSocketEventInterface::CHANNEL_ALIVE;
    119 }
    120 
    121 ChannelState WebSocketEventHandler::OnDataFrame(
    122     bool fin,
    123     net::WebSocketFrameHeader::OpCode type,
    124     const std::vector<char>& data) {
    125   uint32_t size = static_cast<uint32_t>(data.size());
    126   write_queue_->Write(
    127       &data[0], size,
    128       base::Bind(&WebSocketEventHandler::DidWriteToReceiveStream,
    129                  base::Unretained(this),
    130                  fin, type, size));
    131   return WebSocketEventInterface::CHANNEL_ALIVE;
    132 }
    133 
    134 ChannelState WebSocketEventHandler::OnClosingHandshake() {
    135   return WebSocketEventInterface::CHANNEL_ALIVE;
    136 }
    137 
    138 ChannelState WebSocketEventHandler::OnFlowControl(int64 quota) {
    139   client_->DidReceiveFlowControl(quota);
    140   return WebSocketEventInterface::CHANNEL_ALIVE;
    141 }
    142 
    143 ChannelState WebSocketEventHandler::OnDropChannel(bool was_clean,
    144                                                   uint16 code,
    145                                                   const std::string& reason) {
    146   client_->DidClose(was_clean, code, reason);
    147   return WebSocketEventInterface::CHANNEL_DELETED;
    148 }
    149 
    150 ChannelState WebSocketEventHandler::OnFailChannel(const std::string& message) {
    151   client_->DidFail(message);
    152   return WebSocketEventInterface::CHANNEL_DELETED;
    153 }
    154 
    155 ChannelState WebSocketEventHandler::OnStartOpeningHandshake(
    156     scoped_ptr<net::WebSocketHandshakeRequestInfo> request) {
    157   return WebSocketEventInterface::CHANNEL_ALIVE;
    158 }
    159 
    160 ChannelState WebSocketEventHandler::OnFinishOpeningHandshake(
    161     scoped_ptr<net::WebSocketHandshakeResponseInfo> response) {
    162   return WebSocketEventInterface::CHANNEL_ALIVE;
    163 }
    164 
    165 ChannelState WebSocketEventHandler::OnSSLCertificateError(
    166     scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
    167     const GURL& url,
    168     const net::SSLInfo& ssl_info,
    169     bool fatal) {
    170   client_->DidFail("SSL Error");
    171   return WebSocketEventInterface::CHANNEL_DELETED;
    172 }
    173 
    174 void WebSocketEventHandler::DidWriteToReceiveStream(
    175     bool fin,
    176     net::WebSocketFrameHeader::OpCode type,
    177     uint32_t num_bytes,
    178     const char* buffer) {
    179   client_->DidReceiveData(
    180       fin, ConvertTo<WebSocket::MessageType>(type), num_bytes);
    181 }
    182 
    183 }  // namespace mojo
    184 
    185 WebSocketImpl::WebSocketImpl(NetworkContext* context) : context_(context) {
    186 }
    187 
    188 WebSocketImpl::~WebSocketImpl() {
    189 }
    190 
    191 void WebSocketImpl::Connect(const String& url,
    192                             Array<String> protocols,
    193                             const String& origin,
    194                             ScopedDataPipeConsumerHandle send_stream,
    195                             WebSocketClientPtr client) {
    196   DCHECK(!channel_);
    197   send_stream_ = send_stream.Pass();
    198   read_queue_.reset(new WebSocketReadQueue(send_stream_.get()));
    199   scoped_ptr<net::WebSocketEventInterface> event_interface(
    200       new WebSocketEventHandler(client.Pass()));
    201   channel_.reset(new net::WebSocketChannel(event_interface.Pass(),
    202                                            context_->url_request_context()));
    203   channel_->SendAddChannelRequest(GURL(url.get()),
    204                                   protocols.To<std::vector<std::string> >(),
    205                                   url::Origin(origin.get()));
    206 }
    207 
    208 void WebSocketImpl::Send(bool fin,
    209                          WebSocket::MessageType type,
    210                          uint32_t num_bytes) {
    211   DCHECK(channel_);
    212   read_queue_->Read(num_bytes,
    213                     base::Bind(&WebSocketImpl::DidReadFromSendStream,
    214                                base::Unretained(this),
    215                                fin, type, num_bytes));
    216 }
    217 
    218 void WebSocketImpl::FlowControl(int64_t quota) {
    219   DCHECK(channel_);
    220   channel_->SendFlowControl(quota);
    221 }
    222 
    223 void WebSocketImpl::Close(uint16_t code, const String& reason) {
    224   DCHECK(channel_);
    225   channel_->StartClosingHandshake(code, reason);
    226 }
    227 
    228 void WebSocketImpl::DidReadFromSendStream(bool fin,
    229                                           WebSocket::MessageType type,
    230                                           uint32_t num_bytes,
    231                                           const char* data) {
    232   std::vector<char> buffer(num_bytes);
    233   memcpy(&buffer[0], data, num_bytes);
    234   DCHECK(channel_);
    235   channel_->SendFrame(
    236       fin, ConvertTo<net::WebSocketFrameHeader::OpCode>(type), buffer);
    237 }
    238 
    239 }  // namespace mojo
    240