1 // Copyright 2014 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "extensions/browser/api/socket/tcp_socket.h" 6 7 #include "base/lazy_instance.h" 8 #include "base/logging.h" 9 #include "base/macros.h" 10 #include "extensions/browser/api/api_resource.h" 11 #include "net/base/address_list.h" 12 #include "net/base/ip_endpoint.h" 13 #include "net/base/net_errors.h" 14 #include "net/base/rand_callback.h" 15 #include "net/socket/tcp_client_socket.h" 16 17 namespace extensions { 18 19 const char kTCPSocketTypeInvalidError[] = 20 "Cannot call both connect and listen on the same socket."; 21 const char kSocketListenError[] = "Could not listen on the specified port."; 22 23 static base::LazyInstance< 24 BrowserContextKeyedAPIFactory<ApiResourceManager<ResumableTCPSocket> > > 25 g_factory = LAZY_INSTANCE_INITIALIZER; 26 27 // static 28 template <> 29 BrowserContextKeyedAPIFactory<ApiResourceManager<ResumableTCPSocket> >* 30 ApiResourceManager<ResumableTCPSocket>::GetFactoryInstance() { 31 return g_factory.Pointer(); 32 } 33 34 static base::LazyInstance<BrowserContextKeyedAPIFactory< 35 ApiResourceManager<ResumableTCPServerSocket> > > g_server_factory = 36 LAZY_INSTANCE_INITIALIZER; 37 38 // static 39 template <> 40 BrowserContextKeyedAPIFactory<ApiResourceManager<ResumableTCPServerSocket> >* 41 ApiResourceManager<ResumableTCPServerSocket>::GetFactoryInstance() { 42 return g_server_factory.Pointer(); 43 } 44 45 TCPSocket::TCPSocket(const std::string& owner_extension_id) 46 : Socket(owner_extension_id), socket_mode_(UNKNOWN) {} 47 48 TCPSocket::TCPSocket(net::TCPClientSocket* tcp_client_socket, 49 const std::string& owner_extension_id, 50 bool is_connected) 51 : Socket(owner_extension_id), 52 socket_(tcp_client_socket), 53 socket_mode_(CLIENT) { 54 this->is_connected_ = is_connected; 55 } 56 57 TCPSocket::TCPSocket(net::TCPServerSocket* tcp_server_socket, 58 const std::string& owner_extension_id) 59 : Socket(owner_extension_id), 60 server_socket_(tcp_server_socket), 61 socket_mode_(SERVER) {} 62 63 // static 64 TCPSocket* TCPSocket::CreateSocketForTesting( 65 net::TCPClientSocket* tcp_client_socket, 66 const std::string& owner_extension_id, 67 bool is_connected) { 68 return new TCPSocket(tcp_client_socket, owner_extension_id, is_connected); 69 } 70 71 // static 72 TCPSocket* TCPSocket::CreateServerSocketForTesting( 73 net::TCPServerSocket* tcp_server_socket, 74 const std::string& owner_extension_id) { 75 return new TCPSocket(tcp_server_socket, owner_extension_id); 76 } 77 78 TCPSocket::~TCPSocket() { Disconnect(); } 79 80 void TCPSocket::Connect(const std::string& address, 81 int port, 82 const CompletionCallback& callback) { 83 DCHECK(!callback.is_null()); 84 85 if (socket_mode_ == SERVER || !connect_callback_.is_null()) { 86 callback.Run(net::ERR_CONNECTION_FAILED); 87 return; 88 } 89 DCHECK(!server_socket_.get()); 90 socket_mode_ = CLIENT; 91 connect_callback_ = callback; 92 93 int result = net::ERR_CONNECTION_FAILED; 94 do { 95 if (is_connected_) 96 break; 97 98 net::AddressList address_list; 99 if (!StringAndPortToAddressList(address, port, &address_list)) { 100 result = net::ERR_ADDRESS_INVALID; 101 break; 102 } 103 104 socket_.reset( 105 new net::TCPClientSocket(address_list, NULL, net::NetLog::Source())); 106 107 connect_callback_ = callback; 108 result = socket_->Connect( 109 base::Bind(&TCPSocket::OnConnectComplete, base::Unretained(this))); 110 } while (false); 111 112 if (result != net::ERR_IO_PENDING) 113 OnConnectComplete(result); 114 } 115 116 void TCPSocket::Disconnect() { 117 is_connected_ = false; 118 if (socket_.get()) 119 socket_->Disconnect(); 120 server_socket_.reset(NULL); 121 connect_callback_.Reset(); 122 read_callback_.Reset(); 123 accept_callback_.Reset(); 124 accept_socket_.reset(NULL); 125 } 126 127 int TCPSocket::Bind(const std::string& address, int port) { 128 return net::ERR_FAILED; 129 } 130 131 void TCPSocket::Read(int count, const ReadCompletionCallback& callback) { 132 DCHECK(!callback.is_null()); 133 134 if (socket_mode_ != CLIENT) { 135 callback.Run(net::ERR_FAILED, NULL); 136 return; 137 } 138 139 if (!read_callback_.is_null()) { 140 callback.Run(net::ERR_IO_PENDING, NULL); 141 return; 142 } 143 144 if (count < 0) { 145 callback.Run(net::ERR_INVALID_ARGUMENT, NULL); 146 return; 147 } 148 149 if (!socket_.get() || !IsConnected()) { 150 callback.Run(net::ERR_SOCKET_NOT_CONNECTED, NULL); 151 return; 152 } 153 154 read_callback_ = callback; 155 scoped_refptr<net::IOBuffer> io_buffer = new net::IOBuffer(count); 156 int result = socket_->Read( 157 io_buffer.get(), 158 count, 159 base::Bind( 160 &TCPSocket::OnReadComplete, base::Unretained(this), io_buffer)); 161 162 if (result != net::ERR_IO_PENDING) 163 OnReadComplete(io_buffer, result); 164 } 165 166 void TCPSocket::RecvFrom(int count, 167 const RecvFromCompletionCallback& callback) { 168 callback.Run(net::ERR_FAILED, NULL, NULL, 0); 169 } 170 171 void TCPSocket::SendTo(scoped_refptr<net::IOBuffer> io_buffer, 172 int byte_count, 173 const std::string& address, 174 int port, 175 const CompletionCallback& callback) { 176 callback.Run(net::ERR_FAILED); 177 } 178 179 bool TCPSocket::SetKeepAlive(bool enable, int delay) { 180 if (!socket_.get()) 181 return false; 182 return socket_->SetKeepAlive(enable, delay); 183 } 184 185 bool TCPSocket::SetNoDelay(bool no_delay) { 186 if (!socket_.get()) 187 return false; 188 return socket_->SetNoDelay(no_delay); 189 } 190 191 int TCPSocket::Listen(const std::string& address, 192 int port, 193 int backlog, 194 std::string* error_msg) { 195 if (socket_mode_ == CLIENT) { 196 *error_msg = kTCPSocketTypeInvalidError; 197 return net::ERR_NOT_IMPLEMENTED; 198 } 199 DCHECK(!socket_.get()); 200 socket_mode_ = SERVER; 201 202 if (!server_socket_.get()) { 203 server_socket_.reset(new net::TCPServerSocket(NULL, net::NetLog::Source())); 204 } 205 206 int result = server_socket_->ListenWithAddressAndPort(address, port, backlog); 207 if (result) 208 *error_msg = kSocketListenError; 209 return result; 210 } 211 212 void TCPSocket::Accept(const AcceptCompletionCallback& callback) { 213 if (socket_mode_ != SERVER || !server_socket_.get()) { 214 callback.Run(net::ERR_FAILED, NULL); 215 return; 216 } 217 218 // Limits to only 1 blocked accept call. 219 if (!accept_callback_.is_null()) { 220 callback.Run(net::ERR_FAILED, NULL); 221 return; 222 } 223 224 int result = server_socket_->Accept( 225 &accept_socket_, 226 base::Bind(&TCPSocket::OnAccept, base::Unretained(this))); 227 if (result == net::ERR_IO_PENDING) { 228 accept_callback_ = callback; 229 } else if (result == net::OK) { 230 accept_callback_ = callback; 231 this->OnAccept(result); 232 } else { 233 callback.Run(result, NULL); 234 } 235 } 236 237 bool TCPSocket::IsConnected() { 238 RefreshConnectionStatus(); 239 return is_connected_; 240 } 241 242 bool TCPSocket::GetPeerAddress(net::IPEndPoint* address) { 243 if (!socket_.get()) 244 return false; 245 return !socket_->GetPeerAddress(address); 246 } 247 248 bool TCPSocket::GetLocalAddress(net::IPEndPoint* address) { 249 if (socket_.get()) { 250 return !socket_->GetLocalAddress(address); 251 } else if (server_socket_.get()) { 252 return !server_socket_->GetLocalAddress(address); 253 } else { 254 return false; 255 } 256 } 257 258 Socket::SocketType TCPSocket::GetSocketType() const { return Socket::TYPE_TCP; } 259 260 int TCPSocket::WriteImpl(net::IOBuffer* io_buffer, 261 int io_buffer_size, 262 const net::CompletionCallback& callback) { 263 if (socket_mode_ != CLIENT) 264 return net::ERR_FAILED; 265 else if (!socket_.get() || !IsConnected()) 266 return net::ERR_SOCKET_NOT_CONNECTED; 267 else 268 return socket_->Write(io_buffer, io_buffer_size, callback); 269 } 270 271 void TCPSocket::RefreshConnectionStatus() { 272 if (!is_connected_) 273 return; 274 if (server_socket_) 275 return; 276 if (!socket_->IsConnected()) { 277 Disconnect(); 278 } 279 } 280 281 void TCPSocket::OnConnectComplete(int result) { 282 DCHECK(!connect_callback_.is_null()); 283 DCHECK(!is_connected_); 284 is_connected_ = result == net::OK; 285 connect_callback_.Run(result); 286 connect_callback_.Reset(); 287 } 288 289 void TCPSocket::OnReadComplete(scoped_refptr<net::IOBuffer> io_buffer, 290 int result) { 291 DCHECK(!read_callback_.is_null()); 292 read_callback_.Run(result, io_buffer); 293 read_callback_.Reset(); 294 } 295 296 void TCPSocket::OnAccept(int result) { 297 DCHECK(!accept_callback_.is_null()); 298 if (result == net::OK && accept_socket_.get()) { 299 accept_callback_.Run( 300 result, static_cast<net::TCPClientSocket*>(accept_socket_.release())); 301 } else { 302 accept_callback_.Run(result, NULL); 303 } 304 accept_callback_.Reset(); 305 } 306 307 void TCPSocket::Release() { 308 // Release() is only invoked when the underlying sockets are taken (via 309 // ClientStream()) by TLSSocket. TLSSocket only supports CLIENT-mode 310 // sockets. 311 DCHECK(!server_socket_.release() && !accept_socket_.release() && 312 socket_mode_ == CLIENT) 313 << "Called in server mode."; 314 315 // Release() doesn't disconnect the underlying sockets, but it does 316 // disconnect them from this TCPSocket. 317 is_connected_ = false; 318 319 connect_callback_.Reset(); 320 read_callback_.Reset(); 321 accept_callback_.Reset(); 322 323 DCHECK(socket_.get()) << "Called on null client socket."; 324 ignore_result(socket_.release()); 325 } 326 327 net::TCPClientSocket* TCPSocket::ClientStream() { 328 if (socket_mode_ != CLIENT || GetSocketType() != TYPE_TCP) 329 return NULL; 330 return socket_.get(); 331 } 332 333 bool TCPSocket::HasPendingRead() const { 334 return !read_callback_.is_null(); 335 } 336 337 ResumableTCPSocket::ResumableTCPSocket(const std::string& owner_extension_id) 338 : TCPSocket(owner_extension_id), 339 persistent_(false), 340 buffer_size_(0), 341 paused_(false) {} 342 343 ResumableTCPSocket::ResumableTCPSocket(net::TCPClientSocket* tcp_client_socket, 344 const std::string& owner_extension_id, 345 bool is_connected) 346 : TCPSocket(tcp_client_socket, owner_extension_id, is_connected), 347 persistent_(false), 348 buffer_size_(0), 349 paused_(false) {} 350 351 bool ResumableTCPSocket::IsPersistent() const { return persistent(); } 352 353 ResumableTCPServerSocket::ResumableTCPServerSocket( 354 const std::string& owner_extension_id) 355 : TCPSocket(owner_extension_id), persistent_(false), paused_(false) {} 356 357 bool ResumableTCPServerSocket::IsPersistent() const { return persistent(); } 358 359 } // namespace extensions 360