Home | History | Annotate | Download | only in websockets
      1 /*
      2  * Copyright (C) 2009 Google Inc.  All rights reserved.
      3  *
      4  * Redistribution and use in source and binary forms, with or without
      5  * modification, are permitted provided that the following conditions are
      6  * met:
      7  *
      8  *     * Redistributions of source code must retain the above copyright
      9  * notice, this list of conditions and the following disclaimer.
     10  *     * Redistributions in binary form must reproduce the above
     11  * copyright notice, this list of conditions and the following disclaimer
     12  * in the documentation and/or other materials provided with the
     13  * distribution.
     14  *     * Neither the name of Google Inc. nor the names of its
     15  * contributors may be used to endorse or promote products derived from
     16  * this software without specific prior written permission.
     17  *
     18  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     19  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     20  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     21  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     22  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     23  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     24  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     25  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     26  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     27  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     28  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     29  */
     30 
     31 #include "config.h"
     32 
     33 #if ENABLE(WEB_SOCKETS)
     34 
     35 #include "WebSocketHandshake.h"
     36 
     37 #include "AtomicString.h"
     38 #include "CString.h"
     39 #include "Cookie.h"
     40 #include "CookieJar.h"
     41 #include "Document.h"
     42 #include "HTTPHeaderMap.h"
     43 #include "KURL.h"
     44 #include "Logging.h"
     45 #include "ScriptExecutionContext.h"
     46 #include "SecurityOrigin.h"
     47 #include "StringBuilder.h"
     48 #include <wtf/StringExtras.h>
     49 #include <wtf/Vector.h>
     50 
     51 namespace WebCore {
     52 
     53 const char webSocketServerHandshakeHeader[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n";
     54 const char webSocketUpgradeHeader[] = "Upgrade: WebSocket\r\n";
     55 const char webSocketConnectionHeader[] = "Connection: Upgrade\r\n";
     56 
     57 static String extractResponseCode(const char* header, int len)
     58 {
     59     const char* space1 = 0;
     60     const char* space2 = 0;
     61     const char* p;
     62     for (p = header; p - header < len; p++) {
     63         if (*p == ' ') {
     64             if (!space1)
     65                 space1 = p;
     66             else if (!space2)
     67                 space2 = p;
     68         } else if (*p == '\n')
     69             break;
     70     }
     71     if (p - header == len)
     72         return String();
     73     if (!space1 || !space2)
     74         return "";
     75     return String(space1 + 1, space2 - space1 - 1);
     76 }
     77 
     78 static String resourceName(const KURL& url)
     79 {
     80     String name = url.path();
     81     if (name.isEmpty())
     82         name = "/";
     83     if (!url.query().isNull())
     84         name += "?" + url.query();
     85     ASSERT(!name.isEmpty());
     86     ASSERT(!name.contains(' '));
     87     return name;
     88 }
     89 
     90 WebSocketHandshake::WebSocketHandshake(const KURL& url, const String& protocol, ScriptExecutionContext* context)
     91     : m_url(url)
     92     , m_clientProtocol(protocol)
     93     , m_secure(m_url.protocolIs("wss"))
     94     , m_context(context)
     95     , m_mode(Incomplete)
     96 {
     97 }
     98 
     99 WebSocketHandshake::~WebSocketHandshake()
    100 {
    101 }
    102 
    103 const KURL& WebSocketHandshake::url() const
    104 {
    105     return m_url;
    106 }
    107 
    108 void WebSocketHandshake::setURL(const KURL& url)
    109 {
    110     m_url = url.copy();
    111 }
    112 
    113 const String WebSocketHandshake::host() const
    114 {
    115     return m_url.host().lower();
    116 }
    117 
    118 const String& WebSocketHandshake::clientProtocol() const
    119 {
    120     return m_clientProtocol;
    121 }
    122 
    123 void WebSocketHandshake::setClientProtocol(const String& protocol)
    124 {
    125     m_clientProtocol = protocol;
    126 }
    127 
    128 bool WebSocketHandshake::secure() const
    129 {
    130     return m_secure;
    131 }
    132 
    133 void WebSocketHandshake::setSecure(bool secure)
    134 {
    135     m_secure = secure;
    136 }
    137 
    138 String WebSocketHandshake::clientOrigin() const
    139 {
    140     return m_context->securityOrigin()->toString();
    141 }
    142 
    143 String WebSocketHandshake::clientLocation() const
    144 {
    145     StringBuilder builder;
    146     builder.append(m_secure ? "wss" : "ws");
    147     builder.append("://");
    148     builder.append(m_url.host().lower());
    149     if (m_url.port()) {
    150         if ((!m_secure && m_url.port() != 80) || (m_secure && m_url.port() != 443)) {
    151             builder.append(":");
    152             builder.append(String::number(m_url.port()));
    153         }
    154     }
    155     builder.append(resourceName(m_url));
    156     return builder.toString();
    157 }
    158 
    159 CString WebSocketHandshake::clientHandshakeMessage() const
    160 {
    161     StringBuilder builder;
    162 
    163     builder.append("GET ");
    164     builder.append(resourceName(m_url));
    165     builder.append(" HTTP/1.1\r\n");
    166     builder.append("Upgrade: WebSocket\r\n");
    167     builder.append("Connection: Upgrade\r\n");
    168     builder.append("Host: ");
    169     builder.append(m_url.host().lower());
    170     if (m_url.port()) {
    171         if ((!m_secure && m_url.port() != 80) || (m_secure && m_url.port() != 443)) {
    172             builder.append(":");
    173             builder.append(String::number(m_url.port()));
    174         }
    175     }
    176     builder.append("\r\n");
    177     builder.append("Origin: ");
    178     builder.append(clientOrigin());
    179     builder.append("\r\n");
    180     if (!m_clientProtocol.isEmpty()) {
    181         builder.append("WebSocket-Protocol: ");
    182         builder.append(m_clientProtocol);
    183         builder.append("\r\n");
    184     }
    185     KURL url = httpURLForAuthenticationAndCookies();
    186     // FIXME: set authentication information or cookies for url.
    187     // Set "Authorization: <credentials>" if authentication information exists for url.
    188     if (m_context->isDocument()) {
    189         Document* document = static_cast<Document*>(m_context);
    190         String cookie = cookieRequestHeaderFieldValue(document, url);
    191         if (!cookie.isEmpty()) {
    192             builder.append("Cookie: ");
    193             builder.append(cookie);
    194             builder.append("\r\n");
    195         }
    196         // Set "Cookie2: <cookie>" if cookies 2 exists for url?
    197     }
    198     builder.append("\r\n");
    199     return builder.toString().utf8();
    200 }
    201 
    202 void WebSocketHandshake::reset()
    203 {
    204     m_mode = Incomplete;
    205 
    206     m_wsOrigin = String();
    207     m_wsLocation = String();
    208     m_wsProtocol = String();
    209     m_setCookie = String();
    210     m_setCookie2 = String();
    211 }
    212 
    213 int WebSocketHandshake::readServerHandshake(const char* header, size_t len)
    214 {
    215     m_mode = Incomplete;
    216     if (len < sizeof(webSocketServerHandshakeHeader) - 1) {
    217         // Just hasn't been received fully yet.
    218         return -1;
    219     }
    220     if (!memcmp(header, webSocketServerHandshakeHeader, sizeof(webSocketServerHandshakeHeader) - 1))
    221         m_mode = Normal;
    222     else {
    223         const String& code = extractResponseCode(header, len);
    224         if (code.isNull()) {
    225             m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Short server handshake: " + String(header, len), 0, clientOrigin());
    226             return -1;
    227         }
    228         if (code.isEmpty()) {
    229             m_mode = Failed;
    230             m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "No response code found: " + String(header, len), 0, clientOrigin());
    231             return len;
    232         }
    233         LOG(Network, "response code: %s", code.utf8().data());
    234         if (code == "401") {
    235             m_mode = Failed;
    236             m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Authentication required, but not implemented yet.", 0, clientOrigin());
    237             return len;
    238         } else {
    239             m_mode = Failed;
    240             m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected response code:" + code, 0, clientOrigin());
    241             return len;
    242         }
    243     }
    244     const char* p = header + sizeof(webSocketServerHandshakeHeader) - 1;
    245     const char* end = header + len + 1;
    246 
    247     if (m_mode == Normal) {
    248         size_t headerSize = end - p;
    249         if (headerSize < sizeof(webSocketUpgradeHeader) - 1) {
    250             m_mode = Incomplete;
    251             return 0;
    252         }
    253         if (memcmp(p, webSocketUpgradeHeader, sizeof(webSocketUpgradeHeader) - 1)) {
    254             m_mode = Failed;
    255             m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Bad Upgrade header: " + String(p, end - p), 0, clientOrigin());
    256             return p - header + sizeof(webSocketUpgradeHeader) - 1;
    257         }
    258         p += sizeof(webSocketUpgradeHeader) - 1;
    259 
    260         headerSize = end - p;
    261         if (headerSize < sizeof(webSocketConnectionHeader) - 1) {
    262             m_mode = Incomplete;
    263             return -1;
    264         }
    265         if (memcmp(p, webSocketConnectionHeader, sizeof(webSocketConnectionHeader) - 1)) {
    266             m_mode = Failed;
    267             m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Bad Connection header: " + String(p, end - p), 0, clientOrigin());
    268             return p - header + sizeof(webSocketConnectionHeader) - 1;
    269         }
    270         p += sizeof(webSocketConnectionHeader) - 1;
    271     }
    272 
    273     if (!strnstr(p, "\r\n\r\n", end - p)) {
    274         // Just hasn't been received fully yet.
    275         m_mode = Incomplete;
    276         return -1;
    277     }
    278     HTTPHeaderMap headers;
    279     p = readHTTPHeaders(p, end, &headers);
    280     if (!p) {
    281         LOG(Network, "readHTTPHeaders failed");
    282         m_mode = Failed;
    283         return len;
    284     }
    285     if (!processHeaders(headers)) {
    286         LOG(Network, "header process failed");
    287         m_mode = Failed;
    288         return p - header;
    289     }
    290     switch (m_mode) {
    291     case Normal:
    292         checkResponseHeaders();
    293         break;
    294     default:
    295         m_mode = Failed;
    296         break;
    297     }
    298     return p - header;
    299 }
    300 
    301 WebSocketHandshake::Mode WebSocketHandshake::mode() const
    302 {
    303     return m_mode;
    304 }
    305 
    306 const String& WebSocketHandshake::serverWebSocketOrigin() const
    307 {
    308     return m_wsOrigin;
    309 }
    310 
    311 void WebSocketHandshake::setServerWebSocketOrigin(const String& webSocketOrigin)
    312 {
    313     m_wsOrigin = webSocketOrigin;
    314 }
    315 
    316 const String& WebSocketHandshake::serverWebSocketLocation() const
    317 {
    318     return m_wsLocation;
    319 }
    320 
    321 void WebSocketHandshake::setServerWebSocketLocation(const String& webSocketLocation)
    322 {
    323     m_wsLocation = webSocketLocation;
    324 }
    325 
    326 const String& WebSocketHandshake::serverWebSocketProtocol() const
    327 {
    328     return m_wsProtocol;
    329 }
    330 
    331 void WebSocketHandshake::setServerWebSocketProtocol(const String& webSocketProtocol)
    332 {
    333     m_wsProtocol = webSocketProtocol;
    334 }
    335 
    336 const String& WebSocketHandshake::serverSetCookie() const
    337 {
    338     return m_setCookie;
    339 }
    340 
    341 void WebSocketHandshake::setServerSetCookie(const String& setCookie)
    342 {
    343     m_setCookie = setCookie;
    344 }
    345 
    346 const String& WebSocketHandshake::serverSetCookie2() const
    347 {
    348     return m_setCookie2;
    349 }
    350 
    351 void WebSocketHandshake::setServerSetCookie2(const String& setCookie2)
    352 {
    353     m_setCookie2 = setCookie2;
    354 }
    355 
    356 KURL WebSocketHandshake::httpURLForAuthenticationAndCookies() const
    357 {
    358     KURL url = m_url.copy();
    359     bool couldSetProtocol = url.setProtocol(m_secure ? "https" : "http");
    360     ASSERT_UNUSED(couldSetProtocol, couldSetProtocol);
    361     return url;
    362 }
    363 
    364 const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* end, HTTPHeaderMap* headers)
    365 {
    366     Vector<char> name;
    367     Vector<char> value;
    368     for (const char* p = start; p < end; p++) {
    369         name.clear();
    370         value.clear();
    371 
    372         for (; p < end; p++) {
    373             switch (*p) {
    374             case '\r':
    375                 if (name.isEmpty()) {
    376                     if (p + 1 < end && *(p + 1) == '\n')
    377                         return p + 2;
    378                     m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "CR doesn't follow LF at " + String(p, end - p), 0, clientOrigin());
    379                     return 0;
    380                 }
    381                 m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected CR in name at " + String(p, end - p), 0, clientOrigin());
    382                 return 0;
    383             case '\n':
    384                 m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in name at " + String(p, end - p), 0, clientOrigin());
    385                 return 0;
    386             case ':':
    387                 break;
    388             default:
    389                 if (*p >= 0x41 && *p <= 0x5a)
    390                     name.append(*p + 0x20);
    391                 else
    392                     name.append(*p);
    393                 continue;
    394             }
    395             if (*p == ':') {
    396                 ++p;
    397                 break;
    398             }
    399         }
    400 
    401         for (; p < end && *p == 0x20; p++) { }
    402 
    403         for (; p < end; p++) {
    404             switch (*p) {
    405             case '\r':
    406                 break;
    407             case '\n':
    408                 m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in value at " + String(p, end - p), 0, clientOrigin());
    409                 return 0;
    410             default:
    411                 value.append(*p);
    412             }
    413             if (*p == '\r') {
    414                 ++p;
    415                 break;
    416             }
    417         }
    418         if (p >= end || *p != '\n') {
    419             m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "CR doesn't follow LF after value at " + String(p, end - p), 0, clientOrigin());
    420             return 0;
    421         }
    422         AtomicString nameStr(String::fromUTF8(name.data(), name.size()));
    423         String valueStr = String::fromUTF8(value.data(), value.size());
    424         LOG(Network, "name=%s value=%s", nameStr.string().utf8().data(), valueStr.utf8().data());
    425         headers->add(nameStr, valueStr);
    426     }
    427     ASSERT_NOT_REACHED();
    428     return 0;
    429 }
    430 
    431 bool WebSocketHandshake::processHeaders(const HTTPHeaderMap& headers)
    432 {
    433     for (HTTPHeaderMap::const_iterator it = headers.begin(); it != headers.end(); ++it) {
    434         switch (m_mode) {
    435         case Normal:
    436             if (it->first == "websocket-origin")
    437                 m_wsOrigin = it->second;
    438             else if (it->first == "websocket-location")
    439                 m_wsLocation = it->second;
    440             else if (it->first == "websocket-protocol")
    441                 m_wsProtocol = it->second;
    442             else if (it->first == "set-cookie")
    443                 m_setCookie = it->second;
    444             else if (it->first == "set-cookie2")
    445                 m_setCookie2 = it->second;
    446             continue;
    447         case Incomplete:
    448         case Failed:
    449         case Connected:
    450             ASSERT_NOT_REACHED();
    451         }
    452         ASSERT_NOT_REACHED();
    453     }
    454     return true;
    455 }
    456 
    457 void WebSocketHandshake::checkResponseHeaders()
    458 {
    459     ASSERT(m_mode == Normal);
    460     m_mode = Failed;
    461     if (m_wsOrigin.isNull()) {
    462         m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'websocket-origin' header is missing", 0, clientOrigin());
    463         return;
    464     }
    465     if (m_wsLocation.isNull()) {
    466         m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'websocket-location' header is missing", 0, clientOrigin());
    467         return;
    468     }
    469 
    470     if (clientOrigin() != m_wsOrigin) {
    471         m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: origin mismatch: " + clientOrigin() + " != " + m_wsOrigin, 0, clientOrigin());
    472         return;
    473     }
    474     if (clientLocation() != m_wsLocation) {
    475         m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: location mismatch: " + clientLocation() + " != " + m_wsLocation, 0, clientOrigin());
    476         return;
    477     }
    478     if (!m_clientProtocol.isEmpty() && m_clientProtocol != m_wsProtocol) {
    479         m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: protocol mismatch: " + m_clientProtocol + " != " + m_wsProtocol, 0, clientOrigin());
    480         return;
    481     }
    482     m_mode = Connected;
    483     return;
    484 }
    485 
    486 }  // namespace WebCore
    487 
    488 #endif  // ENABLE(WEB_SOCKETS)
    489