Home | History | Annotate | Download | only in protocol
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "remoting/protocol/channel_multiplexer.h"
      6 
      7 #include "base/bind.h"
      8 #include "base/message_loop/message_loop.h"
      9 #include "net/base/net_errors.h"
     10 #include "net/socket/socket.h"
     11 #include "net/socket/stream_socket.h"
     12 #include "remoting/base/constants.h"
     13 #include "remoting/protocol/connection_tester.h"
     14 #include "remoting/protocol/fake_session.h"
     15 #include "testing/gmock/include/gmock/gmock.h"
     16 #include "testing/gtest/include/gtest/gtest.h"
     17 
     18 using testing::_;
     19 using testing::AtMost;
     20 using testing::InvokeWithoutArgs;
     21 
     22 namespace remoting {
     23 namespace protocol {
     24 
     25 namespace {
     26 
     27 const int kMessageSize = 1024;
     28 const int kMessages = 100;
     29 const char kMuxChannelName[] = "mux";
     30 
     31 const char kTestChannelName[] = "test";
     32 const char kTestChannelName2[] = "test2";
     33 
     34 
     35 void QuitCurrentThread() {
     36   base::MessageLoop::current()->PostTask(FROM_HERE,
     37                                          base::MessageLoop::QuitClosure());
     38 }
     39 
     40 class MockSocketCallback {
     41  public:
     42   MOCK_METHOD1(OnDone, void(int result));
     43 };
     44 
     45 class MockConnectCallback {
     46  public:
     47   MOCK_METHOD1(OnConnectedPtr, void(net::StreamSocket* socket));
     48   void OnConnected(scoped_ptr<net::StreamSocket> socket) {
     49     OnConnectedPtr(socket.release());
     50   }
     51 };
     52 
     53 }  // namespace
     54 
     55 class ChannelMultiplexerTest : public testing::Test {
     56  public:
     57   void DeleteAll() {
     58     host_socket1_.reset();
     59     host_socket2_.reset();
     60     client_socket1_.reset();
     61     client_socket2_.reset();
     62     host_mux_.reset();
     63     client_mux_.reset();
     64   }
     65 
     66   void DeleteAfterSessionFail() {
     67     host_mux_->CancelChannelCreation(kTestChannelName2);
     68     DeleteAll();
     69   }
     70 
     71  protected:
     72   virtual void SetUp() OVERRIDE {
     73     // Create pair of multiplexers and connect them to each other.
     74     host_mux_.reset(new ChannelMultiplexer(&host_session_, kMuxChannelName));
     75     client_mux_.reset(new ChannelMultiplexer(&client_session_,
     76                                              kMuxChannelName));
     77   }
     78 
     79   // Connect sockets to each other. Must be called after we've created at least
     80   // one channel with each multiplexer.
     81   void ConnectSockets() {
     82     FakeSocket* host_socket =
     83         host_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName);
     84     FakeSocket* client_socket =
     85         client_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName);
     86     host_socket->PairWith(client_socket);
     87 
     88     // Make writes asynchronous in one direction.
     89     host_socket->set_async_write(true);
     90   }
     91 
     92   void CreateChannel(const std::string& name,
     93                      scoped_ptr<net::StreamSocket>* host_socket,
     94                      scoped_ptr<net::StreamSocket>* client_socket) {
     95     int counter = 2;
     96     host_mux_->CreateStreamChannel(name, base::Bind(
     97         &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this),
     98         host_socket, &counter));
     99     client_mux_->CreateStreamChannel(name, base::Bind(
    100         &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this),
    101         client_socket, &counter));
    102 
    103     message_loop_.Run();
    104 
    105     EXPECT_TRUE(host_socket->get());
    106     EXPECT_TRUE(client_socket->get());
    107   }
    108 
    109   void OnChannelConnected(
    110       scoped_ptr<net::StreamSocket>* storage,
    111       int* counter,
    112       scoped_ptr<net::StreamSocket> socket) {
    113     *storage = socket.Pass();
    114     --(*counter);
    115     EXPECT_GE(*counter, 0);
    116     if (*counter == 0)
    117       QuitCurrentThread();
    118   }
    119 
    120   scoped_refptr<net::IOBufferWithSize> CreateTestBuffer(int size) {
    121     scoped_refptr<net::IOBufferWithSize> result =
    122         new net::IOBufferWithSize(size);
    123     for (int i = 0; i< size; ++i) {
    124       result->data()[i] = rand() % 256;
    125     }
    126     return result;
    127   }
    128 
    129   base::MessageLoop message_loop_;
    130 
    131   FakeSession host_session_;
    132   FakeSession client_session_;
    133 
    134   scoped_ptr<ChannelMultiplexer> host_mux_;
    135   scoped_ptr<ChannelMultiplexer> client_mux_;
    136 
    137   scoped_ptr<net::StreamSocket> host_socket1_;
    138   scoped_ptr<net::StreamSocket> client_socket1_;
    139   scoped_ptr<net::StreamSocket> host_socket2_;
    140   scoped_ptr<net::StreamSocket> client_socket2_;
    141 };
    142 
    143 
    144 TEST_F(ChannelMultiplexerTest, OneChannel) {
    145   scoped_ptr<net::StreamSocket> host_socket;
    146   scoped_ptr<net::StreamSocket> client_socket;
    147   ASSERT_NO_FATAL_FAILURE(
    148       CreateChannel(kTestChannelName, &host_socket, &client_socket));
    149 
    150   ConnectSockets();
    151 
    152   StreamConnectionTester tester(host_socket.get(), client_socket.get(),
    153                                 kMessageSize, kMessages);
    154   tester.Start();
    155   message_loop_.Run();
    156   tester.CheckResults();
    157 }
    158 
    159 TEST_F(ChannelMultiplexerTest, TwoChannels) {
    160   scoped_ptr<net::StreamSocket> host_socket1_;
    161   scoped_ptr<net::StreamSocket> client_socket1_;
    162   ASSERT_NO_FATAL_FAILURE(
    163       CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
    164 
    165   scoped_ptr<net::StreamSocket> host_socket2_;
    166   scoped_ptr<net::StreamSocket> client_socket2_;
    167   ASSERT_NO_FATAL_FAILURE(
    168       CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
    169 
    170   ConnectSockets();
    171 
    172   StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(),
    173                                 kMessageSize, kMessages);
    174   StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(),
    175                                  kMessageSize, kMessages);
    176   tester1.Start();
    177   tester2.Start();
    178   while (!tester1.done() || !tester2.done()) {
    179     message_loop_.Run();
    180   }
    181   tester1.CheckResults();
    182   tester2.CheckResults();
    183 }
    184 
    185 // Four channels, two in each direction
    186 TEST_F(ChannelMultiplexerTest, FourChannels) {
    187   scoped_ptr<net::StreamSocket> host_socket1_;
    188   scoped_ptr<net::StreamSocket> client_socket1_;
    189   ASSERT_NO_FATAL_FAILURE(
    190       CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
    191 
    192   scoped_ptr<net::StreamSocket> host_socket2_;
    193   scoped_ptr<net::StreamSocket> client_socket2_;
    194   ASSERT_NO_FATAL_FAILURE(
    195       CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
    196 
    197   scoped_ptr<net::StreamSocket> host_socket3;
    198   scoped_ptr<net::StreamSocket> client_socket3;
    199   ASSERT_NO_FATAL_FAILURE(
    200       CreateChannel("test3", &host_socket3, &client_socket3));
    201 
    202   scoped_ptr<net::StreamSocket> host_socket4;
    203   scoped_ptr<net::StreamSocket> client_socket4;
    204   ASSERT_NO_FATAL_FAILURE(
    205       CreateChannel("ch4", &host_socket4, &client_socket4));
    206 
    207   ConnectSockets();
    208 
    209   StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(),
    210                                 kMessageSize, kMessages);
    211   StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(),
    212                                  kMessageSize, kMessages);
    213   StreamConnectionTester tester3(client_socket3.get(), host_socket3.get(),
    214                                  kMessageSize, kMessages);
    215   StreamConnectionTester tester4(client_socket4.get(), host_socket4.get(),
    216                                  kMessageSize, kMessages);
    217   tester1.Start();
    218   tester2.Start();
    219   tester3.Start();
    220   tester4.Start();
    221   while (!tester1.done() || !tester2.done() ||
    222          !tester3.done() || !tester4.done()) {
    223     message_loop_.Run();
    224   }
    225   tester1.CheckResults();
    226   tester2.CheckResults();
    227   tester3.CheckResults();
    228   tester4.CheckResults();
    229 }
    230 
    231 TEST_F(ChannelMultiplexerTest, WriteFailSync) {
    232   scoped_ptr<net::StreamSocket> host_socket1_;
    233   scoped_ptr<net::StreamSocket> client_socket1_;
    234   ASSERT_NO_FATAL_FAILURE(
    235       CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
    236 
    237   scoped_ptr<net::StreamSocket> host_socket2_;
    238   scoped_ptr<net::StreamSocket> client_socket2_;
    239   ASSERT_NO_FATAL_FAILURE(
    240       CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
    241 
    242   ConnectSockets();
    243 
    244   host_session_.GetStreamChannel(kMuxChannelName)->
    245       set_next_write_error(net::ERR_FAILED);
    246   host_session_.GetStreamChannel(kMuxChannelName)->
    247       set_async_write(false);
    248 
    249   scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
    250 
    251   MockSocketCallback cb1;
    252   MockSocketCallback cb2;
    253 
    254   EXPECT_CALL(cb1, OnDone(_))
    255       .Times(0);
    256   EXPECT_CALL(cb2, OnDone(_))
    257       .Times(0);
    258 
    259   EXPECT_EQ(net::ERR_FAILED,
    260             host_socket1_->Write(buf.get(),
    261                                  buf->size(),
    262                                  base::Bind(&MockSocketCallback::OnDone,
    263                                             base::Unretained(&cb1))));
    264   EXPECT_EQ(net::ERR_FAILED,
    265             host_socket2_->Write(buf.get(),
    266                                  buf->size(),
    267                                  base::Bind(&MockSocketCallback::OnDone,
    268                                             base::Unretained(&cb2))));
    269 
    270   message_loop_.RunUntilIdle();
    271 }
    272 
    273 TEST_F(ChannelMultiplexerTest, WriteFailAsync) {
    274   ASSERT_NO_FATAL_FAILURE(
    275       CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
    276 
    277   ASSERT_NO_FATAL_FAILURE(
    278       CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
    279 
    280   ConnectSockets();
    281 
    282   host_session_.GetStreamChannel(kMuxChannelName)->
    283       set_next_write_error(net::ERR_FAILED);
    284   host_session_.GetStreamChannel(kMuxChannelName)->
    285       set_async_write(true);
    286 
    287   scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
    288 
    289   MockSocketCallback cb1;
    290   MockSocketCallback cb2;
    291   EXPECT_CALL(cb1, OnDone(net::ERR_FAILED));
    292   EXPECT_CALL(cb2, OnDone(net::ERR_FAILED));
    293 
    294   EXPECT_EQ(net::ERR_IO_PENDING,
    295             host_socket1_->Write(buf.get(),
    296                                  buf->size(),
    297                                  base::Bind(&MockSocketCallback::OnDone,
    298                                             base::Unretained(&cb1))));
    299   EXPECT_EQ(net::ERR_IO_PENDING,
    300             host_socket2_->Write(buf.get(),
    301                                  buf->size(),
    302                                  base::Bind(&MockSocketCallback::OnDone,
    303                                             base::Unretained(&cb2))));
    304 
    305   message_loop_.RunUntilIdle();
    306 }
    307 
    308 TEST_F(ChannelMultiplexerTest, DeleteWhenFailed) {
    309   ASSERT_NO_FATAL_FAILURE(
    310       CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
    311   ASSERT_NO_FATAL_FAILURE(
    312       CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
    313 
    314   ConnectSockets();
    315 
    316   host_session_.GetStreamChannel(kMuxChannelName)->
    317       set_next_write_error(net::ERR_FAILED);
    318   host_session_.GetStreamChannel(kMuxChannelName)->
    319       set_async_write(true);
    320 
    321   scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
    322 
    323   MockSocketCallback cb1;
    324   MockSocketCallback cb2;
    325 
    326   EXPECT_CALL(cb1, OnDone(net::ERR_FAILED))
    327       .Times(AtMost(1))
    328       .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll));
    329   EXPECT_CALL(cb2, OnDone(net::ERR_FAILED))
    330       .Times(AtMost(1))
    331       .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll));
    332 
    333   EXPECT_EQ(net::ERR_IO_PENDING,
    334             host_socket1_->Write(buf.get(),
    335                                  buf->size(),
    336                                  base::Bind(&MockSocketCallback::OnDone,
    337                                             base::Unretained(&cb1))));
    338   EXPECT_EQ(net::ERR_IO_PENDING,
    339             host_socket2_->Write(buf.get(),
    340                                  buf->size(),
    341                                  base::Bind(&MockSocketCallback::OnDone,
    342                                             base::Unretained(&cb2))));
    343 
    344   message_loop_.RunUntilIdle();
    345 
    346   // Check that the sockets were destroyed.
    347   EXPECT_FALSE(host_mux_.get());
    348 }
    349 
    350 TEST_F(ChannelMultiplexerTest, SessionFail) {
    351   host_session_.set_async_creation(true);
    352   host_session_.set_error(AUTHENTICATION_FAILED);
    353 
    354   MockConnectCallback cb1;
    355   MockConnectCallback cb2;
    356 
    357   host_mux_->CreateStreamChannel(kTestChannelName, base::Bind(
    358       &MockConnectCallback::OnConnected, base::Unretained(&cb1)));
    359   host_mux_->CreateStreamChannel(kTestChannelName2, base::Bind(
    360       &MockConnectCallback::OnConnected, base::Unretained(&cb2)));
    361 
    362   EXPECT_CALL(cb1, OnConnectedPtr(NULL))
    363       .Times(AtMost(1))
    364       .WillOnce(InvokeWithoutArgs(
    365           this, &ChannelMultiplexerTest::DeleteAfterSessionFail));
    366   EXPECT_CALL(cb2, OnConnectedPtr(_))
    367       .Times(0);
    368 
    369   message_loop_.RunUntilIdle();
    370 }
    371 
    372 }  // namespace protocol
    373 }  // namespace remoting
    374