1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "extensions/browser/api/socket/tcp_socket.h" 6 7 #include "base/memory/scoped_ptr.h" 8 #include "net/base/address_list.h" 9 #include "net/base/completion_callback.h" 10 #include "net/base/io_buffer.h" 11 #include "net/base/net_errors.h" 12 #include "net/base/rand_callback.h" 13 #include "net/socket/tcp_client_socket.h" 14 #include "net/socket/tcp_server_socket.h" 15 #include "testing/gmock/include/gmock/gmock.h" 16 17 using testing::_; 18 using testing::DoAll; 19 using testing::Return; 20 using testing::SaveArg; 21 22 namespace extensions { 23 24 class MockTCPSocket : public net::TCPClientSocket { 25 public: 26 explicit MockTCPSocket(const net::AddressList& address_list) 27 : net::TCPClientSocket(address_list, NULL, net::NetLog::Source()) { 28 } 29 30 MOCK_METHOD3(Read, int(net::IOBuffer* buf, int buf_len, 31 const net::CompletionCallback& callback)); 32 MOCK_METHOD3(Write, int(net::IOBuffer* buf, int buf_len, 33 const net::CompletionCallback& callback)); 34 MOCK_METHOD2(SetKeepAlive, bool(bool enable, int delay)); 35 MOCK_METHOD1(SetNoDelay, bool(bool no_delay)); 36 virtual bool IsConnected() const OVERRIDE { 37 return true; 38 } 39 40 private: 41 DISALLOW_COPY_AND_ASSIGN(MockTCPSocket); 42 }; 43 44 class MockTCPServerSocket : public net::TCPServerSocket { 45 public: 46 explicit MockTCPServerSocket() 47 : net::TCPServerSocket(NULL, net::NetLog::Source()) { 48 } 49 MOCK_METHOD2(Listen, int(const net::IPEndPoint& address, int backlog)); 50 MOCK_METHOD2(Accept, int(scoped_ptr<net::StreamSocket>* socket, 51 const net::CompletionCallback& callback)); 52 53 private: 54 DISALLOW_COPY_AND_ASSIGN(MockTCPServerSocket); 55 }; 56 57 class CompleteHandler { 58 public: 59 CompleteHandler() {} 60 MOCK_METHOD1(OnComplete, void(int result_code)); 61 MOCK_METHOD2(OnReadComplete, void(int result_code, 62 scoped_refptr<net::IOBuffer> io_buffer)); 63 MOCK_METHOD2(OnAccept, void(int, net::TCPClientSocket*)); 64 private: 65 DISALLOW_COPY_AND_ASSIGN(CompleteHandler); 66 }; 67 68 const std::string FAKE_ID = "abcdefghijklmnopqrst"; 69 70 TEST(SocketTest, TestTCPSocketRead) { 71 net::AddressList address_list; 72 MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list); 73 CompleteHandler handler; 74 75 scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting( 76 tcp_client_socket, FAKE_ID, true)); 77 78 EXPECT_CALL(*tcp_client_socket, Read(_, _, _)) 79 .Times(1); 80 EXPECT_CALL(handler, OnReadComplete(_, _)) 81 .Times(1); 82 83 const int count = 512; 84 socket->Read(count, base::Bind(&CompleteHandler::OnReadComplete, 85 base::Unretained(&handler))); 86 } 87 88 TEST(SocketTest, TestTCPSocketWrite) { 89 net::AddressList address_list; 90 MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list); 91 CompleteHandler handler; 92 93 scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting( 94 tcp_client_socket, FAKE_ID, true)); 95 96 net::CompletionCallback callback; 97 EXPECT_CALL(*tcp_client_socket, Write(_, _, _)) 98 .Times(2) 99 .WillRepeatedly(testing::DoAll(SaveArg<2>(&callback), 100 Return(128))); 101 EXPECT_CALL(handler, OnComplete(_)) 102 .Times(1); 103 104 scoped_refptr<net::IOBufferWithSize> io_buffer( 105 new net::IOBufferWithSize(256)); 106 socket->Write(io_buffer.get(), io_buffer->size(), 107 base::Bind(&CompleteHandler::OnComplete, base::Unretained(&handler))); 108 } 109 110 TEST(SocketTest, TestTCPSocketBlockedWrite) { 111 net::AddressList address_list; 112 MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list); 113 CompleteHandler handler; 114 115 scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting( 116 tcp_client_socket, FAKE_ID, true)); 117 118 net::CompletionCallback callback; 119 EXPECT_CALL(*tcp_client_socket, Write(_, _, _)) 120 .Times(2) 121 .WillRepeatedly(testing::DoAll(SaveArg<2>(&callback), 122 Return(net::ERR_IO_PENDING))); 123 scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(42)); 124 socket->Write(io_buffer.get(), io_buffer->size(), 125 base::Bind(&CompleteHandler::OnComplete, base::Unretained(&handler))); 126 127 // Good. Original call came back unable to complete. Now pretend the socket 128 // finished, and confirm that we passed the error back. 129 EXPECT_CALL(handler, OnComplete(42)) 130 .Times(1); 131 callback.Run(40); 132 callback.Run(2); 133 } 134 135 TEST(SocketTest, TestTCPSocketBlockedWriteReentry) { 136 net::AddressList address_list; 137 MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list); 138 CompleteHandler handlers[5]; 139 140 scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting( 141 tcp_client_socket, FAKE_ID, true)); 142 143 net::CompletionCallback callback; 144 EXPECT_CALL(*tcp_client_socket, Write(_, _, _)) 145 .Times(5) 146 .WillRepeatedly(testing::DoAll(SaveArg<2>(&callback), 147 Return(net::ERR_IO_PENDING))); 148 scoped_refptr<net::IOBufferWithSize> io_buffers[5]; 149 int i; 150 for (i = 0; i < 5; i++) { 151 io_buffers[i] = new net::IOBufferWithSize(128 + i * 50); 152 scoped_refptr<net::IOBufferWithSize> io_buffer1( 153 new net::IOBufferWithSize(42)); 154 socket->Write(io_buffers[i].get(), io_buffers[i]->size(), 155 base::Bind(&CompleteHandler::OnComplete, 156 base::Unretained(&handlers[i]))); 157 158 EXPECT_CALL(handlers[i], OnComplete(io_buffers[i]->size())) 159 .Times(1); 160 } 161 162 for (i = 0; i < 5; i++) { 163 callback.Run(128 + i * 50); 164 } 165 } 166 167 TEST(SocketTest, TestTCPSocketSetNoDelay) { 168 net::AddressList address_list; 169 MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list); 170 171 scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting( 172 tcp_client_socket, FAKE_ID)); 173 174 bool no_delay = false; 175 EXPECT_CALL(*tcp_client_socket, SetNoDelay(_)) 176 .WillOnce(testing::DoAll(SaveArg<0>(&no_delay), Return(true))); 177 int result = socket->SetNoDelay(true); 178 EXPECT_TRUE(result); 179 EXPECT_TRUE(no_delay); 180 181 EXPECT_CALL(*tcp_client_socket, SetNoDelay(_)) 182 .WillOnce(testing::DoAll(SaveArg<0>(&no_delay), Return(false))); 183 184 result = socket->SetNoDelay(false); 185 EXPECT_FALSE(result); 186 EXPECT_FALSE(no_delay); 187 } 188 189 TEST(SocketTest, TestTCPSocketSetKeepAlive) { 190 net::AddressList address_list; 191 MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list); 192 193 scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting( 194 tcp_client_socket, FAKE_ID)); 195 196 bool enable = false; 197 int delay = 0; 198 EXPECT_CALL(*tcp_client_socket, SetKeepAlive(_, _)) 199 .WillOnce(testing::DoAll(SaveArg<0>(&enable), 200 SaveArg<1>(&delay), 201 Return(true))); 202 int result = socket->SetKeepAlive(true, 4500); 203 EXPECT_TRUE(result); 204 EXPECT_TRUE(enable); 205 EXPECT_EQ(4500, delay); 206 207 EXPECT_CALL(*tcp_client_socket, SetKeepAlive(_, _)) 208 .WillOnce(testing::DoAll(SaveArg<0>(&enable), 209 SaveArg<1>(&delay), 210 Return(false))); 211 result = socket->SetKeepAlive(false, 0); 212 EXPECT_FALSE(result); 213 EXPECT_FALSE(enable); 214 EXPECT_EQ(0, delay); 215 } 216 217 TEST(SocketTest, TestTCPServerSocketListenAccept) { 218 MockTCPServerSocket* tcp_server_socket = new MockTCPServerSocket(); 219 CompleteHandler handler; 220 221 scoped_ptr<TCPSocket> socket(TCPSocket::CreateServerSocketForTesting( 222 tcp_server_socket, FAKE_ID)); 223 224 EXPECT_CALL(*tcp_server_socket, Accept(_, _)).Times(1); 225 EXPECT_CALL(*tcp_server_socket, Listen(_, _)).Times(1); 226 EXPECT_CALL(handler, OnAccept(_, _)); 227 228 std::string err_msg; 229 EXPECT_EQ(net::OK, socket->Listen("127.0.0.1", 9999, 10, &err_msg)); 230 socket->Accept(base::Bind(&CompleteHandler::OnAccept, 231 base::Unretained(&handler))); 232 } 233 234 } // namespace extensions 235