Home | History | Annotate | Download | only in websockets
      1 /*
      2  * Copyright (C) 2011 Google Inc.  All rights reserved.
      3  * Copyright (C) Research In Motion Limited 2011. All rights reserved.
      4  *
      5  * Redistribution and use in source and binary forms, with or without
      6  * modification, are permitted provided that the following conditions are
      7  * met:
      8  *
      9  *     * Redistributions of source code must retain the above copyright
     10  * notice, this list of conditions and the following disclaimer.
     11  *     * Redistributions in binary form must reproduce the above
     12  * copyright notice, this list of conditions and the following disclaimer
     13  * in the documentation and/or other materials provided with the
     14  * distribution.
     15  *     * Neither the name of Google Inc. nor the names of its
     16  * contributors may be used to endorse or promote products derived from
     17  * this software without specific prior written permission.
     18  *
     19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     20  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     21  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     22  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     23  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     24  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     25  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     26  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     27  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     28  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     29  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     30  */
     31 
     32 #include "config.h"
     33 
     34 #include "modules/websockets/WebSocketHandshake.h"
     35 
     36 #include "core/dom/Document.h"
     37 #include "core/inspector/ScriptCallStack.h"
     38 #include "core/loader/CookieJar.h"
     39 #include "modules/websockets/DOMWebSocket.h"
     40 #include "platform/Cookie.h"
     41 #include "platform/Crypto.h"
     42 #include "platform/Logging.h"
     43 #include "platform/network/HTTPHeaderMap.h"
     44 #include "platform/network/HTTPParsers.h"
     45 #include "platform/weborigin/SecurityOrigin.h"
     46 #include "public/platform/Platform.h"
     47 #include "wtf/CryptographicallyRandomNumber.h"
     48 #include "wtf/StdLibExtras.h"
     49 #include "wtf/StringExtras.h"
     50 #include "wtf/Vector.h"
     51 #include "wtf/text/Base64.h"
     52 #include "wtf/text/CString.h"
     53 #include "wtf/text/StringBuilder.h"
     54 #include "wtf/unicode/CharacterNames.h"
     55 
     56 namespace blink {
     57 
     58 String formatHandshakeFailureReason(const String& detail)
     59 {
     60     return "Error during WebSocket handshake: " + detail;
     61 }
     62 
     63 static String resourceName(const KURL& url)
     64 {
     65     StringBuilder name;
     66     name.append(url.path());
     67     if (name.isEmpty())
     68         name.append('/');
     69     if (!url.query().isNull()) {
     70         name.append('?');
     71         name.append(url.query());
     72     }
     73     String result = name.toString();
     74     ASSERT(!result.isEmpty());
     75     ASSERT(!result.contains(' '));
     76     return result;
     77 }
     78 
     79 static String hostName(const KURL& url, bool secure)
     80 {
     81     ASSERT(url.protocolIs("wss") == secure);
     82     StringBuilder builder;
     83     builder.append(url.host().lower());
     84     if (url.port() && ((!secure && url.port() != 80) || (secure && url.port() != 443))) {
     85         builder.append(':');
     86         builder.appendNumber(url.port());
     87     }
     88     return builder.toString();
     89 }
     90 
     91 static const size_t maxInputSampleSize = 128;
     92 static String trimInputSample(const char* p, size_t len)
     93 {
     94     if (len > maxInputSampleSize)
     95         return String(p, maxInputSampleSize) + horizontalEllipsis;
     96     return String(p, len);
     97 }
     98 
     99 static String generateSecWebSocketKey()
    100 {
    101     static const size_t nonceSize = 16;
    102     unsigned char key[nonceSize];
    103     cryptographicallyRandomValues(key, nonceSize);
    104     return base64Encode(reinterpret_cast<char*>(key), nonceSize);
    105 }
    106 
    107 String WebSocketHandshake::getExpectedWebSocketAccept(const String& secWebSocketKey)
    108 {
    109     static const char webSocketKeyGUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
    110     CString keyData = secWebSocketKey.ascii();
    111 
    112     StringBuilder digestable;
    113     digestable.append(secWebSocketKey);
    114     digestable.append(webSocketKeyGUID, strlen(webSocketKeyGUID));
    115     CString digestableCString = digestable.toString().utf8();
    116     DigestValue digest;
    117     bool digestSuccess = computeDigest(HashAlgorithmSha1, digestableCString.data(), digestableCString.length(), digest);
    118     RELEASE_ASSERT(digestSuccess);
    119 
    120     return base64Encode(reinterpret_cast<const char*>(digest.data()), sha1HashSize);
    121 }
    122 
    123 WebSocketHandshake::WebSocketHandshake(const KURL& url, const String& protocol, Document* document)
    124     : m_url(url)
    125     , m_clientProtocol(protocol)
    126     , m_secure(m_url.protocolIs("wss"))
    127     , m_document(document)
    128     , m_mode(Incomplete)
    129 {
    130     m_secWebSocketKey = generateSecWebSocketKey();
    131     m_expectedAccept = getExpectedWebSocketAccept(m_secWebSocketKey);
    132 }
    133 
    134 WebSocketHandshake::~WebSocketHandshake()
    135 {
    136     Platform::current()->histogramEnumeration("WebCore.WebSocket.HandshakeResult", m_mode, WebSocketHandshake::ModeMax);
    137 }
    138 
    139 const KURL& WebSocketHandshake::url() const
    140 {
    141     return m_url;
    142 }
    143 
    144 void WebSocketHandshake::setURL(const KURL& url)
    145 {
    146     m_url = url.copy();
    147 }
    148 
    149 const String WebSocketHandshake::host() const
    150 {
    151     return m_url.host().lower();
    152 }
    153 
    154 const String& WebSocketHandshake::clientProtocol() const
    155 {
    156     return m_clientProtocol;
    157 }
    158 
    159 void WebSocketHandshake::setClientProtocol(const String& protocol)
    160 {
    161     m_clientProtocol = protocol;
    162 }
    163 
    164 bool WebSocketHandshake::secure() const
    165 {
    166     return m_secure;
    167 }
    168 
    169 String WebSocketHandshake::clientOrigin() const
    170 {
    171     return m_document->securityOrigin()->toString();
    172 }
    173 
    174 String WebSocketHandshake::clientLocation() const
    175 {
    176     StringBuilder builder;
    177     if (m_secure)
    178         builder.appendLiteral("wss");
    179     else
    180         builder.appendLiteral("ws");
    181     builder.appendLiteral("://");
    182     builder.append(hostName(m_url, m_secure));
    183     builder.append(resourceName(m_url));
    184     return builder.toString();
    185 }
    186 
    187 CString WebSocketHandshake::clientHandshakeMessage() const
    188 {
    189     ASSERT(m_document);
    190 
    191     // Keep the following consistent with clientHandshakeRequest().
    192     StringBuilder builder;
    193 
    194     builder.appendLiteral("GET ");
    195     builder.append(resourceName(m_url));
    196     builder.appendLiteral(" HTTP/1.1\r\n");
    197 
    198     Vector<String> fields;
    199     fields.append("Upgrade: websocket");
    200     fields.append("Connection: Upgrade");
    201     fields.append("Host: " + hostName(m_url, m_secure));
    202     fields.append("Origin: " + clientOrigin());
    203     if (!m_clientProtocol.isEmpty())
    204         fields.append("Sec-WebSocket-Protocol: " + m_clientProtocol);
    205 
    206     // Add no-cache headers to avoid compatibility issue.
    207     // There are some proxies that rewrite "Connection: upgrade"
    208     // to "Connection: close" in the response if a request doesn't contain
    209     // these headers.
    210     fields.append("Pragma: no-cache");
    211     fields.append("Cache-Control: no-cache");
    212 
    213     fields.append("Sec-WebSocket-Key: " + m_secWebSocketKey);
    214     fields.append("Sec-WebSocket-Version: 13");
    215     const String extensionValue = m_extensionDispatcher.createHeaderValue();
    216     if (extensionValue.length())
    217         fields.append("Sec-WebSocket-Extensions: " + extensionValue);
    218 
    219     fields.append("User-Agent: " + m_document->userAgent(m_document->url()));
    220 
    221     // Fields in the handshake are sent by the client in a random order; the
    222     // order is not meaningful. Thus, it's ok to send the order we constructed
    223     // the fields.
    224 
    225     for (size_t i = 0; i < fields.size(); i++) {
    226         builder.append(fields[i]);
    227         builder.appendLiteral("\r\n");
    228     }
    229 
    230     builder.appendLiteral("\r\n");
    231 
    232     return builder.toString().utf8();
    233 }
    234 
    235 PassRefPtr<WebSocketHandshakeRequest> WebSocketHandshake::clientHandshakeRequest() const
    236 {
    237     ASSERT(m_document);
    238 
    239     // Keep the following consistent with clientHandshakeMessage().
    240     // FIXME: do we need to store m_secWebSocketKey1, m_secWebSocketKey2 and
    241     // m_key3 in WebSocketHandshakeRequest?
    242     RefPtr<WebSocketHandshakeRequest> request = WebSocketHandshakeRequest::create(m_url);
    243     request->addHeaderField("Upgrade", "websocket");
    244     request->addHeaderField("Connection", "Upgrade");
    245     request->addHeaderField("Host", AtomicString(hostName(m_url, m_secure)));
    246     request->addHeaderField("Origin", AtomicString(clientOrigin()));
    247     if (!m_clientProtocol.isEmpty())
    248         request->addHeaderField("Sec-WebSocket-Protocol", AtomicString(m_clientProtocol));
    249 
    250     KURL url = httpURLForAuthenticationAndCookies();
    251 
    252     String cookie = cookieRequestHeaderFieldValue(m_document, url);
    253     if (!cookie.isEmpty())
    254         request->addHeaderField("Cookie", AtomicString(cookie));
    255     // Set "Cookie2: <cookie>" if cookies 2 exists for url?
    256 
    257     request->addHeaderField("Pragma", "no-cache");
    258     request->addHeaderField("Cache-Control", "no-cache");
    259 
    260     request->addHeaderField("Sec-WebSocket-Key", AtomicString(m_secWebSocketKey));
    261     request->addHeaderField("Sec-WebSocket-Version", "13");
    262     const String extensionValue = m_extensionDispatcher.createHeaderValue();
    263     if (extensionValue.length())
    264         request->addHeaderField("Sec-WebSocket-Extensions", AtomicString(extensionValue));
    265 
    266     request->addHeaderField("User-Agent", AtomicString(m_document->userAgent(m_document->url())));
    267 
    268     return request.release();
    269 }
    270 
    271 void WebSocketHandshake::reset()
    272 {
    273     m_mode = Incomplete;
    274     m_extensionDispatcher.reset();
    275 }
    276 
    277 void WebSocketHandshake::clearDocument()
    278 {
    279     m_document = nullptr;
    280 }
    281 
    282 int WebSocketHandshake::readServerHandshake(const char* header, size_t len)
    283 {
    284     m_mode = Incomplete;
    285     int statusCode;
    286     String statusText;
    287     int lineLength = readStatusLine(header, len, statusCode, statusText);
    288     if (lineLength == -1)
    289         return -1;
    290     if (statusCode == -1) {
    291         m_mode = Failed; // m_failureReason is set inside readStatusLine().
    292         return len;
    293     }
    294     WTF_LOG(Network, "WebSocketHandshake %p readServerHandshake() Status code is %d", this, statusCode);
    295     m_response.setStatusCode(statusCode);
    296     m_response.setStatusText(statusText);
    297     if (statusCode != 101) {
    298         m_mode = Failed;
    299         m_failureReason = formatHandshakeFailureReason("Unexpected response code: " + String::number(statusCode));
    300         return len;
    301     }
    302     m_mode = Normal;
    303     if (!strnstr(header, "\r\n\r\n", len)) {
    304         // Just hasn't been received fully yet.
    305         m_mode = Incomplete;
    306         return -1;
    307     }
    308     const char* p = readHTTPHeaders(header + lineLength, header + len);
    309     if (!p) {
    310         WTF_LOG(Network, "WebSocketHandshake %p readServerHandshake() readHTTPHeaders() failed", this);
    311         m_mode = Failed; // m_failureReason is set inside readHTTPHeaders().
    312         return len;
    313     }
    314     if (!checkResponseHeaders()) {
    315         WTF_LOG(Network, "WebSocketHandshake %p readServerHandshake() checkResponseHeaders() failed", this);
    316         m_mode = Failed;
    317         return p - header;
    318     }
    319 
    320     m_mode = Connected;
    321     return p - header;
    322 }
    323 
    324 WebSocketHandshake::Mode WebSocketHandshake::mode() const
    325 {
    326     return m_mode;
    327 }
    328 
    329 String WebSocketHandshake::failureReason() const
    330 {
    331     return m_failureReason;
    332 }
    333 
    334 const AtomicString& WebSocketHandshake::serverWebSocketProtocol() const
    335 {
    336     return m_response.headerFields().get("sec-websocket-protocol");
    337 }
    338 
    339 const AtomicString& WebSocketHandshake::serverUpgrade() const
    340 {
    341     return m_response.headerFields().get("upgrade");
    342 }
    343 
    344 const AtomicString& WebSocketHandshake::serverConnection() const
    345 {
    346     return m_response.headerFields().get("connection");
    347 }
    348 
    349 const AtomicString& WebSocketHandshake::serverWebSocketAccept() const
    350 {
    351     return m_response.headerFields().get("sec-websocket-accept");
    352 }
    353 
    354 String WebSocketHandshake::acceptedExtensions() const
    355 {
    356     return m_extensionDispatcher.acceptedExtensions();
    357 }
    358 
    359 const WebSocketHandshakeResponse& WebSocketHandshake::serverHandshakeResponse() const
    360 {
    361     return m_response;
    362 }
    363 
    364 void WebSocketHandshake::addExtensionProcessor(PassOwnPtr<WebSocketExtensionProcessor> processor)
    365 {
    366     m_extensionDispatcher.addProcessor(processor);
    367 }
    368 
    369 KURL WebSocketHandshake::httpURLForAuthenticationAndCookies() const
    370 {
    371     KURL url = m_url.copy();
    372     bool couldSetProtocol = url.setProtocol(m_secure ? "https" : "http");
    373     ASSERT_UNUSED(couldSetProtocol, couldSetProtocol);
    374     return url;
    375 }
    376 
    377 // Returns the header length (including "\r\n"), or -1 if we have not received enough data yet.
    378 // If the line is malformed or the status code is not a 3-digit number,
    379 // statusCode and statusText will be set to -1 and a null string, respectively.
    380 int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, int& statusCode, String& statusText)
    381 {
    382     // Arbitrary size limit to prevent the server from sending an unbounded
    383     // amount of data with no newlines and forcing us to buffer it all.
    384     static const int maximumLength = 1024;
    385 
    386     statusCode = -1;
    387     statusText = String();
    388 
    389     const char* space1 = 0;
    390     const char* space2 = 0;
    391     const char* p;
    392     size_t consumedLength;
    393 
    394     for (p = header, consumedLength = 0; consumedLength < headerLength; p++, consumedLength++) {
    395         if (*p == ' ') {
    396             if (!space1)
    397                 space1 = p;
    398             else if (!space2)
    399                 space2 = p;
    400         } else if (*p == '\0') {
    401             // The caller isn't prepared to deal with null bytes in status
    402             // line. WebSockets specification doesn't prohibit this, but HTTP
    403             // does, so we'll just treat this as an error.
    404             m_failureReason = formatHandshakeFailureReason("Status line contains embedded null");
    405             return p + 1 - header;
    406         } else if (*p == '\n') {
    407             break;
    408         }
    409     }
    410     if (consumedLength == headerLength)
    411         return -1; // We have not received '\n' yet.
    412 
    413     const char* end = p + 1;
    414     int lineLength = end - header;
    415     if (lineLength > maximumLength) {
    416         m_failureReason = formatHandshakeFailureReason("Status line is too long");
    417         return maximumLength;
    418     }
    419 
    420     // The line must end with "\r\n".
    421     if (lineLength < 2 || *(end - 2) != '\r') {
    422         m_failureReason = formatHandshakeFailureReason("Status line does not end with CRLF");
    423         return lineLength;
    424     }
    425 
    426     if (!space1 || !space2) {
    427         m_failureReason = formatHandshakeFailureReason("No response code found in status line: " + trimInputSample(header, lineLength - 2));
    428         return lineLength;
    429     }
    430 
    431     String statusCodeString(space1 + 1, space2 - space1 - 1);
    432     if (statusCodeString.length() != 3) // Status code must consist of three digits.
    433         return lineLength;
    434     for (int i = 0; i < 3; ++i) {
    435         if (statusCodeString[i] < '0' || statusCodeString[i] > '9') {
    436             m_failureReason = formatHandshakeFailureReason("Invalid status code: " + statusCodeString);
    437             return lineLength;
    438         }
    439     }
    440 
    441     bool ok = false;
    442     statusCode = statusCodeString.toInt(&ok);
    443     ASSERT(ok);
    444 
    445     statusText = String(space2 + 1, end - space2 - 3); // Exclude "\r\n".
    446     return lineLength;
    447 }
    448 
    449 const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* end)
    450 {
    451     m_response.clearHeaderFields();
    452 
    453     AtomicString name;
    454     AtomicString value;
    455     bool sawSecWebSocketAcceptHeaderField = false;
    456     bool sawSecWebSocketProtocolHeaderField = false;
    457     const char* p = start;
    458     while (p < end) {
    459         size_t consumedLength = parseHTTPHeader(p, end - p, m_failureReason, name, value);
    460         if (!consumedLength)
    461             return 0;
    462         p += consumedLength;
    463 
    464         // Stop once we consumed an empty line.
    465         if (name.isEmpty())
    466             break;
    467 
    468         // Sec-WebSocket-Extensions may be split. We parse and check the
    469         // header value every time the header appears.
    470         if (equalIgnoringCase("Sec-WebSocket-Extensions", name)) {
    471             if (!m_extensionDispatcher.processHeaderValue(value)) {
    472                 m_failureReason = formatHandshakeFailureReason(m_extensionDispatcher.failureReason());
    473                 return 0;
    474             }
    475         } else if (equalIgnoringCase("Sec-WebSocket-Accept", name)) {
    476             if (sawSecWebSocketAcceptHeaderField) {
    477                 m_failureReason = formatHandshakeFailureReason("'Sec-WebSocket-Accept' header must not appear more than once in a response");
    478                 return 0;
    479             }
    480             m_response.addHeaderField(name, value);
    481             sawSecWebSocketAcceptHeaderField = true;
    482         } else if (equalIgnoringCase("Sec-WebSocket-Protocol", name)) {
    483             if (sawSecWebSocketProtocolHeaderField) {
    484                 m_failureReason = formatHandshakeFailureReason("'Sec-WebSocket-Protocol' header must not appear more than once in a response");
    485                 return 0;
    486             }
    487             m_response.addHeaderField(name, value);
    488             sawSecWebSocketProtocolHeaderField = true;
    489         } else {
    490             m_response.addHeaderField(name, value);
    491         }
    492     }
    493 
    494     String extensions = m_extensionDispatcher.acceptedExtensions();
    495     if (!extensions.isEmpty())
    496         m_response.addHeaderField("Sec-WebSocket-Extensions", AtomicString(extensions));
    497     return p;
    498 }
    499 
    500 bool WebSocketHandshake::checkResponseHeaders()
    501 {
    502     const AtomicString& serverWebSocketProtocol = this->serverWebSocketProtocol();
    503     const AtomicString& serverUpgrade = this->serverUpgrade();
    504     const AtomicString& serverConnection = this->serverConnection();
    505     const AtomicString& serverWebSocketAccept = this->serverWebSocketAccept();
    506 
    507     if (serverUpgrade.isNull()) {
    508         m_failureReason = formatHandshakeFailureReason("'Upgrade' header is missing");
    509         return false;
    510     }
    511     if (serverConnection.isNull()) {
    512         m_failureReason = formatHandshakeFailureReason("'Connection' header is missing");
    513         return false;
    514     }
    515     if (serverWebSocketAccept.isNull()) {
    516         m_failureReason = formatHandshakeFailureReason("'Sec-WebSocket-Accept' header is missing");
    517         return false;
    518     }
    519 
    520     if (!equalIgnoringCase(serverUpgrade, "websocket")) {
    521         m_failureReason = formatHandshakeFailureReason("'Upgrade' header value is not 'WebSocket': " + serverUpgrade);
    522         return false;
    523     }
    524     if (!equalIgnoringCase(serverConnection, "upgrade")) {
    525         m_failureReason = formatHandshakeFailureReason("'Connection' header value is not 'Upgrade': " + serverConnection);
    526         return false;
    527     }
    528 
    529     if (serverWebSocketAccept != m_expectedAccept) {
    530         m_failureReason = formatHandshakeFailureReason("Incorrect 'Sec-WebSocket-Accept' header value");
    531         return false;
    532     }
    533     if (!serverWebSocketProtocol.isNull()) {
    534         if (m_clientProtocol.isEmpty()) {
    535             m_failureReason = formatHandshakeFailureReason("Response must not include 'Sec-WebSocket-Protocol' header if not present in request: " + serverWebSocketProtocol);
    536             return false;
    537         }
    538         Vector<String> result;
    539         m_clientProtocol.split(String(DOMWebSocket::subprotocolSeperator()), result);
    540         if (!result.contains(serverWebSocketProtocol)) {
    541             m_failureReason = formatHandshakeFailureReason("'Sec-WebSocket-Protocol' header value '" + serverWebSocketProtocol + "' in response does not match any of sent values");
    542             return false;
    543         }
    544     } else if (!m_clientProtocol.isEmpty()) {
    545         m_failureReason = formatHandshakeFailureReason("Sent non-empty 'Sec-WebSocket-Protocol' header but no response was received");
    546         return false;
    547     }
    548     return true;
    549 }
    550 
    551 void WebSocketHandshake::trace(Visitor* visitor)
    552 {
    553     visitor->trace(m_document);
    554 }
    555 
    556 } // namespace blink
    557