Home | History | Annotate | Download | only in base
      1 /*
      2  *  Copyright 2014 The WebRTC Project Authors. All rights reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #include <string>
     12 
     13 #include "webrtc/base/gunit.h"
     14 #include "webrtc/base/ipaddress.h"
     15 #include "webrtc/base/socketstream.h"
     16 #include "webrtc/base/ssladapter.h"
     17 #include "webrtc/base/sslstreamadapter.h"
     18 #include "webrtc/base/sslidentity.h"
     19 #include "webrtc/base/stream.h"
     20 #include "webrtc/base/virtualsocketserver.h"
     21 
     22 static const int kTimeout = 5000;
     23 
     24 static rtc::AsyncSocket* CreateSocket(const rtc::SSLMode& ssl_mode) {
     25   rtc::SocketAddress address(rtc::IPAddress(INADDR_ANY), 0);
     26 
     27   rtc::AsyncSocket* socket = rtc::Thread::Current()->
     28       socketserver()->CreateAsyncSocket(
     29       address.family(), (ssl_mode == rtc::SSL_MODE_DTLS) ?
     30       SOCK_DGRAM : SOCK_STREAM);
     31   socket->Bind(address);
     32 
     33   return socket;
     34 }
     35 
     36 static std::string GetSSLProtocolName(const rtc::SSLMode& ssl_mode) {
     37   return (ssl_mode == rtc::SSL_MODE_DTLS) ? "DTLS" : "TLS";
     38 }
     39 
     40 class SSLAdapterTestDummyClient : public sigslot::has_slots<> {
     41  public:
     42   explicit SSLAdapterTestDummyClient(const rtc::SSLMode& ssl_mode)
     43       : ssl_mode_(ssl_mode) {
     44     rtc::AsyncSocket* socket = CreateSocket(ssl_mode_);
     45 
     46     ssl_adapter_.reset(rtc::SSLAdapter::Create(socket));
     47 
     48     ssl_adapter_->SetMode(ssl_mode_);
     49 
     50     // Ignore any certificate errors for the purpose of testing.
     51     // Note: We do this only because we don't have a real certificate.
     52     // NEVER USE THIS IN PRODUCTION CODE!
     53     ssl_adapter_->set_ignore_bad_cert(true);
     54 
     55     ssl_adapter_->SignalReadEvent.connect(this,
     56         &SSLAdapterTestDummyClient::OnSSLAdapterReadEvent);
     57     ssl_adapter_->SignalCloseEvent.connect(this,
     58         &SSLAdapterTestDummyClient::OnSSLAdapterCloseEvent);
     59   }
     60 
     61   rtc::SocketAddress GetAddress() const {
     62     return ssl_adapter_->GetLocalAddress();
     63   }
     64 
     65   rtc::AsyncSocket::ConnState GetState() const {
     66     return ssl_adapter_->GetState();
     67   }
     68 
     69   const std::string& GetReceivedData() const {
     70     return data_;
     71   }
     72 
     73   int Connect(const std::string& hostname, const rtc::SocketAddress& address) {
     74     LOG(LS_INFO) << "Initiating connection with " << address;
     75 
     76     int rv = ssl_adapter_->Connect(address);
     77 
     78     if (rv == 0) {
     79       LOG(LS_INFO) << "Starting " << GetSSLProtocolName(ssl_mode_)
     80           << " handshake with " << hostname;
     81 
     82       if (ssl_adapter_->StartSSL(hostname.c_str(), false) != 0) {
     83         return -1;
     84       }
     85     }
     86 
     87     return rv;
     88   }
     89 
     90   int Close() {
     91     return ssl_adapter_->Close();
     92   }
     93 
     94   int Send(const std::string& message) {
     95     LOG(LS_INFO) << "Client sending '" << message << "'";
     96 
     97     return ssl_adapter_->Send(message.data(), message.length());
     98   }
     99 
    100   void OnSSLAdapterReadEvent(rtc::AsyncSocket* socket) {
    101     char buffer[4096] = "";
    102 
    103     // Read data received from the server and store it in our internal buffer.
    104     int read = socket->Recv(buffer, sizeof(buffer) - 1);
    105     if (read != -1) {
    106       buffer[read] = '\0';
    107 
    108       LOG(LS_INFO) << "Client received '" << buffer << "'";
    109 
    110       data_ += buffer;
    111     }
    112   }
    113 
    114   void OnSSLAdapterCloseEvent(rtc::AsyncSocket* socket, int error) {
    115     // OpenSSLAdapter signals handshake failure with a close event, but without
    116     // closing the socket! Let's close the socket here. This way GetState() can
    117     // return CS_CLOSED after failure.
    118     if (socket->GetState() != rtc::AsyncSocket::CS_CLOSED) {
    119       socket->Close();
    120     }
    121   }
    122 
    123  private:
    124   const rtc::SSLMode ssl_mode_;
    125 
    126   rtc::scoped_ptr<rtc::SSLAdapter> ssl_adapter_;
    127 
    128   std::string data_;
    129 };
    130 
    131 class SSLAdapterTestDummyServer : public sigslot::has_slots<> {
    132  public:
    133   explicit SSLAdapterTestDummyServer(const rtc::SSLMode& ssl_mode,
    134                                      const rtc::KeyParams& key_params)
    135       : ssl_mode_(ssl_mode) {
    136     // Generate a key pair and a certificate for this host.
    137     ssl_identity_.reset(rtc::SSLIdentity::Generate(GetHostname(), key_params));
    138 
    139     server_socket_.reset(CreateSocket(ssl_mode_));
    140 
    141     if (ssl_mode_ == rtc::SSL_MODE_TLS) {
    142       server_socket_->SignalReadEvent.connect(this,
    143           &SSLAdapterTestDummyServer::OnServerSocketReadEvent);
    144 
    145       server_socket_->Listen(1);
    146     }
    147 
    148     LOG(LS_INFO) << ((ssl_mode_ == rtc::SSL_MODE_DTLS) ? "UDP" : "TCP")
    149         << " server listening on " << server_socket_->GetLocalAddress();
    150   }
    151 
    152   rtc::SocketAddress GetAddress() const {
    153     return server_socket_->GetLocalAddress();
    154   }
    155 
    156   std::string GetHostname() const {
    157     // Since we don't have a real certificate anyway, the value here doesn't
    158     // really matter.
    159     return "example.com";
    160   }
    161 
    162   const std::string& GetReceivedData() const {
    163     return data_;
    164   }
    165 
    166   int Send(const std::string& message) {
    167     if (ssl_stream_adapter_ == NULL
    168         || ssl_stream_adapter_->GetState() != rtc::SS_OPEN) {
    169       // No connection yet.
    170       return -1;
    171     }
    172 
    173     LOG(LS_INFO) << "Server sending '" << message << "'";
    174 
    175     size_t written;
    176     int error;
    177 
    178     rtc::StreamResult r = ssl_stream_adapter_->Write(message.data(),
    179         message.length(), &written, &error);
    180     if (r == rtc::SR_SUCCESS) {
    181       return written;
    182     } else {
    183       return -1;
    184     }
    185   }
    186 
    187   void AcceptConnection(const rtc::SocketAddress& address) {
    188     // Only a single connection is supported.
    189     ASSERT_TRUE(ssl_stream_adapter_ == NULL);
    190 
    191     // This is only for DTLS.
    192     ASSERT_EQ(rtc::SSL_MODE_DTLS, ssl_mode_);
    193 
    194     // Transfer ownership of the socket to the SSLStreamAdapter object.
    195     rtc::AsyncSocket* socket = server_socket_.release();
    196 
    197     socket->Connect(address);
    198 
    199     DoHandshake(socket);
    200   }
    201 
    202   void OnServerSocketReadEvent(rtc::AsyncSocket* socket) {
    203     // Only a single connection is supported.
    204     ASSERT_TRUE(ssl_stream_adapter_ == NULL);
    205 
    206     DoHandshake(server_socket_->Accept(NULL));
    207   }
    208 
    209   void OnSSLStreamAdapterEvent(rtc::StreamInterface* stream, int sig, int err) {
    210     if (sig & rtc::SE_READ) {
    211       char buffer[4096] = "";
    212 
    213       size_t read;
    214       int error;
    215 
    216       // Read data received from the client and store it in our internal
    217       // buffer.
    218       rtc::StreamResult r = stream->Read(buffer,
    219           sizeof(buffer) - 1, &read, &error);
    220       if (r == rtc::SR_SUCCESS) {
    221         buffer[read] = '\0';
    222 
    223         LOG(LS_INFO) << "Server received '" << buffer << "'";
    224 
    225         data_ += buffer;
    226       }
    227     }
    228   }
    229 
    230  private:
    231   void DoHandshake(rtc::AsyncSocket* socket) {
    232     rtc::SocketStream* stream = new rtc::SocketStream(socket);
    233 
    234     ssl_stream_adapter_.reset(rtc::SSLStreamAdapter::Create(stream));
    235 
    236     ssl_stream_adapter_->SetMode(ssl_mode_);
    237     ssl_stream_adapter_->SetServerRole();
    238 
    239     // SSLStreamAdapter is normally used for peer-to-peer communication, but
    240     // here we're testing communication between a client and a server
    241     // (e.g. a WebRTC-based application and an RFC 5766 TURN server), where
    242     // clients are not required to provide a certificate during handshake.
    243     // Accordingly, we must disable client authentication here.
    244     ssl_stream_adapter_->set_client_auth_enabled(false);
    245 
    246     ssl_stream_adapter_->SetIdentity(ssl_identity_->GetReference());
    247 
    248     // Set a bogus peer certificate digest.
    249     unsigned char digest[20];
    250     size_t digest_len = sizeof(digest);
    251     ssl_stream_adapter_->SetPeerCertificateDigest(rtc::DIGEST_SHA_1, digest,
    252         digest_len);
    253 
    254     ssl_stream_adapter_->StartSSLWithPeer();
    255 
    256     ssl_stream_adapter_->SignalEvent.connect(this,
    257         &SSLAdapterTestDummyServer::OnSSLStreamAdapterEvent);
    258   }
    259 
    260   const rtc::SSLMode ssl_mode_;
    261 
    262   rtc::scoped_ptr<rtc::AsyncSocket> server_socket_;
    263   rtc::scoped_ptr<rtc::SSLStreamAdapter> ssl_stream_adapter_;
    264 
    265   rtc::scoped_ptr<rtc::SSLIdentity> ssl_identity_;
    266 
    267   std::string data_;
    268 };
    269 
    270 class SSLAdapterTestBase : public testing::Test,
    271                            public sigslot::has_slots<> {
    272  public:
    273   explicit SSLAdapterTestBase(const rtc::SSLMode& ssl_mode,
    274                               const rtc::KeyParams& key_params)
    275       : ssl_mode_(ssl_mode),
    276         ss_scope_(new rtc::VirtualSocketServer(NULL)),
    277         server_(new SSLAdapterTestDummyServer(ssl_mode_, key_params)),
    278         client_(new SSLAdapterTestDummyClient(ssl_mode_)),
    279         handshake_wait_(kTimeout) {}
    280 
    281   void SetHandshakeWait(int wait) {
    282     handshake_wait_ = wait;
    283   }
    284 
    285   void TestHandshake(bool expect_success) {
    286     int rv;
    287 
    288     // The initial state is CS_CLOSED
    289     ASSERT_EQ(rtc::AsyncSocket::CS_CLOSED, client_->GetState());
    290 
    291     rv = client_->Connect(server_->GetHostname(), server_->GetAddress());
    292     ASSERT_EQ(0, rv);
    293 
    294     // Now the state should be CS_CONNECTING
    295     ASSERT_EQ(rtc::AsyncSocket::CS_CONNECTING, client_->GetState());
    296 
    297     if (ssl_mode_ == rtc::SSL_MODE_DTLS) {
    298       // For DTLS, call AcceptConnection() with the client's address.
    299       server_->AcceptConnection(client_->GetAddress());
    300     }
    301 
    302     if (expect_success) {
    303       // If expecting success, the client should end up in the CS_CONNECTED
    304       // state after handshake.
    305       EXPECT_EQ_WAIT(rtc::AsyncSocket::CS_CONNECTED, client_->GetState(),
    306           handshake_wait_);
    307 
    308       LOG(LS_INFO) << GetSSLProtocolName(ssl_mode_) << " handshake complete.";
    309 
    310     } else {
    311       // On handshake failure the client should end up in the CS_CLOSED state.
    312       EXPECT_EQ_WAIT(rtc::AsyncSocket::CS_CLOSED, client_->GetState(),
    313           handshake_wait_);
    314 
    315       LOG(LS_INFO) << GetSSLProtocolName(ssl_mode_) << " handshake failed.";
    316     }
    317   }
    318 
    319   void TestTransfer(const std::string& message) {
    320     int rv;
    321 
    322     rv = client_->Send(message);
    323     ASSERT_EQ(static_cast<int>(message.length()), rv);
    324 
    325     // The server should have received the client's message.
    326     EXPECT_EQ_WAIT(message, server_->GetReceivedData(), kTimeout);
    327 
    328     rv = server_->Send(message);
    329     ASSERT_EQ(static_cast<int>(message.length()), rv);
    330 
    331     // The client should have received the server's message.
    332     EXPECT_EQ_WAIT(message, client_->GetReceivedData(), kTimeout);
    333 
    334     LOG(LS_INFO) << "Transfer complete.";
    335   }
    336 
    337  private:
    338   const rtc::SSLMode ssl_mode_;
    339 
    340   const rtc::SocketServerScope ss_scope_;
    341 
    342   rtc::scoped_ptr<SSLAdapterTestDummyServer> server_;
    343   rtc::scoped_ptr<SSLAdapterTestDummyClient> client_;
    344 
    345   int handshake_wait_;
    346 };
    347 
    348 class SSLAdapterTestTLS_RSA : public SSLAdapterTestBase {
    349  public:
    350   SSLAdapterTestTLS_RSA()
    351       : SSLAdapterTestBase(rtc::SSL_MODE_TLS, rtc::KeyParams::RSA()) {}
    352 };
    353 
    354 class SSLAdapterTestTLS_ECDSA : public SSLAdapterTestBase {
    355  public:
    356   SSLAdapterTestTLS_ECDSA()
    357       : SSLAdapterTestBase(rtc::SSL_MODE_TLS, rtc::KeyParams::ECDSA()) {}
    358 };
    359 
    360 class SSLAdapterTestDTLS_RSA : public SSLAdapterTestBase {
    361  public:
    362   SSLAdapterTestDTLS_RSA()
    363       : SSLAdapterTestBase(rtc::SSL_MODE_DTLS, rtc::KeyParams::RSA()) {}
    364 };
    365 
    366 class SSLAdapterTestDTLS_ECDSA : public SSLAdapterTestBase {
    367  public:
    368   SSLAdapterTestDTLS_ECDSA()
    369       : SSLAdapterTestBase(rtc::SSL_MODE_DTLS, rtc::KeyParams::ECDSA()) {}
    370 };
    371 
    372 #if SSL_USE_OPENSSL
    373 
    374 // Basic tests: TLS
    375 
    376 // Test that handshake works, using RSA
    377 TEST_F(SSLAdapterTestTLS_RSA, TestTLSConnect) {
    378   TestHandshake(true);
    379 }
    380 
    381 // Test that handshake works, using ECDSA
    382 TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSConnect) {
    383   TestHandshake(true);
    384 }
    385 
    386 // Test transfer between client and server, using RSA
    387 TEST_F(SSLAdapterTestTLS_RSA, TestTLSTransfer) {
    388   TestHandshake(true);
    389   TestTransfer("Hello, world!");
    390 }
    391 
    392 // Test transfer between client and server, using ECDSA
    393 TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSTransfer) {
    394   TestHandshake(true);
    395   TestTransfer("Hello, world!");
    396 }
    397 
    398 // Basic tests: DTLS
    399 
    400 // Test that handshake works, using RSA
    401 TEST_F(SSLAdapterTestDTLS_RSA, TestDTLSConnect) {
    402   TestHandshake(true);
    403 }
    404 
    405 // Test that handshake works, using ECDSA
    406 TEST_F(SSLAdapterTestDTLS_ECDSA, TestDTLSConnect) {
    407   TestHandshake(true);
    408 }
    409 
    410 // Test transfer between client and server, using RSA
    411 TEST_F(SSLAdapterTestDTLS_RSA, TestDTLSTransfer) {
    412   TestHandshake(true);
    413   TestTransfer("Hello, world!");
    414 }
    415 
    416 // Test transfer between client and server, using ECDSA
    417 TEST_F(SSLAdapterTestDTLS_ECDSA, TestDTLSTransfer) {
    418   TestHandshake(true);
    419   TestTransfer("Hello, world!");
    420 }
    421 
    422 #endif  // SSL_USE_OPENSSL
    423