Home | History | Annotate | Download | only in websockets
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/websockets/websocket_stream.h"
      6 
      7 #include "base/logging.h"
      8 #include "base/memory/scoped_ptr.h"
      9 #include "net/http/http_request_headers.h"
     10 #include "net/http/http_status_code.h"
     11 #include "net/url_request/url_request.h"
     12 #include "net/url_request/url_request_context.h"
     13 #include "net/websockets/websocket_errors.h"
     14 #include "net/websockets/websocket_handshake_constants.h"
     15 #include "net/websockets/websocket_handshake_stream_base.h"
     16 #include "net/websockets/websocket_handshake_stream_create_helper.h"
     17 #include "net/websockets/websocket_test_util.h"
     18 #include "url/gurl.h"
     19 
     20 namespace net {
     21 namespace {
     22 
     23 class StreamRequestImpl;
     24 
     25 class Delegate : public URLRequest::Delegate {
     26  public:
     27   explicit Delegate(StreamRequestImpl* owner) : owner_(owner) {}
     28   virtual ~Delegate() {}
     29 
     30   // Implementation of URLRequest::Delegate methods.
     31   virtual void OnResponseStarted(URLRequest* request) OVERRIDE;
     32 
     33   virtual void OnAuthRequired(URLRequest* request,
     34                               AuthChallengeInfo* auth_info) OVERRIDE;
     35 
     36   virtual void OnCertificateRequested(URLRequest* request,
     37                                       SSLCertRequestInfo* cert_request_info)
     38       OVERRIDE;
     39 
     40   virtual void OnSSLCertificateError(URLRequest* request,
     41                                      const SSLInfo& ssl_info,
     42                                      bool fatal) OVERRIDE;
     43 
     44   virtual void OnReadCompleted(URLRequest* request, int bytes_read) OVERRIDE;
     45 
     46  private:
     47   StreamRequestImpl* owner_;
     48 };
     49 
     50 class StreamRequestImpl : public WebSocketStreamRequest {
     51  public:
     52   StreamRequestImpl(
     53       const GURL& url,
     54       const URLRequestContext* context,
     55       scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate,
     56       WebSocketHandshakeStreamCreateHelper* create_helper)
     57       : delegate_(new Delegate(this)),
     58         url_request_(url, DEFAULT_PRIORITY, delegate_.get(), context),
     59         connect_delegate_(connect_delegate.Pass()),
     60         create_helper_(create_helper) {}
     61 
     62   // Destroying this object destroys the URLRequest, which cancels the request
     63   // and so terminates the handshake if it is incomplete.
     64   virtual ~StreamRequestImpl() {}
     65 
     66   URLRequest* url_request() { return &url_request_; }
     67 
     68   void PerformUpgrade() {
     69     connect_delegate_->OnSuccess(create_helper_->stream()->Upgrade());
     70   }
     71 
     72   void ReportFailure() {
     73     connect_delegate_->OnFailure(kWebSocketErrorAbnormalClosure);
     74   }
     75 
     76  private:
     77   // |delegate_| needs to be declared before |url_request_| so that it gets
     78   // initialised first.
     79   scoped_ptr<Delegate> delegate_;
     80 
     81   // Deleting the StreamRequestImpl object deletes this URLRequest object,
     82   // cancelling the whole connection.
     83   URLRequest url_request_;
     84 
     85   scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate_;
     86 
     87   // Owned by the URLRequest.
     88   WebSocketHandshakeStreamCreateHelper* create_helper_;
     89 };
     90 
     91 void Delegate::OnResponseStarted(URLRequest* request) {
     92   switch (request->GetResponseCode()) {
     93     case HTTP_SWITCHING_PROTOCOLS:
     94       owner_->PerformUpgrade();
     95       return;
     96 
     97     case HTTP_UNAUTHORIZED:
     98     case HTTP_PROXY_AUTHENTICATION_REQUIRED:
     99       return;
    100 
    101     default:
    102       owner_->ReportFailure();
    103   }
    104 }
    105 
    106 void Delegate::OnAuthRequired(URLRequest* request,
    107                               AuthChallengeInfo* auth_info) {
    108   request->CancelAuth();
    109 }
    110 
    111 void Delegate::OnCertificateRequested(URLRequest* request,
    112                                       SSLCertRequestInfo* cert_request_info) {
    113   request->ContinueWithCertificate(NULL);
    114 }
    115 
    116 void Delegate::OnSSLCertificateError(URLRequest* request,
    117                                      const SSLInfo& ssl_info,
    118                                      bool fatal) {
    119   request->Cancel();
    120 }
    121 
    122 void Delegate::OnReadCompleted(URLRequest* request, int bytes_read) {
    123   NOTREACHED();
    124 }
    125 
    126 // Internal implementation of CreateAndConnectStream and
    127 // CreateAndConnectStreamForTesting.
    128 scoped_ptr<WebSocketStreamRequest> CreateAndConnectStreamWithCreateHelper(
    129     const GURL& socket_url,
    130     scoped_ptr<WebSocketHandshakeStreamCreateHelper> create_helper,
    131     const GURL& origin,
    132     URLRequestContext* url_request_context,
    133     const BoundNetLog& net_log,
    134     scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate) {
    135   scoped_ptr<StreamRequestImpl> request(
    136       new StreamRequestImpl(socket_url,
    137                             url_request_context,
    138                             connect_delegate.Pass(),
    139                             create_helper.get()));
    140   HttpRequestHeaders headers;
    141   headers.SetHeader(websockets::kUpgrade, websockets::kWebSocketLowercase);
    142   headers.SetHeader(HttpRequestHeaders::kConnection, websockets::kUpgrade);
    143   headers.SetHeader(HttpRequestHeaders::kOrigin, origin.spec());
    144   // TODO(ricea): Move the version number to websocket_handshake_constants.h
    145   headers.SetHeader(websockets::kSecWebSocketVersion,
    146                     websockets::kSupportedVersion);
    147   request->url_request()->SetExtraRequestHeaders(headers);
    148   request->url_request()->SetUserData(
    149       WebSocketHandshakeStreamBase::CreateHelper::DataKey(),
    150       create_helper.release());
    151   request->url_request()->SetLoadFlags(LOAD_DISABLE_CACHE |
    152                                        LOAD_DO_NOT_PROMPT_FOR_LOGIN);
    153   request->url_request()->Start();
    154   return request.PassAs<WebSocketStreamRequest>();
    155 }
    156 
    157 }  // namespace
    158 
    159 WebSocketStreamRequest::~WebSocketStreamRequest() {}
    160 
    161 WebSocketStream::WebSocketStream() {}
    162 WebSocketStream::~WebSocketStream() {}
    163 
    164 WebSocketStream::ConnectDelegate::~ConnectDelegate() {}
    165 
    166 scoped_ptr<WebSocketStreamRequest> WebSocketStream::CreateAndConnectStream(
    167     const GURL& socket_url,
    168     const std::vector<std::string>& requested_subprotocols,
    169     const GURL& origin,
    170     URLRequestContext* url_request_context,
    171     const BoundNetLog& net_log,
    172     scoped_ptr<ConnectDelegate> connect_delegate) {
    173   scoped_ptr<WebSocketHandshakeStreamCreateHelper> create_helper(
    174       new WebSocketHandshakeStreamCreateHelper(requested_subprotocols));
    175   return CreateAndConnectStreamWithCreateHelper(socket_url,
    176                                                 create_helper.Pass(),
    177                                                 origin,
    178                                                 url_request_context,
    179                                                 net_log,
    180                                                 connect_delegate.Pass());
    181 }
    182 
    183 // This is declared in websocket_test_util.h.
    184 scoped_ptr<WebSocketStreamRequest> CreateAndConnectStreamForTesting(
    185       const GURL& socket_url,
    186       scoped_ptr<WebSocketHandshakeStreamCreateHelper> create_helper,
    187       const GURL& origin,
    188       URLRequestContext* url_request_context,
    189       const BoundNetLog& net_log,
    190       scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate) {
    191   return CreateAndConnectStreamWithCreateHelper(socket_url,
    192                                                 create_helper.Pass(),
    193                                                 origin,
    194                                                 url_request_context,
    195                                                 net_log,
    196                                                 connect_delegate.Pass());
    197 }
    198 
    199 }  // namespace net
    200