1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/websockets/websocket_stream.h" 6 7 #include "base/logging.h" 8 #include "base/memory/scoped_ptr.h" 9 #include "net/http/http_request_headers.h" 10 #include "net/http/http_status_code.h" 11 #include "net/url_request/url_request.h" 12 #include "net/url_request/url_request_context.h" 13 #include "net/websockets/websocket_errors.h" 14 #include "net/websockets/websocket_handshake_constants.h" 15 #include "net/websockets/websocket_handshake_stream_base.h" 16 #include "net/websockets/websocket_handshake_stream_create_helper.h" 17 #include "net/websockets/websocket_test_util.h" 18 #include "url/gurl.h" 19 20 namespace net { 21 namespace { 22 23 class StreamRequestImpl; 24 25 class Delegate : public URLRequest::Delegate { 26 public: 27 explicit Delegate(StreamRequestImpl* owner) : owner_(owner) {} 28 virtual ~Delegate() {} 29 30 // Implementation of URLRequest::Delegate methods. 31 virtual void OnResponseStarted(URLRequest* request) OVERRIDE; 32 33 virtual void OnAuthRequired(URLRequest* request, 34 AuthChallengeInfo* auth_info) OVERRIDE; 35 36 virtual void OnCertificateRequested(URLRequest* request, 37 SSLCertRequestInfo* cert_request_info) 38 OVERRIDE; 39 40 virtual void OnSSLCertificateError(URLRequest* request, 41 const SSLInfo& ssl_info, 42 bool fatal) OVERRIDE; 43 44 virtual void OnReadCompleted(URLRequest* request, int bytes_read) OVERRIDE; 45 46 private: 47 StreamRequestImpl* owner_; 48 }; 49 50 class StreamRequestImpl : public WebSocketStreamRequest { 51 public: 52 StreamRequestImpl( 53 const GURL& url, 54 const URLRequestContext* context, 55 scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate, 56 WebSocketHandshakeStreamCreateHelper* create_helper) 57 : delegate_(new Delegate(this)), 58 url_request_(url, DEFAULT_PRIORITY, delegate_.get(), context), 59 connect_delegate_(connect_delegate.Pass()), 60 create_helper_(create_helper) {} 61 62 // Destroying this object destroys the URLRequest, which cancels the request 63 // and so terminates the handshake if it is incomplete. 64 virtual ~StreamRequestImpl() {} 65 66 URLRequest* url_request() { return &url_request_; } 67 68 void PerformUpgrade() { 69 connect_delegate_->OnSuccess(create_helper_->stream()->Upgrade()); 70 } 71 72 void ReportFailure() { 73 connect_delegate_->OnFailure(kWebSocketErrorAbnormalClosure); 74 } 75 76 private: 77 // |delegate_| needs to be declared before |url_request_| so that it gets 78 // initialised first. 79 scoped_ptr<Delegate> delegate_; 80 81 // Deleting the StreamRequestImpl object deletes this URLRequest object, 82 // cancelling the whole connection. 83 URLRequest url_request_; 84 85 scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate_; 86 87 // Owned by the URLRequest. 88 WebSocketHandshakeStreamCreateHelper* create_helper_; 89 }; 90 91 void Delegate::OnResponseStarted(URLRequest* request) { 92 switch (request->GetResponseCode()) { 93 case HTTP_SWITCHING_PROTOCOLS: 94 owner_->PerformUpgrade(); 95 return; 96 97 case HTTP_UNAUTHORIZED: 98 case HTTP_PROXY_AUTHENTICATION_REQUIRED: 99 return; 100 101 default: 102 owner_->ReportFailure(); 103 } 104 } 105 106 void Delegate::OnAuthRequired(URLRequest* request, 107 AuthChallengeInfo* auth_info) { 108 request->CancelAuth(); 109 } 110 111 void Delegate::OnCertificateRequested(URLRequest* request, 112 SSLCertRequestInfo* cert_request_info) { 113 request->ContinueWithCertificate(NULL); 114 } 115 116 void Delegate::OnSSLCertificateError(URLRequest* request, 117 const SSLInfo& ssl_info, 118 bool fatal) { 119 request->Cancel(); 120 } 121 122 void Delegate::OnReadCompleted(URLRequest* request, int bytes_read) { 123 NOTREACHED(); 124 } 125 126 // Internal implementation of CreateAndConnectStream and 127 // CreateAndConnectStreamForTesting. 128 scoped_ptr<WebSocketStreamRequest> CreateAndConnectStreamWithCreateHelper( 129 const GURL& socket_url, 130 scoped_ptr<WebSocketHandshakeStreamCreateHelper> create_helper, 131 const GURL& origin, 132 URLRequestContext* url_request_context, 133 const BoundNetLog& net_log, 134 scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate) { 135 scoped_ptr<StreamRequestImpl> request( 136 new StreamRequestImpl(socket_url, 137 url_request_context, 138 connect_delegate.Pass(), 139 create_helper.get())); 140 HttpRequestHeaders headers; 141 headers.SetHeader(websockets::kUpgrade, websockets::kWebSocketLowercase); 142 headers.SetHeader(HttpRequestHeaders::kConnection, websockets::kUpgrade); 143 headers.SetHeader(HttpRequestHeaders::kOrigin, origin.spec()); 144 // TODO(ricea): Move the version number to websocket_handshake_constants.h 145 headers.SetHeader(websockets::kSecWebSocketVersion, 146 websockets::kSupportedVersion); 147 request->url_request()->SetExtraRequestHeaders(headers); 148 request->url_request()->SetUserData( 149 WebSocketHandshakeStreamBase::CreateHelper::DataKey(), 150 create_helper.release()); 151 request->url_request()->SetLoadFlags(LOAD_DISABLE_CACHE | 152 LOAD_DO_NOT_PROMPT_FOR_LOGIN); 153 request->url_request()->Start(); 154 return request.PassAs<WebSocketStreamRequest>(); 155 } 156 157 } // namespace 158 159 WebSocketStreamRequest::~WebSocketStreamRequest() {} 160 161 WebSocketStream::WebSocketStream() {} 162 WebSocketStream::~WebSocketStream() {} 163 164 WebSocketStream::ConnectDelegate::~ConnectDelegate() {} 165 166 scoped_ptr<WebSocketStreamRequest> WebSocketStream::CreateAndConnectStream( 167 const GURL& socket_url, 168 const std::vector<std::string>& requested_subprotocols, 169 const GURL& origin, 170 URLRequestContext* url_request_context, 171 const BoundNetLog& net_log, 172 scoped_ptr<ConnectDelegate> connect_delegate) { 173 scoped_ptr<WebSocketHandshakeStreamCreateHelper> create_helper( 174 new WebSocketHandshakeStreamCreateHelper(requested_subprotocols)); 175 return CreateAndConnectStreamWithCreateHelper(socket_url, 176 create_helper.Pass(), 177 origin, 178 url_request_context, 179 net_log, 180 connect_delegate.Pass()); 181 } 182 183 // This is declared in websocket_test_util.h. 184 scoped_ptr<WebSocketStreamRequest> CreateAndConnectStreamForTesting( 185 const GURL& socket_url, 186 scoped_ptr<WebSocketHandshakeStreamCreateHelper> create_helper, 187 const GURL& origin, 188 URLRequestContext* url_request_context, 189 const BoundNetLog& net_log, 190 scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate) { 191 return CreateAndConnectStreamWithCreateHelper(socket_url, 192 create_helper.Pass(), 193 origin, 194 url_request_context, 195 net_log, 196 connect_delegate.Pass()); 197 } 198 199 } // namespace net 200