1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "remoting/protocol/channel_multiplexer.h" 6 7 #include "base/bind.h" 8 #include "base/message_loop/message_loop.h" 9 #include "net/base/net_errors.h" 10 #include "net/socket/socket.h" 11 #include "net/socket/stream_socket.h" 12 #include "remoting/base/constants.h" 13 #include "remoting/protocol/connection_tester.h" 14 #include "remoting/protocol/fake_session.h" 15 #include "testing/gmock/include/gmock/gmock.h" 16 #include "testing/gtest/include/gtest/gtest.h" 17 18 using testing::_; 19 using testing::AtMost; 20 using testing::InvokeWithoutArgs; 21 22 namespace remoting { 23 namespace protocol { 24 25 namespace { 26 27 const int kMessageSize = 1024; 28 const int kMessages = 100; 29 const char kMuxChannelName[] = "mux"; 30 31 const char kTestChannelName[] = "test"; 32 const char kTestChannelName2[] = "test2"; 33 34 35 void QuitCurrentThread() { 36 base::MessageLoop::current()->PostTask(FROM_HERE, 37 base::MessageLoop::QuitClosure()); 38 } 39 40 class MockSocketCallback { 41 public: 42 MOCK_METHOD1(OnDone, void(int result)); 43 }; 44 45 class MockConnectCallback { 46 public: 47 MOCK_METHOD1(OnConnectedPtr, void(net::StreamSocket* socket)); 48 void OnConnected(scoped_ptr<net::StreamSocket> socket) { 49 OnConnectedPtr(socket.release()); 50 } 51 }; 52 53 } // namespace 54 55 class ChannelMultiplexerTest : public testing::Test { 56 public: 57 void DeleteAll() { 58 host_socket1_.reset(); 59 host_socket2_.reset(); 60 client_socket1_.reset(); 61 client_socket2_.reset(); 62 host_mux_.reset(); 63 client_mux_.reset(); 64 } 65 66 void DeleteAfterSessionFail() { 67 host_mux_->CancelChannelCreation(kTestChannelName2); 68 DeleteAll(); 69 } 70 71 protected: 72 virtual void SetUp() OVERRIDE { 73 // Create pair of multiplexers and connect them to each other. 74 host_mux_.reset(new ChannelMultiplexer(&host_session_, kMuxChannelName)); 75 client_mux_.reset(new ChannelMultiplexer(&client_session_, 76 kMuxChannelName)); 77 } 78 79 // Connect sockets to each other. Must be called after we've created at least 80 // one channel with each multiplexer. 81 void ConnectSockets() { 82 FakeSocket* host_socket = 83 host_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName); 84 FakeSocket* client_socket = 85 client_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName); 86 host_socket->PairWith(client_socket); 87 88 // Make writes asynchronous in one direction. 89 host_socket->set_async_write(true); 90 } 91 92 void CreateChannel(const std::string& name, 93 scoped_ptr<net::StreamSocket>* host_socket, 94 scoped_ptr<net::StreamSocket>* client_socket) { 95 int counter = 2; 96 host_mux_->CreateStreamChannel(name, base::Bind( 97 &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this), 98 host_socket, &counter)); 99 client_mux_->CreateStreamChannel(name, base::Bind( 100 &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this), 101 client_socket, &counter)); 102 103 message_loop_.Run(); 104 105 EXPECT_TRUE(host_socket->get()); 106 EXPECT_TRUE(client_socket->get()); 107 } 108 109 void OnChannelConnected( 110 scoped_ptr<net::StreamSocket>* storage, 111 int* counter, 112 scoped_ptr<net::StreamSocket> socket) { 113 *storage = socket.Pass(); 114 --(*counter); 115 EXPECT_GE(*counter, 0); 116 if (*counter == 0) 117 QuitCurrentThread(); 118 } 119 120 scoped_refptr<net::IOBufferWithSize> CreateTestBuffer(int size) { 121 scoped_refptr<net::IOBufferWithSize> result = 122 new net::IOBufferWithSize(size); 123 for (int i = 0; i< size; ++i) { 124 result->data()[i] = rand() % 256; 125 } 126 return result; 127 } 128 129 base::MessageLoop message_loop_; 130 131 FakeSession host_session_; 132 FakeSession client_session_; 133 134 scoped_ptr<ChannelMultiplexer> host_mux_; 135 scoped_ptr<ChannelMultiplexer> client_mux_; 136 137 scoped_ptr<net::StreamSocket> host_socket1_; 138 scoped_ptr<net::StreamSocket> client_socket1_; 139 scoped_ptr<net::StreamSocket> host_socket2_; 140 scoped_ptr<net::StreamSocket> client_socket2_; 141 }; 142 143 144 TEST_F(ChannelMultiplexerTest, OneChannel) { 145 scoped_ptr<net::StreamSocket> host_socket; 146 scoped_ptr<net::StreamSocket> client_socket; 147 ASSERT_NO_FATAL_FAILURE( 148 CreateChannel(kTestChannelName, &host_socket, &client_socket)); 149 150 ConnectSockets(); 151 152 StreamConnectionTester tester(host_socket.get(), client_socket.get(), 153 kMessageSize, kMessages); 154 tester.Start(); 155 message_loop_.Run(); 156 tester.CheckResults(); 157 } 158 159 TEST_F(ChannelMultiplexerTest, TwoChannels) { 160 scoped_ptr<net::StreamSocket> host_socket1_; 161 scoped_ptr<net::StreamSocket> client_socket1_; 162 ASSERT_NO_FATAL_FAILURE( 163 CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_)); 164 165 scoped_ptr<net::StreamSocket> host_socket2_; 166 scoped_ptr<net::StreamSocket> client_socket2_; 167 ASSERT_NO_FATAL_FAILURE( 168 CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_)); 169 170 ConnectSockets(); 171 172 StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(), 173 kMessageSize, kMessages); 174 StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(), 175 kMessageSize, kMessages); 176 tester1.Start(); 177 tester2.Start(); 178 while (!tester1.done() || !tester2.done()) { 179 message_loop_.Run(); 180 } 181 tester1.CheckResults(); 182 tester2.CheckResults(); 183 } 184 185 // Four channels, two in each direction 186 TEST_F(ChannelMultiplexerTest, FourChannels) { 187 scoped_ptr<net::StreamSocket> host_socket1_; 188 scoped_ptr<net::StreamSocket> client_socket1_; 189 ASSERT_NO_FATAL_FAILURE( 190 CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_)); 191 192 scoped_ptr<net::StreamSocket> host_socket2_; 193 scoped_ptr<net::StreamSocket> client_socket2_; 194 ASSERT_NO_FATAL_FAILURE( 195 CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_)); 196 197 scoped_ptr<net::StreamSocket> host_socket3; 198 scoped_ptr<net::StreamSocket> client_socket3; 199 ASSERT_NO_FATAL_FAILURE( 200 CreateChannel("test3", &host_socket3, &client_socket3)); 201 202 scoped_ptr<net::StreamSocket> host_socket4; 203 scoped_ptr<net::StreamSocket> client_socket4; 204 ASSERT_NO_FATAL_FAILURE( 205 CreateChannel("ch4", &host_socket4, &client_socket4)); 206 207 ConnectSockets(); 208 209 StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(), 210 kMessageSize, kMessages); 211 StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(), 212 kMessageSize, kMessages); 213 StreamConnectionTester tester3(client_socket3.get(), host_socket3.get(), 214 kMessageSize, kMessages); 215 StreamConnectionTester tester4(client_socket4.get(), host_socket4.get(), 216 kMessageSize, kMessages); 217 tester1.Start(); 218 tester2.Start(); 219 tester3.Start(); 220 tester4.Start(); 221 while (!tester1.done() || !tester2.done() || 222 !tester3.done() || !tester4.done()) { 223 message_loop_.Run(); 224 } 225 tester1.CheckResults(); 226 tester2.CheckResults(); 227 tester3.CheckResults(); 228 tester4.CheckResults(); 229 } 230 231 TEST_F(ChannelMultiplexerTest, WriteFailSync) { 232 scoped_ptr<net::StreamSocket> host_socket1_; 233 scoped_ptr<net::StreamSocket> client_socket1_; 234 ASSERT_NO_FATAL_FAILURE( 235 CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_)); 236 237 scoped_ptr<net::StreamSocket> host_socket2_; 238 scoped_ptr<net::StreamSocket> client_socket2_; 239 ASSERT_NO_FATAL_FAILURE( 240 CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_)); 241 242 ConnectSockets(); 243 244 host_session_.GetStreamChannel(kMuxChannelName)-> 245 set_next_write_error(net::ERR_FAILED); 246 host_session_.GetStreamChannel(kMuxChannelName)-> 247 set_async_write(false); 248 249 scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100); 250 251 MockSocketCallback cb1; 252 MockSocketCallback cb2; 253 254 EXPECT_CALL(cb1, OnDone(_)) 255 .Times(0); 256 EXPECT_CALL(cb2, OnDone(_)) 257 .Times(0); 258 259 EXPECT_EQ(net::ERR_FAILED, 260 host_socket1_->Write(buf.get(), 261 buf->size(), 262 base::Bind(&MockSocketCallback::OnDone, 263 base::Unretained(&cb1)))); 264 EXPECT_EQ(net::ERR_FAILED, 265 host_socket2_->Write(buf.get(), 266 buf->size(), 267 base::Bind(&MockSocketCallback::OnDone, 268 base::Unretained(&cb2)))); 269 270 message_loop_.RunUntilIdle(); 271 } 272 273 TEST_F(ChannelMultiplexerTest, WriteFailAsync) { 274 ASSERT_NO_FATAL_FAILURE( 275 CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_)); 276 277 ASSERT_NO_FATAL_FAILURE( 278 CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_)); 279 280 ConnectSockets(); 281 282 host_session_.GetStreamChannel(kMuxChannelName)-> 283 set_next_write_error(net::ERR_FAILED); 284 host_session_.GetStreamChannel(kMuxChannelName)-> 285 set_async_write(true); 286 287 scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100); 288 289 MockSocketCallback cb1; 290 MockSocketCallback cb2; 291 EXPECT_CALL(cb1, OnDone(net::ERR_FAILED)); 292 EXPECT_CALL(cb2, OnDone(net::ERR_FAILED)); 293 294 EXPECT_EQ(net::ERR_IO_PENDING, 295 host_socket1_->Write(buf.get(), 296 buf->size(), 297 base::Bind(&MockSocketCallback::OnDone, 298 base::Unretained(&cb1)))); 299 EXPECT_EQ(net::ERR_IO_PENDING, 300 host_socket2_->Write(buf.get(), 301 buf->size(), 302 base::Bind(&MockSocketCallback::OnDone, 303 base::Unretained(&cb2)))); 304 305 message_loop_.RunUntilIdle(); 306 } 307 308 TEST_F(ChannelMultiplexerTest, DeleteWhenFailed) { 309 ASSERT_NO_FATAL_FAILURE( 310 CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_)); 311 ASSERT_NO_FATAL_FAILURE( 312 CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_)); 313 314 ConnectSockets(); 315 316 host_session_.GetStreamChannel(kMuxChannelName)-> 317 set_next_write_error(net::ERR_FAILED); 318 host_session_.GetStreamChannel(kMuxChannelName)-> 319 set_async_write(true); 320 321 scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100); 322 323 MockSocketCallback cb1; 324 MockSocketCallback cb2; 325 326 EXPECT_CALL(cb1, OnDone(net::ERR_FAILED)) 327 .Times(AtMost(1)) 328 .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll)); 329 EXPECT_CALL(cb2, OnDone(net::ERR_FAILED)) 330 .Times(AtMost(1)) 331 .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll)); 332 333 EXPECT_EQ(net::ERR_IO_PENDING, 334 host_socket1_->Write(buf.get(), 335 buf->size(), 336 base::Bind(&MockSocketCallback::OnDone, 337 base::Unretained(&cb1)))); 338 EXPECT_EQ(net::ERR_IO_PENDING, 339 host_socket2_->Write(buf.get(), 340 buf->size(), 341 base::Bind(&MockSocketCallback::OnDone, 342 base::Unretained(&cb2)))); 343 344 message_loop_.RunUntilIdle(); 345 346 // Check that the sockets were destroyed. 347 EXPECT_FALSE(host_mux_.get()); 348 } 349 350 TEST_F(ChannelMultiplexerTest, SessionFail) { 351 host_session_.set_async_creation(true); 352 host_session_.set_error(AUTHENTICATION_FAILED); 353 354 MockConnectCallback cb1; 355 MockConnectCallback cb2; 356 357 host_mux_->CreateStreamChannel(kTestChannelName, base::Bind( 358 &MockConnectCallback::OnConnected, base::Unretained(&cb1))); 359 host_mux_->CreateStreamChannel(kTestChannelName2, base::Bind( 360 &MockConnectCallback::OnConnected, base::Unretained(&cb2))); 361 362 EXPECT_CALL(cb1, OnConnectedPtr(NULL)) 363 .Times(AtMost(1)) 364 .WillOnce(InvokeWithoutArgs( 365 this, &ChannelMultiplexerTest::DeleteAfterSessionFail)); 366 EXPECT_CALL(cb2, OnConnectedPtr(_)) 367 .Times(0); 368 369 message_loop_.RunUntilIdle(); 370 } 371 372 } // namespace protocol 373 } // namespace remoting 374