Home | History | Annotate | Download | only in protocol
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "remoting/protocol/authenticator_test_base.h"
      6 
      7 #include "base/base64.h"
      8 #include "base/file_util.h"
      9 #include "base/files/file_path.h"
     10 #include "base/path_service.h"
     11 #include "base/test/test_timeouts.h"
     12 #include "base/timer/timer.h"
     13 #include "net/base/test_data_directory.h"
     14 #include "remoting/base/rsa_key_pair.h"
     15 #include "remoting/protocol/authenticator.h"
     16 #include "remoting/protocol/channel_authenticator.h"
     17 #include "remoting/protocol/fake_session.h"
     18 #include "testing/gtest/include/gtest/gtest.h"
     19 #include "third_party/libjingle/source/talk/xmllite/xmlelement.h"
     20 
     21 using testing::_;
     22 using testing::SaveArg;
     23 
     24 namespace remoting {
     25 namespace protocol {
     26 
     27 namespace {
     28 
     29 ACTION_P(QuitThreadOnCounter, counter) {
     30   --(*counter);
     31   EXPECT_GE(*counter, 0);
     32   if (*counter == 0)
     33     base::MessageLoop::current()->Quit();
     34 }
     35 
     36 }  // namespace
     37 
     38 AuthenticatorTestBase::MockChannelDoneCallback::MockChannelDoneCallback() {}
     39 
     40 AuthenticatorTestBase::MockChannelDoneCallback::~MockChannelDoneCallback() {}
     41 
     42 AuthenticatorTestBase::AuthenticatorTestBase() {}
     43 
     44 AuthenticatorTestBase::~AuthenticatorTestBase() {}
     45 
     46 void AuthenticatorTestBase::SetUp() {
     47   base::FilePath certs_dir(net::GetTestCertsDirectory());
     48 
     49   base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der");
     50   ASSERT_TRUE(base::ReadFileToString(cert_path, &host_cert_));
     51 
     52   base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
     53   std::string key_string;
     54   ASSERT_TRUE(base::ReadFileToString(key_path, &key_string));
     55   std::string key_base64;
     56   base::Base64Encode(key_string, &key_base64);
     57   key_pair_ = RsaKeyPair::FromString(key_base64);
     58   ASSERT_TRUE(key_pair_.get());
     59   host_public_key_ = key_pair_->GetPublicKey();
     60 }
     61 
     62 void AuthenticatorTestBase::RunAuthExchange() {
     63   ContinueAuthExchangeWith(client_.get(), host_.get());
     64 }
     65 
     66 void AuthenticatorTestBase::RunHostInitiatedAuthExchange() {
     67   ContinueAuthExchangeWith(host_.get(), client_.get());
     68 }
     69 
     70 // static
     71 void AuthenticatorTestBase::ContinueAuthExchangeWith(Authenticator* sender,
     72                                                      Authenticator* receiver) {
     73   scoped_ptr<buzz::XmlElement> message;
     74   ASSERT_NE(Authenticator::WAITING_MESSAGE, sender->state());
     75   if (sender->state() == Authenticator::ACCEPTED ||
     76       sender->state() == Authenticator::REJECTED)
     77     return;
     78   // Pass message from client to host.
     79   ASSERT_EQ(Authenticator::MESSAGE_READY, sender->state());
     80   message = sender->GetNextMessage();
     81   ASSERT_TRUE(message.get());
     82   ASSERT_NE(Authenticator::MESSAGE_READY, sender->state());
     83 
     84   ASSERT_EQ(Authenticator::WAITING_MESSAGE, receiver->state());
     85   receiver->ProcessMessage(message.get(), base::Bind(
     86       &AuthenticatorTestBase::ContinueAuthExchangeWith,
     87       base::Unretained(receiver), base::Unretained(sender)));
     88 }
     89 
     90 void AuthenticatorTestBase::RunChannelAuth(bool expected_fail) {
     91   client_fake_socket_.reset(new FakeSocket());
     92   host_fake_socket_.reset(new FakeSocket());
     93   client_fake_socket_->PairWith(host_fake_socket_.get());
     94 
     95   client_auth_->SecureAndAuthenticate(
     96       client_fake_socket_.PassAs<net::StreamSocket>(),
     97       base::Bind(&AuthenticatorTestBase::OnClientConnected,
     98                  base::Unretained(this)));
     99 
    100   host_auth_->SecureAndAuthenticate(
    101       host_fake_socket_.PassAs<net::StreamSocket>(),
    102       base::Bind(&AuthenticatorTestBase::OnHostConnected,
    103                  base::Unretained(this)));
    104 
    105   // Expect two callbacks to be called - the client callback and the host
    106   // callback.
    107   int callback_counter = 2;
    108 
    109   EXPECT_CALL(client_callback_, OnDone(net::OK))
    110       .WillOnce(QuitThreadOnCounter(&callback_counter));
    111   if (expected_fail) {
    112     EXPECT_CALL(host_callback_, OnDone(net::ERR_FAILED))
    113          .WillOnce(QuitThreadOnCounter(&callback_counter));
    114   } else {
    115     EXPECT_CALL(host_callback_, OnDone(net::OK))
    116         .WillOnce(QuitThreadOnCounter(&callback_counter));
    117   }
    118 
    119   // Ensure that .Run() does not run unbounded if the callbacks are never
    120   // called.
    121   base::Timer shutdown_timer(false, false);
    122   shutdown_timer.Start(FROM_HERE,
    123                        TestTimeouts::action_timeout(),
    124                        base::MessageLoop::QuitClosure());
    125   message_loop_.Run();
    126   shutdown_timer.Stop();
    127 
    128   testing::Mock::VerifyAndClearExpectations(&client_callback_);
    129   testing::Mock::VerifyAndClearExpectations(&host_callback_);
    130 
    131   if (!expected_fail) {
    132     ASSERT_TRUE(client_socket_.get() != NULL);
    133     ASSERT_TRUE(host_socket_.get() != NULL);
    134   }
    135 }
    136 
    137 void AuthenticatorTestBase::OnHostConnected(
    138     net::Error error,
    139     scoped_ptr<net::StreamSocket> socket) {
    140   host_callback_.OnDone(error);
    141   host_socket_ = socket.Pass();
    142 }
    143 
    144 void AuthenticatorTestBase::OnClientConnected(
    145     net::Error error,
    146     scoped_ptr<net::StreamSocket> socket) {
    147   client_callback_.OnDone(error);
    148   client_socket_ = socket.Pass();
    149 }
    150 
    151 }  // namespace protocol
    152 }  // namespace remoting
    153