1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/websockets/websocket_handshake.h" 6 7 #include <algorithm> 8 #include <vector> 9 10 #include "base/logging.h" 11 #include "base/md5.h" 12 #include "base/memory/ref_counted.h" 13 #include "base/rand_util.h" 14 #include "base/string_number_conversions.h" 15 #include "base/string_util.h" 16 #include "base/stringprintf.h" 17 #include "net/http/http_response_headers.h" 18 #include "net/http/http_util.h" 19 20 namespace net { 21 22 const int WebSocketHandshake::kWebSocketPort = 80; 23 const int WebSocketHandshake::kSecureWebSocketPort = 443; 24 25 WebSocketHandshake::WebSocketHandshake( 26 const GURL& url, 27 const std::string& origin, 28 const std::string& location, 29 const std::string& protocol) 30 : url_(url), 31 origin_(origin), 32 location_(location), 33 protocol_(protocol), 34 mode_(MODE_INCOMPLETE) { 35 } 36 37 WebSocketHandshake::~WebSocketHandshake() { 38 } 39 40 bool WebSocketHandshake::is_secure() const { 41 return url_.SchemeIs("wss"); 42 } 43 44 std::string WebSocketHandshake::CreateClientHandshakeMessage() { 45 if (!parameter_.get()) { 46 parameter_.reset(new Parameter); 47 parameter_->GenerateKeys(); 48 } 49 std::string msg; 50 51 // WebSocket protocol 4.1 Opening handshake. 52 53 msg = "GET "; 54 msg += GetResourceName(); 55 msg += " HTTP/1.1\r\n"; 56 57 std::vector<std::string> fields; 58 59 fields.push_back("Upgrade: WebSocket"); 60 fields.push_back("Connection: Upgrade"); 61 62 fields.push_back("Host: " + GetHostFieldValue()); 63 64 fields.push_back("Origin: " + GetOriginFieldValue()); 65 66 if (!protocol_.empty()) 67 fields.push_back("Sec-WebSocket-Protocol: " + protocol_); 68 69 // TODO(ukai): Add cookie if necessary. 70 71 fields.push_back("Sec-WebSocket-Key1: " + parameter_->GetSecWebSocketKey1()); 72 fields.push_back("Sec-WebSocket-Key2: " + parameter_->GetSecWebSocketKey2()); 73 74 std::random_shuffle(fields.begin(), fields.end(), base::RandGenerator); 75 76 for (size_t i = 0; i < fields.size(); i++) { 77 msg += fields[i] + "\r\n"; 78 } 79 msg += "\r\n"; 80 81 msg.append(parameter_->GetKey3()); 82 return msg; 83 } 84 85 int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) { 86 mode_ = MODE_INCOMPLETE; 87 int eoh = HttpUtil::LocateEndOfHeaders(data, len); 88 if (eoh < 0) 89 return -1; 90 91 scoped_refptr<HttpResponseHeaders> headers( 92 new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh))); 93 94 if (headers->response_code() != 101) { 95 mode_ = MODE_FAILED; 96 DVLOG(1) << "Bad response code: " << headers->response_code(); 97 return eoh; 98 } 99 mode_ = MODE_NORMAL; 100 if (!ProcessHeaders(*headers) || !CheckResponseHeaders()) { 101 DVLOG(1) << "Process Headers failed: " << std::string(data, eoh); 102 mode_ = MODE_FAILED; 103 return eoh; 104 } 105 if (len < static_cast<size_t>(eoh + Parameter::kExpectedResponseSize)) { 106 mode_ = MODE_INCOMPLETE; 107 return -1; 108 } 109 uint8 expected[Parameter::kExpectedResponseSize]; 110 parameter_->GetExpectedResponse(expected); 111 if (memcmp(&data[eoh], expected, Parameter::kExpectedResponseSize)) { 112 mode_ = MODE_FAILED; 113 return eoh + Parameter::kExpectedResponseSize; 114 } 115 mode_ = MODE_CONNECTED; 116 return eoh + Parameter::kExpectedResponseSize; 117 } 118 119 std::string WebSocketHandshake::GetResourceName() const { 120 std::string resource_name = url_.path(); 121 if (url_.has_query()) { 122 resource_name += "?"; 123 resource_name += url_.query(); 124 } 125 return resource_name; 126 } 127 128 std::string WebSocketHandshake::GetHostFieldValue() const { 129 // url_.host() is expected to be encoded in punnycode here. 130 std::string host = StringToLowerASCII(url_.host()); 131 if (url_.has_port()) { 132 bool secure = is_secure(); 133 int port = url_.EffectiveIntPort(); 134 if ((!secure && 135 port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) || 136 (secure && 137 port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) { 138 host += ":"; 139 host += base::IntToString(port); 140 } 141 } 142 return host; 143 } 144 145 std::string WebSocketHandshake::GetOriginFieldValue() const { 146 // It's OK to lowercase the origin as the Origin header does not contain 147 // the path or query portions, as per 148 // http://tools.ietf.org/html/draft-abarth-origin-00. 149 // 150 // TODO(satorux): Should we trim the port portion here if it's 80 for 151 // http:// or 443 for https:// ? Or can we assume it's done by the 152 // client of the library? 153 return StringToLowerASCII(origin_); 154 } 155 156 /* static */ 157 bool WebSocketHandshake::GetSingleHeader(const HttpResponseHeaders& headers, 158 const std::string& name, 159 std::string* value) { 160 std::string first_value; 161 void* iter = NULL; 162 if (!headers.EnumerateHeader(&iter, name, &first_value)) 163 return false; 164 165 // Checks no more |name| found in |headers|. 166 // Second call of EnumerateHeader() must return false. 167 std::string second_value; 168 if (headers.EnumerateHeader(&iter, name, &second_value)) 169 return false; 170 *value = first_value; 171 return true; 172 } 173 174 bool WebSocketHandshake::ProcessHeaders(const HttpResponseHeaders& headers) { 175 std::string value; 176 if (!GetSingleHeader(headers, "upgrade", &value) || 177 value != "WebSocket") 178 return false; 179 180 if (!GetSingleHeader(headers, "connection", &value) || 181 !LowerCaseEqualsASCII(value, "upgrade")) 182 return false; 183 184 if (!GetSingleHeader(headers, "sec-websocket-origin", &ws_origin_)) 185 return false; 186 187 if (!GetSingleHeader(headers, "sec-websocket-location", &ws_location_)) 188 return false; 189 190 // If |protocol_| is not specified by client, we don't care if there's 191 // protocol field or not as specified in the spec. 192 if (!protocol_.empty() 193 && !GetSingleHeader(headers, "sec-websocket-protocol", &ws_protocol_)) 194 return false; 195 return true; 196 } 197 198 bool WebSocketHandshake::CheckResponseHeaders() const { 199 DCHECK(mode_ == MODE_NORMAL); 200 if (!LowerCaseEqualsASCII(origin_, ws_origin_.c_str())) 201 return false; 202 if (location_ != ws_location_) 203 return false; 204 if (!protocol_.empty() && protocol_ != ws_protocol_) 205 return false; 206 return true; 207 } 208 209 namespace { 210 211 // unsigned int version of base::RandInt(). 212 // we can't use base::RandInt(), because max would be negative if it is 213 // represented as int, so DCHECK(min <= max) fails. 214 uint32 RandUint32(uint32 min, uint32 max) { 215 DCHECK(min <= max); 216 217 uint64 range = static_cast<int64>(max) - min + 1; 218 uint64 number = base::RandUint64(); 219 // TODO(ukai): fix to be uniform. 220 // the distribution of the result of modulo will be biased. 221 uint32 result = min + static_cast<uint32>(number % range); 222 DCHECK(result >= min && result <= max); 223 return result; 224 } 225 226 } 227 228 uint32 (*WebSocketHandshake::Parameter::rand_)(uint32 min, uint32 max) = 229 RandUint32; 230 uint8 randomCharacterInSecWebSocketKey[0x2F - 0x20 + 0x7E - 0x39]; 231 232 WebSocketHandshake::Parameter::Parameter() 233 : number_1_(0), number_2_(0) { 234 if (randomCharacterInSecWebSocketKey[0] == '\0') { 235 int i = 0; 236 for (int ch = 0x21; ch <= 0x2F; ch++, i++) 237 randomCharacterInSecWebSocketKey[i] = ch; 238 for (int ch = 0x3A; ch <= 0x7E; ch++, i++) 239 randomCharacterInSecWebSocketKey[i] = ch; 240 } 241 } 242 243 WebSocketHandshake::Parameter::~Parameter() {} 244 245 void WebSocketHandshake::Parameter::GenerateKeys() { 246 GenerateSecWebSocketKey(&number_1_, &key_1_); 247 GenerateSecWebSocketKey(&number_2_, &key_2_); 248 GenerateKey3(); 249 } 250 251 static void SetChallengeNumber(uint8* buf, uint32 number) { 252 uint8* p = buf + 3; 253 for (int i = 0; i < 4; i++) { 254 *p = (uint8)(number & 0xFF); 255 --p; 256 number >>= 8; 257 } 258 } 259 260 void WebSocketHandshake::Parameter::GetExpectedResponse(uint8 *expected) const { 261 uint8 challenge[kExpectedResponseSize]; 262 SetChallengeNumber(&challenge[0], number_1_); 263 SetChallengeNumber(&challenge[4], number_2_); 264 memcpy(&challenge[8], key_3_.data(), kKey3Size); 265 MD5Digest digest; 266 MD5Sum(challenge, kExpectedResponseSize, &digest); 267 memcpy(expected, digest.a, kExpectedResponseSize); 268 } 269 270 /* static */ 271 void WebSocketHandshake::Parameter::SetRandomNumberGenerator( 272 uint32 (*rand)(uint32 min, uint32 max)) { 273 rand_ = rand; 274 } 275 276 void WebSocketHandshake::Parameter::GenerateSecWebSocketKey( 277 uint32* number, std::string* key) { 278 uint32 space = rand_(1, 12); 279 uint32 max = 4294967295U / space; 280 *number = rand_(0, max); 281 uint32 product = *number * space; 282 283 std::string s = base::StringPrintf("%u", product); 284 int n = rand_(1, 12); 285 for (int i = 0; i < n; i++) { 286 int pos = rand_(0, s.length()); 287 int chpos = rand_(0, sizeof(randomCharacterInSecWebSocketKey) - 1); 288 s = s.substr(0, pos).append(1, randomCharacterInSecWebSocketKey[chpos]) + 289 s.substr(pos); 290 } 291 for (uint32 i = 0; i < space; i++) { 292 int pos = rand_(1, s.length() - 1); 293 s = s.substr(0, pos) + " " + s.substr(pos); 294 } 295 *key = s; 296 } 297 298 void WebSocketHandshake::Parameter::GenerateKey3() { 299 key_3_.clear(); 300 for (int i = 0; i < 8; i++) { 301 key_3_.append(1, rand_(0, 255)); 302 } 303 } 304 305 } // namespace net 306