Home | History | Annotate | Download | only in protocol
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "remoting/protocol/authenticator_test_base.h"
      6 
      7 #include "base/base64.h"
      8 #include "base/files/file_path.h"
      9 #include "base/files/file_util.h"
     10 #include "base/path_service.h"
     11 #include "base/test/test_timeouts.h"
     12 #include "base/timer/timer.h"
     13 #include "net/base/net_errors.h"
     14 #include "net/base/test_data_directory.h"
     15 #include "remoting/base/rsa_key_pair.h"
     16 #include "remoting/protocol/authenticator.h"
     17 #include "remoting/protocol/channel_authenticator.h"
     18 #include "remoting/protocol/fake_stream_socket.h"
     19 #include "testing/gtest/include/gtest/gtest.h"
     20 #include "third_party/webrtc/libjingle/xmllite/xmlelement.h"
     21 
     22 using testing::_;
     23 using testing::SaveArg;
     24 
     25 namespace remoting {
     26 namespace protocol {
     27 
     28 namespace {
     29 
     30 ACTION_P(QuitThreadOnCounter, counter) {
     31   --(*counter);
     32   EXPECT_GE(*counter, 0);
     33   if (*counter == 0)
     34     base::MessageLoop::current()->Quit();
     35 }
     36 
     37 }  // namespace
     38 
     39 AuthenticatorTestBase::MockChannelDoneCallback::MockChannelDoneCallback() {}
     40 
     41 AuthenticatorTestBase::MockChannelDoneCallback::~MockChannelDoneCallback() {}
     42 
     43 AuthenticatorTestBase::AuthenticatorTestBase() {}
     44 
     45 AuthenticatorTestBase::~AuthenticatorTestBase() {}
     46 
     47 void AuthenticatorTestBase::SetUp() {
     48   base::FilePath certs_dir(net::GetTestCertsDirectory());
     49 
     50   base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der");
     51   ASSERT_TRUE(base::ReadFileToString(cert_path, &host_cert_));
     52 
     53   base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
     54   std::string key_string;
     55   ASSERT_TRUE(base::ReadFileToString(key_path, &key_string));
     56   std::string key_base64;
     57   base::Base64Encode(key_string, &key_base64);
     58   key_pair_ = RsaKeyPair::FromString(key_base64);
     59   ASSERT_TRUE(key_pair_.get());
     60   host_public_key_ = key_pair_->GetPublicKey();
     61 }
     62 
     63 void AuthenticatorTestBase::RunAuthExchange() {
     64   ContinueAuthExchangeWith(client_.get(),
     65                            host_.get(),
     66                            client_->started(),
     67                            host_->started());
     68 }
     69 
     70 void AuthenticatorTestBase::RunHostInitiatedAuthExchange() {
     71   ContinueAuthExchangeWith(host_.get(),
     72                            client_.get(),
     73                            host_->started(),
     74                            client_->started());
     75 }
     76 
     77 // static
     78 // This function sends a message from the sender and receiver and recursively
     79 // calls itself to the send the next message from the receiver to the sender
     80 // untils the authentication completes.
     81 void AuthenticatorTestBase::ContinueAuthExchangeWith(Authenticator* sender,
     82                                                      Authenticator* receiver,
     83                                                      bool sender_started,
     84                                                      bool receiver_started) {
     85   scoped_ptr<buzz::XmlElement> message;
     86   ASSERT_NE(Authenticator::WAITING_MESSAGE, sender->state());
     87   if (sender->state() == Authenticator::ACCEPTED ||
     88       sender->state() == Authenticator::REJECTED)
     89     return;
     90 
     91   // Verify that once the started flag for either party is set to true,
     92   // it should always stay true.
     93   if (receiver_started) {
     94     ASSERT_TRUE(receiver->started());
     95   }
     96 
     97   if (sender_started) {
     98     ASSERT_TRUE(sender->started());
     99   }
    100 
    101   ASSERT_EQ(Authenticator::MESSAGE_READY, sender->state());
    102   message = sender->GetNextMessage();
    103   ASSERT_TRUE(message.get());
    104   ASSERT_NE(Authenticator::MESSAGE_READY, sender->state());
    105 
    106   ASSERT_EQ(Authenticator::WAITING_MESSAGE, receiver->state());
    107   receiver->ProcessMessage(message.get(), base::Bind(
    108       &AuthenticatorTestBase::ContinueAuthExchangeWith,
    109       base::Unretained(receiver), base::Unretained(sender),
    110       receiver->started(), sender->started()));
    111 }
    112 
    113 void AuthenticatorTestBase::RunChannelAuth(bool expected_fail) {
    114   client_fake_socket_.reset(new FakeStreamSocket());
    115   host_fake_socket_.reset(new FakeStreamSocket());
    116   client_fake_socket_->PairWith(host_fake_socket_.get());
    117 
    118   client_auth_->SecureAndAuthenticate(
    119       client_fake_socket_.PassAs<net::StreamSocket>(),
    120       base::Bind(&AuthenticatorTestBase::OnClientConnected,
    121                  base::Unretained(this)));
    122 
    123   host_auth_->SecureAndAuthenticate(
    124       host_fake_socket_.PassAs<net::StreamSocket>(),
    125       base::Bind(&AuthenticatorTestBase::OnHostConnected,
    126                  base::Unretained(this)));
    127 
    128   // Expect two callbacks to be called - the client callback and the host
    129   // callback.
    130   int callback_counter = 2;
    131 
    132   EXPECT_CALL(client_callback_, OnDone(net::OK))
    133       .WillOnce(QuitThreadOnCounter(&callback_counter));
    134   if (expected_fail) {
    135     EXPECT_CALL(host_callback_, OnDone(net::ERR_FAILED))
    136          .WillOnce(QuitThreadOnCounter(&callback_counter));
    137   } else {
    138     EXPECT_CALL(host_callback_, OnDone(net::OK))
    139         .WillOnce(QuitThreadOnCounter(&callback_counter));
    140   }
    141 
    142   // Ensure that .Run() does not run unbounded if the callbacks are never
    143   // called.
    144   base::Timer shutdown_timer(false, false);
    145   shutdown_timer.Start(FROM_HERE,
    146                        TestTimeouts::action_timeout(),
    147                        base::MessageLoop::QuitClosure());
    148   message_loop_.Run();
    149   shutdown_timer.Stop();
    150 
    151   testing::Mock::VerifyAndClearExpectations(&client_callback_);
    152   testing::Mock::VerifyAndClearExpectations(&host_callback_);
    153 
    154   if (!expected_fail) {
    155     ASSERT_TRUE(client_socket_.get() != NULL);
    156     ASSERT_TRUE(host_socket_.get() != NULL);
    157   }
    158 }
    159 
    160 void AuthenticatorTestBase::OnHostConnected(
    161     int error,
    162     scoped_ptr<net::StreamSocket> socket) {
    163   host_callback_.OnDone(error);
    164   host_socket_ = socket.Pass();
    165 }
    166 
    167 void AuthenticatorTestBase::OnClientConnected(
    168     int error,
    169     scoped_ptr<net::StreamSocket> socket) {
    170   client_callback_.OnDone(error);
    171   client_socket_ = socket.Pass();
    172 }
    173 
    174 }  // namespace protocol
    175 }  // namespace remoting
    176