Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2006-2008 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/ssl_client_socket.h"
      6 
      7 #include "net/base/address_list.h"
      8 #include "net/base/host_resolver.h"
      9 #include "net/base/io_buffer.h"
     10 #include "net/base/load_log.h"
     11 #include "net/base/load_log_unittest.h"
     12 #include "net/base/net_errors.h"
     13 #include "net/base/ssl_config_service.h"
     14 #include "net/base/test_completion_callback.h"
     15 #include "net/socket/client_socket_factory.h"
     16 #include "net/socket/ssl_test_util.h"
     17 #include "net/socket/tcp_client_socket.h"
     18 #include "testing/gtest/include/gtest/gtest.h"
     19 #include "testing/platform_test.h"
     20 
     21 //-----------------------------------------------------------------------------
     22 
     23 const net::SSLConfig kDefaultSSLConfig;
     24 
     25 class SSLClientSocketTest : public PlatformTest {
     26  public:
     27   SSLClientSocketTest()
     28     : resolver_(net::CreateSystemHostResolver(NULL)),
     29         socket_factory_(net::ClientSocketFactory::GetDefaultFactory()) {
     30   }
     31 
     32   void StartOKServer() {
     33     bool success = server_.Start(net::TestServerLauncher::ProtoHTTP,
     34         server_.kHostName, server_.kOKHTTPSPort,
     35         FilePath(), server_.GetOKCertPath(), std::wstring());
     36     ASSERT_TRUE(success);
     37   }
     38 
     39   void StartMismatchedServer() {
     40     bool success = server_.Start(net::TestServerLauncher::ProtoHTTP,
     41         server_.kMismatchedHostName, server_.kOKHTTPSPort,
     42         FilePath(), server_.GetOKCertPath(), std::wstring());
     43     ASSERT_TRUE(success);
     44   }
     45 
     46   void StartExpiredServer() {
     47     bool success = server_.Start(net::TestServerLauncher::ProtoHTTP,
     48         server_.kHostName, server_.kBadHTTPSPort,
     49         FilePath(), server_.GetExpiredCertPath(), std::wstring());
     50     ASSERT_TRUE(success);
     51   }
     52 
     53  protected:
     54   scoped_refptr<net::HostResolver> resolver_;
     55   net::ClientSocketFactory* socket_factory_;
     56   net::TestServerLauncher server_;
     57 };
     58 
     59 //-----------------------------------------------------------------------------
     60 
     61 TEST_F(SSLClientSocketTest, Connect) {
     62   StartOKServer();
     63 
     64   net::AddressList addr;
     65   TestCompletionCallback callback;
     66 
     67   net::HostResolver::RequestInfo info(server_.kHostName, server_.kOKHTTPSPort);
     68   int rv = resolver_->Resolve(info, &addr, NULL, NULL, NULL);
     69   EXPECT_EQ(net::OK, rv);
     70 
     71   net::ClientSocket *transport = new net::TCPClientSocket(addr);
     72   rv = transport->Connect(&callback, NULL);
     73   if (rv == net::ERR_IO_PENDING)
     74     rv = callback.WaitForResult();
     75   EXPECT_EQ(net::OK, rv);
     76 
     77   scoped_ptr<net::SSLClientSocket> sock(
     78       socket_factory_->CreateSSLClientSocket(transport,
     79           server_.kHostName, kDefaultSSLConfig));
     80 
     81   EXPECT_FALSE(sock->IsConnected());
     82 
     83   scoped_refptr<net::LoadLog> log(new net::LoadLog(net::LoadLog::kUnbounded));
     84   rv = sock->Connect(&callback, log);
     85   EXPECT_TRUE(net::LogContainsBeginEvent(
     86       *log, 0, net::LoadLog::TYPE_SSL_CONNECT));
     87   if (rv != net::OK) {
     88     ASSERT_EQ(net::ERR_IO_PENDING, rv);
     89     EXPECT_FALSE(sock->IsConnected());
     90     EXPECT_FALSE(net::LogContainsEndEvent(
     91         *log, -1, net::LoadLog::TYPE_SSL_CONNECT));
     92 
     93     rv = callback.WaitForResult();
     94     EXPECT_EQ(net::OK, rv);
     95   }
     96 
     97   EXPECT_TRUE(sock->IsConnected());
     98   EXPECT_TRUE(net::LogContainsEndEvent(
     99       *log, -1, net::LoadLog::TYPE_SSL_CONNECT));
    100 
    101   sock->Disconnect();
    102   EXPECT_FALSE(sock->IsConnected());
    103 }
    104 
    105 TEST_F(SSLClientSocketTest, ConnectExpired) {
    106   StartExpiredServer();
    107 
    108   net::AddressList addr;
    109   TestCompletionCallback callback;
    110 
    111   net::HostResolver::RequestInfo info(server_.kHostName, server_.kBadHTTPSPort);
    112   int rv = resolver_->Resolve(info, &addr, NULL, NULL, NULL);
    113   EXPECT_EQ(net::OK, rv);
    114 
    115   net::ClientSocket *transport = new net::TCPClientSocket(addr);
    116   rv = transport->Connect(&callback, NULL);
    117   if (rv == net::ERR_IO_PENDING)
    118     rv = callback.WaitForResult();
    119   EXPECT_EQ(net::OK, rv);
    120 
    121   scoped_ptr<net::SSLClientSocket> sock(
    122       socket_factory_->CreateSSLClientSocket(transport,
    123           server_.kHostName, kDefaultSSLConfig));
    124 
    125   EXPECT_FALSE(sock->IsConnected());
    126 
    127   scoped_refptr<net::LoadLog> log(new net::LoadLog(net::LoadLog::kUnbounded));
    128   rv = sock->Connect(&callback, log);
    129   EXPECT_TRUE(net::LogContainsBeginEvent(
    130       *log, 0, net::LoadLog::TYPE_SSL_CONNECT));
    131   if (rv != net::OK) {
    132     ASSERT_EQ(net::ERR_IO_PENDING, rv);
    133     EXPECT_FALSE(sock->IsConnected());
    134     EXPECT_FALSE(net::LogContainsEndEvent(
    135         *log, -1, net::LoadLog::TYPE_SSL_CONNECT));
    136 
    137     rv = callback.WaitForResult();
    138     EXPECT_EQ(net::ERR_CERT_DATE_INVALID, rv);
    139   }
    140 
    141   // We cannot test sock->IsConnected(), as the NSS implementation disconnects
    142   // the socket when it encounters an error, whereas other implementations
    143   // leave it connected.
    144 
    145   EXPECT_TRUE(net::LogContainsEndEvent(
    146       *log, -1, net::LoadLog::TYPE_SSL_CONNECT));
    147 }
    148 
    149 TEST_F(SSLClientSocketTest, ConnectMismatched) {
    150   StartMismatchedServer();
    151 
    152   net::AddressList addr;
    153   TestCompletionCallback callback;
    154 
    155   net::HostResolver::RequestInfo info(server_.kMismatchedHostName,
    156                                       server_.kOKHTTPSPort);
    157   int rv = resolver_->Resolve(info, &addr, NULL, NULL, NULL);
    158   EXPECT_EQ(net::OK, rv);
    159 
    160   net::ClientSocket *transport = new net::TCPClientSocket(addr);
    161   rv = transport->Connect(&callback, NULL);
    162   if (rv == net::ERR_IO_PENDING)
    163     rv = callback.WaitForResult();
    164   EXPECT_EQ(net::OK, rv);
    165 
    166   scoped_ptr<net::SSLClientSocket> sock(
    167       socket_factory_->CreateSSLClientSocket(transport,
    168           server_.kMismatchedHostName, kDefaultSSLConfig));
    169 
    170   EXPECT_FALSE(sock->IsConnected());
    171 
    172   scoped_refptr<net::LoadLog> log(new net::LoadLog(net::LoadLog::kUnbounded));
    173   rv = sock->Connect(&callback, log);
    174   EXPECT_TRUE(net::LogContainsBeginEvent(
    175       *log, 0, net::LoadLog::TYPE_SSL_CONNECT));
    176   if (rv != net::ERR_CERT_COMMON_NAME_INVALID) {
    177     ASSERT_EQ(net::ERR_IO_PENDING, rv);
    178     EXPECT_FALSE(sock->IsConnected());
    179     EXPECT_FALSE(net::LogContainsEndEvent(
    180         *log, -1, net::LoadLog::TYPE_SSL_CONNECT));
    181 
    182     rv = callback.WaitForResult();
    183     EXPECT_EQ(net::ERR_CERT_COMMON_NAME_INVALID, rv);
    184   }
    185 
    186   // We cannot test sock->IsConnected(), as the NSS implementation disconnects
    187   // the socket when it encounters an error, whereas other implementations
    188   // leave it connected.
    189 
    190   EXPECT_TRUE(net::LogContainsEndEvent(
    191       *log, -1, net::LoadLog::TYPE_SSL_CONNECT));
    192 }
    193 
    194 // TODO(wtc): Add unit tests for IsConnectedAndIdle:
    195 //   - Server closes an SSL connection (with a close_notify alert message).
    196 //   - Server closes the underlying TCP connection directly.
    197 //   - Server sends data unexpectedly.
    198 
    199 TEST_F(SSLClientSocketTest, Read) {
    200   StartOKServer();
    201 
    202   net::AddressList addr;
    203   TestCompletionCallback callback;
    204 
    205   net::HostResolver::RequestInfo info(server_.kHostName, server_.kOKHTTPSPort);
    206   int rv = resolver_->Resolve(info, &addr, &callback, NULL, NULL);
    207   EXPECT_EQ(net::ERR_IO_PENDING, rv);
    208 
    209   rv = callback.WaitForResult();
    210   EXPECT_EQ(net::OK, rv);
    211 
    212   net::ClientSocket *transport = new net::TCPClientSocket(addr);
    213   rv = transport->Connect(&callback, NULL);
    214   if (rv == net::ERR_IO_PENDING)
    215     rv = callback.WaitForResult();
    216   EXPECT_EQ(net::OK, rv);
    217 
    218   scoped_ptr<net::SSLClientSocket> sock(
    219       socket_factory_->CreateSSLClientSocket(transport,
    220                                              server_.kHostName,
    221                                              kDefaultSSLConfig));
    222 
    223   rv = sock->Connect(&callback, NULL);
    224   if (rv != net::OK) {
    225     ASSERT_EQ(net::ERR_IO_PENDING, rv);
    226 
    227     rv = callback.WaitForResult();
    228     EXPECT_EQ(net::OK, rv);
    229   }
    230   EXPECT_TRUE(sock->IsConnected());
    231 
    232   const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
    233   scoped_refptr<net::IOBuffer> request_buffer =
    234       new net::IOBuffer(arraysize(request_text) - 1);
    235   memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
    236 
    237   rv = sock->Write(request_buffer, arraysize(request_text) - 1, &callback);
    238   EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
    239 
    240   if (rv == net::ERR_IO_PENDING)
    241     rv = callback.WaitForResult();
    242   EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
    243 
    244   scoped_refptr<net::IOBuffer> buf = new net::IOBuffer(4096);
    245   for (;;) {
    246     rv = sock->Read(buf, 4096, &callback);
    247     EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
    248 
    249     if (rv == net::ERR_IO_PENDING)
    250       rv = callback.WaitForResult();
    251 
    252     EXPECT_GE(rv, 0);
    253     if (rv <= 0)
    254       break;
    255   }
    256 }
    257 
    258 // Test the full duplex mode, with Read and Write pending at the same time.
    259 // This test also serves as a regression test for http://crbug.com/29815.
    260 TEST_F(SSLClientSocketTest, Read_FullDuplex) {
    261   StartOKServer();
    262 
    263   net::AddressList addr;
    264   TestCompletionCallback callback;  // Used for everything except Write.
    265   TestCompletionCallback callback2;  // Used for Write only.
    266 
    267   net::HostResolver::RequestInfo info(server_.kHostName, server_.kOKHTTPSPort);
    268   int rv = resolver_->Resolve(info, &addr, &callback, NULL, NULL);
    269   EXPECT_EQ(net::ERR_IO_PENDING, rv);
    270 
    271   rv = callback.WaitForResult();
    272   EXPECT_EQ(net::OK, rv);
    273 
    274   net::ClientSocket *transport = new net::TCPClientSocket(addr);
    275   rv = transport->Connect(&callback, NULL);
    276   if (rv == net::ERR_IO_PENDING)
    277     rv = callback.WaitForResult();
    278   EXPECT_EQ(net::OK, rv);
    279 
    280   scoped_ptr<net::SSLClientSocket> sock(
    281       socket_factory_->CreateSSLClientSocket(transport,
    282                                              server_.kHostName,
    283                                              kDefaultSSLConfig));
    284 
    285   rv = sock->Connect(&callback, NULL);
    286   if (rv != net::OK) {
    287     ASSERT_EQ(net::ERR_IO_PENDING, rv);
    288 
    289     rv = callback.WaitForResult();
    290     EXPECT_EQ(net::OK, rv);
    291   }
    292   EXPECT_TRUE(sock->IsConnected());
    293 
    294   // Issue a "hanging" Read first.
    295   scoped_refptr<net::IOBuffer> buf = new net::IOBuffer(4096);
    296   rv = sock->Read(buf, 4096, &callback);
    297   // We haven't written the request, so there should be no response yet.
    298   ASSERT_EQ(net::ERR_IO_PENDING, rv);
    299 
    300   // Write the request.
    301   // The request is padded with a User-Agent header to a size that causes the
    302   // memio circular buffer (4k bytes) in SSLClientSocketNSS to wrap around.
    303   // This tests the fix for http://crbug.com/29815.
    304   std::string request_text = "GET / HTTP/1.1\r\nUser-Agent: long browser name ";
    305   for (int i = 0; i < 3800; ++i)
    306     request_text.push_back('*');
    307   request_text.append("\r\n\r\n");
    308   scoped_refptr<net::IOBuffer> request_buffer =
    309       new net::StringIOBuffer(request_text);
    310 
    311   rv = sock->Write(request_buffer, request_text.size(), &callback2);
    312   EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
    313 
    314   if (rv == net::ERR_IO_PENDING)
    315     rv = callback2.WaitForResult();
    316   EXPECT_EQ(static_cast<int>(request_text.size()), rv);
    317 
    318   // Now get the Read result.
    319   rv = callback.WaitForResult();
    320   EXPECT_GT(rv, 0);
    321 }
    322 
    323 TEST_F(SSLClientSocketTest, Read_SmallChunks) {
    324   StartOKServer();
    325 
    326   net::AddressList addr;
    327   TestCompletionCallback callback;
    328 
    329   net::HostResolver::RequestInfo info(server_.kHostName, server_.kOKHTTPSPort);
    330   int rv = resolver_->Resolve(info, &addr, NULL, NULL, NULL);
    331   EXPECT_EQ(net::OK, rv);
    332 
    333   net::ClientSocket *transport = new net::TCPClientSocket(addr);
    334   rv = transport->Connect(&callback, NULL);
    335   if (rv == net::ERR_IO_PENDING)
    336     rv = callback.WaitForResult();
    337   EXPECT_EQ(net::OK, rv);
    338 
    339   scoped_ptr<net::SSLClientSocket> sock(
    340       socket_factory_->CreateSSLClientSocket(transport,
    341           server_.kHostName, kDefaultSSLConfig));
    342 
    343   rv = sock->Connect(&callback, NULL);
    344   if (rv != net::OK) {
    345     ASSERT_EQ(net::ERR_IO_PENDING, rv);
    346 
    347     rv = callback.WaitForResult();
    348     EXPECT_EQ(net::OK, rv);
    349   }
    350 
    351   const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
    352   scoped_refptr<net::IOBuffer> request_buffer =
    353       new net::IOBuffer(arraysize(request_text) - 1);
    354   memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
    355 
    356   rv = sock->Write(request_buffer, arraysize(request_text) - 1, &callback);
    357   EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
    358 
    359   if (rv == net::ERR_IO_PENDING)
    360     rv = callback.WaitForResult();
    361   EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
    362 
    363   scoped_refptr<net::IOBuffer> buf = new net::IOBuffer(1);
    364   for (;;) {
    365     rv = sock->Read(buf, 1, &callback);
    366     EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
    367 
    368     if (rv == net::ERR_IO_PENDING)
    369       rv = callback.WaitForResult();
    370 
    371     EXPECT_GE(rv, 0);
    372     if (rv <= 0)
    373       break;
    374   }
    375 }
    376 
    377 TEST_F(SSLClientSocketTest, Read_Interrupted) {
    378   StartOKServer();
    379 
    380   net::AddressList addr;
    381   TestCompletionCallback callback;
    382 
    383   net::HostResolver::RequestInfo info(server_.kHostName, server_.kOKHTTPSPort);
    384   int rv = resolver_->Resolve(info, &addr, NULL, NULL, NULL);
    385   EXPECT_EQ(net::OK, rv);
    386 
    387   net::ClientSocket *transport = new net::TCPClientSocket(addr);
    388   rv = transport->Connect(&callback, NULL);
    389   if (rv == net::ERR_IO_PENDING)
    390     rv = callback.WaitForResult();
    391   EXPECT_EQ(net::OK, rv);
    392 
    393   scoped_ptr<net::SSLClientSocket> sock(
    394       socket_factory_->CreateSSLClientSocket(transport,
    395           server_.kHostName, kDefaultSSLConfig));
    396 
    397   rv = sock->Connect(&callback, NULL);
    398   if (rv != net::OK) {
    399     ASSERT_EQ(net::ERR_IO_PENDING, rv);
    400 
    401     rv = callback.WaitForResult();
    402     EXPECT_EQ(net::OK, rv);
    403   }
    404 
    405   const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
    406   scoped_refptr<net::IOBuffer> request_buffer =
    407       new net::IOBuffer(arraysize(request_text) - 1);
    408   memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
    409 
    410   rv = sock->Write(request_buffer, arraysize(request_text) - 1, &callback);
    411   EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING);
    412 
    413   if (rv == net::ERR_IO_PENDING)
    414     rv = callback.WaitForResult();
    415   EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
    416 
    417   // Do a partial read and then exit.  This test should not crash!
    418   scoped_refptr<net::IOBuffer> buf = new net::IOBuffer(512);
    419   rv = sock->Read(buf, 512, &callback);
    420   EXPECT_TRUE(rv > 0 || rv == net::ERR_IO_PENDING);
    421 
    422   if (rv == net::ERR_IO_PENDING)
    423     rv = callback.WaitForResult();
    424 
    425   EXPECT_GT(rv, 0);
    426 }
    427