Home | History | Annotate | Download | only in websockets
      1 /*
      2  * Copyright (C) 2011 Google Inc.  All rights reserved.
      3  * Copyright (C) Research In Motion Limited 2011. All rights reserved.
      4  *
      5  * Redistribution and use in source and binary forms, with or without
      6  * modification, are permitted provided that the following conditions are
      7  * met:
      8  *
      9  *     * Redistributions of source code must retain the above copyright
     10  * notice, this list of conditions and the following disclaimer.
     11  *     * Redistributions in binary form must reproduce the above
     12  * copyright notice, this list of conditions and the following disclaimer
     13  * in the documentation and/or other materials provided with the
     14  * distribution.
     15  *     * Neither the name of Google Inc. nor the names of its
     16  * contributors may be used to endorse or promote products derived from
     17  * this software without specific prior written permission.
     18  *
     19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     20  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     21  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     22  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     23  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     24  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     25  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     26  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     27  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     28  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     29  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     30  */
     31 
     32 #include "config.h"
     33 
     34 #include "modules/websockets/WebSocketHandshake.h"
     35 
     36 #include "core/dom/Document.h"
     37 #include "core/inspector/ScriptCallStack.h"
     38 #include "core/loader/CookieJar.h"
     39 #include "modules/websockets/WebSocket.h"
     40 #include "platform/Cookie.h"
     41 #include "platform/Crypto.h"
     42 #include "platform/Logging.h"
     43 #include "platform/network/HTTPHeaderMap.h"
     44 #include "platform/network/HTTPParsers.h"
     45 #include "platform/weborigin/SecurityOrigin.h"
     46 #include "public/platform/Platform.h"
     47 #include "wtf/CryptographicallyRandomNumber.h"
     48 #include "wtf/StdLibExtras.h"
     49 #include "wtf/StringExtras.h"
     50 #include "wtf/Vector.h"
     51 #include "wtf/text/Base64.h"
     52 #include "wtf/text/CString.h"
     53 #include "wtf/text/StringBuilder.h"
     54 #include "wtf/unicode/CharacterNames.h"
     55 
     56 namespace WebCore {
     57 
     58 String formatHandshakeFailureReason(const String& detail)
     59 {
     60     return "Error during WebSocket handshake: " + detail;
     61 }
     62 
     63 static String resourceName(const KURL& url)
     64 {
     65     StringBuilder name;
     66     name.append(url.path());
     67     if (name.isEmpty())
     68         name.append('/');
     69     if (!url.query().isNull()) {
     70         name.append('?');
     71         name.append(url.query());
     72     }
     73     String result = name.toString();
     74     ASSERT(!result.isEmpty());
     75     ASSERT(!result.contains(' '));
     76     return result;
     77 }
     78 
     79 static String hostName(const KURL& url, bool secure)
     80 {
     81     ASSERT(url.protocolIs("wss") == secure);
     82     StringBuilder builder;
     83     builder.append(url.host().lower());
     84     if (url.port() && ((!secure && url.port() != 80) || (secure && url.port() != 443))) {
     85         builder.append(':');
     86         builder.appendNumber(url.port());
     87     }
     88     return builder.toString();
     89 }
     90 
     91 static const size_t maxInputSampleSize = 128;
     92 static String trimInputSample(const char* p, size_t len)
     93 {
     94     if (len > maxInputSampleSize)
     95         return String(p, maxInputSampleSize) + horizontalEllipsis;
     96     return String(p, len);
     97 }
     98 
     99 static String generateSecWebSocketKey()
    100 {
    101     static const size_t nonceSize = 16;
    102     unsigned char key[nonceSize];
    103     cryptographicallyRandomValues(key, nonceSize);
    104     return base64Encode(reinterpret_cast<char*>(key), nonceSize);
    105 }
    106 
    107 String WebSocketHandshake::getExpectedWebSocketAccept(const String& secWebSocketKey)
    108 {
    109     static const char webSocketKeyGUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
    110     CString keyData = secWebSocketKey.ascii();
    111 
    112     StringBuilder digestable;
    113     digestable.append(secWebSocketKey);
    114     digestable.append(webSocketKeyGUID, strlen(webSocketKeyGUID));
    115     CString digestableCString = digestable.toString().utf8();
    116     DigestValue digest;
    117     bool digestSuccess = computeDigest(HashAlgorithmSha1, digestableCString.data(), digestableCString.length(), digest);
    118     RELEASE_ASSERT(digestSuccess);
    119 
    120     return base64Encode(reinterpret_cast<const char*>(digest.data()), sha1HashSize);
    121 }
    122 
    123 WebSocketHandshake::WebSocketHandshake(const KURL& url, const String& protocol, Document* document)
    124     : m_url(url)
    125     , m_clientProtocol(protocol)
    126     , m_secure(m_url.protocolIs("wss"))
    127     , m_document(document)
    128     , m_mode(Incomplete)
    129 {
    130     m_secWebSocketKey = generateSecWebSocketKey();
    131     m_expectedAccept = getExpectedWebSocketAccept(m_secWebSocketKey);
    132 }
    133 
    134 WebSocketHandshake::~WebSocketHandshake()
    135 {
    136     blink::Platform::current()->histogramEnumeration("WebCore.WebSocket.HandshakeResult", m_mode, WebSocketHandshake::ModeMax);
    137 }
    138 
    139 const KURL& WebSocketHandshake::url() const
    140 {
    141     return m_url;
    142 }
    143 
    144 void WebSocketHandshake::setURL(const KURL& url)
    145 {
    146     m_url = url.copy();
    147 }
    148 
    149 const String WebSocketHandshake::host() const
    150 {
    151     return m_url.host().lower();
    152 }
    153 
    154 const String& WebSocketHandshake::clientProtocol() const
    155 {
    156     return m_clientProtocol;
    157 }
    158 
    159 void WebSocketHandshake::setClientProtocol(const String& protocol)
    160 {
    161     m_clientProtocol = protocol;
    162 }
    163 
    164 bool WebSocketHandshake::secure() const
    165 {
    166     return m_secure;
    167 }
    168 
    169 String WebSocketHandshake::clientOrigin() const
    170 {
    171     return m_document->securityOrigin()->toString();
    172 }
    173 
    174 String WebSocketHandshake::clientLocation() const
    175 {
    176     StringBuilder builder;
    177     builder.append(m_secure ? "wss" : "ws");
    178     builder.append("://");
    179     builder.append(hostName(m_url, m_secure));
    180     builder.append(resourceName(m_url));
    181     return builder.toString();
    182 }
    183 
    184 CString WebSocketHandshake::clientHandshakeMessage() const
    185 {
    186     ASSERT(m_document);
    187 
    188     // Keep the following consistent with clientHandshakeRequest().
    189     StringBuilder builder;
    190 
    191     builder.append("GET ");
    192     builder.append(resourceName(m_url));
    193     builder.append(" HTTP/1.1\r\n");
    194 
    195     Vector<String> fields;
    196     fields.append("Upgrade: websocket");
    197     fields.append("Connection: Upgrade");
    198     fields.append("Host: " + hostName(m_url, m_secure));
    199     fields.append("Origin: " + clientOrigin());
    200     if (!m_clientProtocol.isEmpty())
    201         fields.append("Sec-WebSocket-Protocol: " + m_clientProtocol);
    202 
    203     // Add no-cache headers to avoid compatibility issue.
    204     // There are some proxies that rewrite "Connection: upgrade"
    205     // to "Connection: close" in the response if a request doesn't contain
    206     // these headers.
    207     fields.append("Pragma: no-cache");
    208     fields.append("Cache-Control: no-cache");
    209 
    210     fields.append("Sec-WebSocket-Key: " + m_secWebSocketKey);
    211     fields.append("Sec-WebSocket-Version: 13");
    212     const String extensionValue = m_extensionDispatcher.createHeaderValue();
    213     if (extensionValue.length())
    214         fields.append("Sec-WebSocket-Extensions: " + extensionValue);
    215 
    216     fields.append("User-Agent: " + m_document->userAgent(m_document->url()));
    217 
    218     // Fields in the handshake are sent by the client in a random order; the
    219     // order is not meaningful. Thus, it's ok to send the order we constructed
    220     // the fields.
    221 
    222     for (size_t i = 0; i < fields.size(); i++) {
    223         builder.append(fields[i]);
    224         builder.append("\r\n");
    225     }
    226 
    227     builder.append("\r\n");
    228 
    229     return builder.toString().utf8();
    230 }
    231 
    232 PassRefPtr<WebSocketHandshakeRequest> WebSocketHandshake::clientHandshakeRequest() const
    233 {
    234     ASSERT(m_document);
    235 
    236     // Keep the following consistent with clientHandshakeMessage().
    237     // FIXME: do we need to store m_secWebSocketKey1, m_secWebSocketKey2 and
    238     // m_key3 in WebSocketHandshakeRequest?
    239     RefPtr<WebSocketHandshakeRequest> request = WebSocketHandshakeRequest::create(m_url);
    240     request->addHeaderField("Upgrade", "websocket");
    241     request->addHeaderField("Connection", "Upgrade");
    242     request->addHeaderField("Host", AtomicString(hostName(m_url, m_secure)));
    243     request->addHeaderField("Origin", AtomicString(clientOrigin()));
    244     if (!m_clientProtocol.isEmpty())
    245         request->addHeaderField("Sec-WebSocket-Protocol", AtomicString(m_clientProtocol));
    246 
    247     KURL url = httpURLForAuthenticationAndCookies();
    248 
    249     String cookie = cookieRequestHeaderFieldValue(m_document, url);
    250     if (!cookie.isEmpty())
    251         request->addHeaderField("Cookie", AtomicString(cookie));
    252     // Set "Cookie2: <cookie>" if cookies 2 exists for url?
    253 
    254     request->addHeaderField("Pragma", "no-cache");
    255     request->addHeaderField("Cache-Control", "no-cache");
    256 
    257     request->addHeaderField("Sec-WebSocket-Key", AtomicString(m_secWebSocketKey));
    258     request->addHeaderField("Sec-WebSocket-Version", "13");
    259     const String extensionValue = m_extensionDispatcher.createHeaderValue();
    260     if (extensionValue.length())
    261         request->addHeaderField("Sec-WebSocket-Extensions", AtomicString(extensionValue));
    262 
    263     request->addHeaderField("User-Agent", AtomicString(m_document->userAgent(m_document->url())));
    264 
    265     return request.release();
    266 }
    267 
    268 void WebSocketHandshake::reset()
    269 {
    270     m_mode = Incomplete;
    271     m_extensionDispatcher.reset();
    272 }
    273 
    274 void WebSocketHandshake::clearDocument()
    275 {
    276     m_document = 0;
    277 }
    278 
    279 int WebSocketHandshake::readServerHandshake(const char* header, size_t len)
    280 {
    281     m_mode = Incomplete;
    282     int statusCode;
    283     String statusText;
    284     int lineLength = readStatusLine(header, len, statusCode, statusText);
    285     if (lineLength == -1)
    286         return -1;
    287     if (statusCode == -1) {
    288         m_mode = Failed; // m_failureReason is set inside readStatusLine().
    289         return len;
    290     }
    291     WTF_LOG(Network, "WebSocketHandshake %p readServerHandshake() Status code is %d", this, statusCode);
    292     m_response.setStatusCode(statusCode);
    293     m_response.setStatusText(statusText);
    294     if (statusCode != 101) {
    295         m_mode = Failed;
    296         m_failureReason = formatHandshakeFailureReason("Unexpected response code: " + String::number(statusCode));
    297         return len;
    298     }
    299     m_mode = Normal;
    300     if (!strnstr(header, "\r\n\r\n", len)) {
    301         // Just hasn't been received fully yet.
    302         m_mode = Incomplete;
    303         return -1;
    304     }
    305     const char* p = readHTTPHeaders(header + lineLength, header + len);
    306     if (!p) {
    307         WTF_LOG(Network, "WebSocketHandshake %p readServerHandshake() readHTTPHeaders() failed", this);
    308         m_mode = Failed; // m_failureReason is set inside readHTTPHeaders().
    309         return len;
    310     }
    311     if (!checkResponseHeaders()) {
    312         WTF_LOG(Network, "WebSocketHandshake %p readServerHandshake() checkResponseHeaders() failed", this);
    313         m_mode = Failed;
    314         return p - header;
    315     }
    316 
    317     m_mode = Connected;
    318     return p - header;
    319 }
    320 
    321 WebSocketHandshake::Mode WebSocketHandshake::mode() const
    322 {
    323     return m_mode;
    324 }
    325 
    326 String WebSocketHandshake::failureReason() const
    327 {
    328     return m_failureReason;
    329 }
    330 
    331 const AtomicString& WebSocketHandshake::serverWebSocketProtocol() const
    332 {
    333     return m_response.headerFields().get("sec-websocket-protocol");
    334 }
    335 
    336 const AtomicString& WebSocketHandshake::serverUpgrade() const
    337 {
    338     return m_response.headerFields().get("upgrade");
    339 }
    340 
    341 const AtomicString& WebSocketHandshake::serverConnection() const
    342 {
    343     return m_response.headerFields().get("connection");
    344 }
    345 
    346 const AtomicString& WebSocketHandshake::serverWebSocketAccept() const
    347 {
    348     return m_response.headerFields().get("sec-websocket-accept");
    349 }
    350 
    351 String WebSocketHandshake::acceptedExtensions() const
    352 {
    353     return m_extensionDispatcher.acceptedExtensions();
    354 }
    355 
    356 const WebSocketHandshakeResponse& WebSocketHandshake::serverHandshakeResponse() const
    357 {
    358     return m_response;
    359 }
    360 
    361 void WebSocketHandshake::addExtensionProcessor(PassOwnPtr<WebSocketExtensionProcessor> processor)
    362 {
    363     m_extensionDispatcher.addProcessor(processor);
    364 }
    365 
    366 KURL WebSocketHandshake::httpURLForAuthenticationAndCookies() const
    367 {
    368     KURL url = m_url.copy();
    369     bool couldSetProtocol = url.setProtocol(m_secure ? "https" : "http");
    370     ASSERT_UNUSED(couldSetProtocol, couldSetProtocol);
    371     return url;
    372 }
    373 
    374 // Returns the header length (including "\r\n"), or -1 if we have not received enough data yet.
    375 // If the line is malformed or the status code is not a 3-digit number,
    376 // statusCode and statusText will be set to -1 and a null string, respectively.
    377 int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, int& statusCode, String& statusText)
    378 {
    379     // Arbitrary size limit to prevent the server from sending an unbounded
    380     // amount of data with no newlines and forcing us to buffer it all.
    381     static const int maximumLength = 1024;
    382 
    383     statusCode = -1;
    384     statusText = String();
    385 
    386     const char* space1 = 0;
    387     const char* space2 = 0;
    388     const char* p;
    389     size_t consumedLength;
    390 
    391     for (p = header, consumedLength = 0; consumedLength < headerLength; p++, consumedLength++) {
    392         if (*p == ' ') {
    393             if (!space1)
    394                 space1 = p;
    395             else if (!space2)
    396                 space2 = p;
    397         } else if (*p == '\0') {
    398             // The caller isn't prepared to deal with null bytes in status
    399             // line. WebSockets specification doesn't prohibit this, but HTTP
    400             // does, so we'll just treat this as an error.
    401             m_failureReason = formatHandshakeFailureReason("Status line contains embedded null");
    402             return p + 1 - header;
    403         } else if (*p == '\n') {
    404             break;
    405         }
    406     }
    407     if (consumedLength == headerLength)
    408         return -1; // We have not received '\n' yet.
    409 
    410     const char* end = p + 1;
    411     int lineLength = end - header;
    412     if (lineLength > maximumLength) {
    413         m_failureReason = formatHandshakeFailureReason("Status line is too long");
    414         return maximumLength;
    415     }
    416 
    417     // The line must end with "\r\n".
    418     if (lineLength < 2 || *(end - 2) != '\r') {
    419         m_failureReason = formatHandshakeFailureReason("Status line does not end with CRLF");
    420         return lineLength;
    421     }
    422 
    423     if (!space1 || !space2) {
    424         m_failureReason = formatHandshakeFailureReason("No response code found in status line: " + trimInputSample(header, lineLength - 2));
    425         return lineLength;
    426     }
    427 
    428     String statusCodeString(space1 + 1, space2 - space1 - 1);
    429     if (statusCodeString.length() != 3) // Status code must consist of three digits.
    430         return lineLength;
    431     for (int i = 0; i < 3; ++i) {
    432         if (statusCodeString[i] < '0' || statusCodeString[i] > '9') {
    433             m_failureReason = formatHandshakeFailureReason("Invalid status code: " + statusCodeString);
    434             return lineLength;
    435         }
    436     }
    437 
    438     bool ok = false;
    439     statusCode = statusCodeString.toInt(&ok);
    440     ASSERT(ok);
    441 
    442     statusText = String(space2 + 1, end - space2 - 3); // Exclude "\r\n".
    443     return lineLength;
    444 }
    445 
    446 const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* end)
    447 {
    448     m_response.clearHeaderFields();
    449 
    450     AtomicString name;
    451     AtomicString value;
    452     bool sawSecWebSocketAcceptHeaderField = false;
    453     bool sawSecWebSocketProtocolHeaderField = false;
    454     const char* p = start;
    455     while (p < end) {
    456         size_t consumedLength = parseHTTPHeader(p, end - p, m_failureReason, name, value);
    457         if (!consumedLength)
    458             return 0;
    459         p += consumedLength;
    460 
    461         // Stop once we consumed an empty line.
    462         if (name.isEmpty())
    463             break;
    464 
    465         // Sec-WebSocket-Extensions may be split. We parse and check the
    466         // header value every time the header appears.
    467         if (equalIgnoringCase("Sec-WebSocket-Extensions", name)) {
    468             if (!m_extensionDispatcher.processHeaderValue(value)) {
    469                 m_failureReason = formatHandshakeFailureReason(m_extensionDispatcher.failureReason());
    470                 return 0;
    471             }
    472         } else if (equalIgnoringCase("Sec-WebSocket-Accept", name)) {
    473             if (sawSecWebSocketAcceptHeaderField) {
    474                 m_failureReason = formatHandshakeFailureReason("'Sec-WebSocket-Accept' header must not appear more than once in a response");
    475                 return 0;
    476             }
    477             m_response.addHeaderField(name, value);
    478             sawSecWebSocketAcceptHeaderField = true;
    479         } else if (equalIgnoringCase("Sec-WebSocket-Protocol", name)) {
    480             if (sawSecWebSocketProtocolHeaderField) {
    481                 m_failureReason = formatHandshakeFailureReason("'Sec-WebSocket-Protocol' header must not appear more than once in a response");
    482                 return 0;
    483             }
    484             m_response.addHeaderField(name, value);
    485             sawSecWebSocketProtocolHeaderField = true;
    486         } else {
    487             m_response.addHeaderField(name, value);
    488         }
    489     }
    490 
    491     String extensions = m_extensionDispatcher.acceptedExtensions();
    492     if (!extensions.isEmpty())
    493         m_response.addHeaderField("Sec-WebSocket-Extensions", AtomicString(extensions));
    494     return p;
    495 }
    496 
    497 bool WebSocketHandshake::checkResponseHeaders()
    498 {
    499     const AtomicString& serverWebSocketProtocol = this->serverWebSocketProtocol();
    500     const AtomicString& serverUpgrade = this->serverUpgrade();
    501     const AtomicString& serverConnection = this->serverConnection();
    502     const AtomicString& serverWebSocketAccept = this->serverWebSocketAccept();
    503 
    504     if (serverUpgrade.isNull()) {
    505         m_failureReason = formatHandshakeFailureReason("'Upgrade' header is missing");
    506         return false;
    507     }
    508     if (serverConnection.isNull()) {
    509         m_failureReason = formatHandshakeFailureReason("'Connection' header is missing");
    510         return false;
    511     }
    512     if (serverWebSocketAccept.isNull()) {
    513         m_failureReason = formatHandshakeFailureReason("'Sec-WebSocket-Accept' header is missing");
    514         return false;
    515     }
    516 
    517     if (!equalIgnoringCase(serverUpgrade, "websocket")) {
    518         m_failureReason = formatHandshakeFailureReason("'Upgrade' header value is not 'WebSocket': " + serverUpgrade);
    519         return false;
    520     }
    521     if (!equalIgnoringCase(serverConnection, "upgrade")) {
    522         m_failureReason = formatHandshakeFailureReason("'Connection' header value is not 'Upgrade': " + serverConnection);
    523         return false;
    524     }
    525 
    526     if (serverWebSocketAccept != m_expectedAccept) {
    527         m_failureReason = formatHandshakeFailureReason("Incorrect 'Sec-WebSocket-Accept' header value");
    528         return false;
    529     }
    530     if (!serverWebSocketProtocol.isNull()) {
    531         if (m_clientProtocol.isEmpty()) {
    532             m_failureReason = formatHandshakeFailureReason("Response must not include 'Sec-WebSocket-Protocol' header if not present in request: " + serverWebSocketProtocol);
    533             return false;
    534         }
    535         Vector<String> result;
    536         m_clientProtocol.split(String(WebSocket::subprotocolSeperator()), result);
    537         if (!result.contains(serverWebSocketProtocol)) {
    538             m_failureReason = formatHandshakeFailureReason("'Sec-WebSocket-Protocol' header value '" + serverWebSocketProtocol + "' in response does not match any of sent values");
    539             return false;
    540         }
    541     } else if (!m_clientProtocol.isEmpty()) {
    542         m_failureReason = formatHandshakeFailureReason("Sent non-empty 'Sec-WebSocket-Protocol' header but no response was received");
    543         return false;
    544     }
    545     return true;
    546 }
    547 
    548 } // namespace WebCore
    549