Home | History | Annotate | Download | only in socket
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "extensions/browser/api/socket/tcp_socket.h"
      6 
      7 #include "base/lazy_instance.h"
      8 #include "base/logging.h"
      9 #include "base/macros.h"
     10 #include "extensions/browser/api/api_resource.h"
     11 #include "net/base/address_list.h"
     12 #include "net/base/ip_endpoint.h"
     13 #include "net/base/net_errors.h"
     14 #include "net/base/rand_callback.h"
     15 #include "net/socket/tcp_client_socket.h"
     16 
     17 namespace extensions {
     18 
     19 const char kTCPSocketTypeInvalidError[] =
     20     "Cannot call both connect and listen on the same socket.";
     21 const char kSocketListenError[] = "Could not listen on the specified port.";
     22 
     23 static base::LazyInstance<
     24     BrowserContextKeyedAPIFactory<ApiResourceManager<ResumableTCPSocket> > >
     25     g_factory = LAZY_INSTANCE_INITIALIZER;
     26 
     27 // static
     28 template <>
     29 BrowserContextKeyedAPIFactory<ApiResourceManager<ResumableTCPSocket> >*
     30 ApiResourceManager<ResumableTCPSocket>::GetFactoryInstance() {
     31   return g_factory.Pointer();
     32 }
     33 
     34 static base::LazyInstance<BrowserContextKeyedAPIFactory<
     35     ApiResourceManager<ResumableTCPServerSocket> > > g_server_factory =
     36     LAZY_INSTANCE_INITIALIZER;
     37 
     38 // static
     39 template <>
     40 BrowserContextKeyedAPIFactory<ApiResourceManager<ResumableTCPServerSocket> >*
     41 ApiResourceManager<ResumableTCPServerSocket>::GetFactoryInstance() {
     42   return g_server_factory.Pointer();
     43 }
     44 
     45 TCPSocket::TCPSocket(const std::string& owner_extension_id)
     46     : Socket(owner_extension_id), socket_mode_(UNKNOWN) {}
     47 
     48 TCPSocket::TCPSocket(net::TCPClientSocket* tcp_client_socket,
     49                      const std::string& owner_extension_id,
     50                      bool is_connected)
     51     : Socket(owner_extension_id),
     52       socket_(tcp_client_socket),
     53       socket_mode_(CLIENT) {
     54   this->is_connected_ = is_connected;
     55 }
     56 
     57 TCPSocket::TCPSocket(net::TCPServerSocket* tcp_server_socket,
     58                      const std::string& owner_extension_id)
     59     : Socket(owner_extension_id),
     60       server_socket_(tcp_server_socket),
     61       socket_mode_(SERVER) {}
     62 
     63 // static
     64 TCPSocket* TCPSocket::CreateSocketForTesting(
     65     net::TCPClientSocket* tcp_client_socket,
     66     const std::string& owner_extension_id,
     67     bool is_connected) {
     68   return new TCPSocket(tcp_client_socket, owner_extension_id, is_connected);
     69 }
     70 
     71 // static
     72 TCPSocket* TCPSocket::CreateServerSocketForTesting(
     73     net::TCPServerSocket* tcp_server_socket,
     74     const std::string& owner_extension_id) {
     75   return new TCPSocket(tcp_server_socket, owner_extension_id);
     76 }
     77 
     78 TCPSocket::~TCPSocket() { Disconnect(); }
     79 
     80 void TCPSocket::Connect(const std::string& address,
     81                         int port,
     82                         const CompletionCallback& callback) {
     83   DCHECK(!callback.is_null());
     84 
     85   if (socket_mode_ == SERVER || !connect_callback_.is_null()) {
     86     callback.Run(net::ERR_CONNECTION_FAILED);
     87     return;
     88   }
     89   DCHECK(!server_socket_.get());
     90   socket_mode_ = CLIENT;
     91   connect_callback_ = callback;
     92 
     93   int result = net::ERR_CONNECTION_FAILED;
     94   do {
     95     if (is_connected_)
     96       break;
     97 
     98     net::AddressList address_list;
     99     if (!StringAndPortToAddressList(address, port, &address_list)) {
    100       result = net::ERR_ADDRESS_INVALID;
    101       break;
    102     }
    103 
    104     socket_.reset(
    105         new net::TCPClientSocket(address_list, NULL, net::NetLog::Source()));
    106 
    107     connect_callback_ = callback;
    108     result = socket_->Connect(
    109         base::Bind(&TCPSocket::OnConnectComplete, base::Unretained(this)));
    110   } while (false);
    111 
    112   if (result != net::ERR_IO_PENDING)
    113     OnConnectComplete(result);
    114 }
    115 
    116 void TCPSocket::Disconnect() {
    117   is_connected_ = false;
    118   if (socket_.get())
    119     socket_->Disconnect();
    120   server_socket_.reset(NULL);
    121   connect_callback_.Reset();
    122   read_callback_.Reset();
    123   accept_callback_.Reset();
    124   accept_socket_.reset(NULL);
    125 }
    126 
    127 int TCPSocket::Bind(const std::string& address, int port) {
    128   return net::ERR_FAILED;
    129 }
    130 
    131 void TCPSocket::Read(int count, const ReadCompletionCallback& callback) {
    132   DCHECK(!callback.is_null());
    133 
    134   if (socket_mode_ != CLIENT) {
    135     callback.Run(net::ERR_FAILED, NULL);
    136     return;
    137   }
    138 
    139   if (!read_callback_.is_null()) {
    140     callback.Run(net::ERR_IO_PENDING, NULL);
    141     return;
    142   }
    143 
    144   if (count < 0) {
    145     callback.Run(net::ERR_INVALID_ARGUMENT, NULL);
    146     return;
    147   }
    148 
    149   if (!socket_.get() || !IsConnected()) {
    150     callback.Run(net::ERR_SOCKET_NOT_CONNECTED, NULL);
    151     return;
    152   }
    153 
    154   read_callback_ = callback;
    155   scoped_refptr<net::IOBuffer> io_buffer = new net::IOBuffer(count);
    156   int result = socket_->Read(
    157       io_buffer.get(),
    158       count,
    159       base::Bind(
    160           &TCPSocket::OnReadComplete, base::Unretained(this), io_buffer));
    161 
    162   if (result != net::ERR_IO_PENDING)
    163     OnReadComplete(io_buffer, result);
    164 }
    165 
    166 void TCPSocket::RecvFrom(int count,
    167                          const RecvFromCompletionCallback& callback) {
    168   callback.Run(net::ERR_FAILED, NULL, NULL, 0);
    169 }
    170 
    171 void TCPSocket::SendTo(scoped_refptr<net::IOBuffer> io_buffer,
    172                        int byte_count,
    173                        const std::string& address,
    174                        int port,
    175                        const CompletionCallback& callback) {
    176   callback.Run(net::ERR_FAILED);
    177 }
    178 
    179 bool TCPSocket::SetKeepAlive(bool enable, int delay) {
    180   if (!socket_.get())
    181     return false;
    182   return socket_->SetKeepAlive(enable, delay);
    183 }
    184 
    185 bool TCPSocket::SetNoDelay(bool no_delay) {
    186   if (!socket_.get())
    187     return false;
    188   return socket_->SetNoDelay(no_delay);
    189 }
    190 
    191 int TCPSocket::Listen(const std::string& address,
    192                       int port,
    193                       int backlog,
    194                       std::string* error_msg) {
    195   if (socket_mode_ == CLIENT) {
    196     *error_msg = kTCPSocketTypeInvalidError;
    197     return net::ERR_NOT_IMPLEMENTED;
    198   }
    199   DCHECK(!socket_.get());
    200   socket_mode_ = SERVER;
    201 
    202   if (!server_socket_.get()) {
    203     server_socket_.reset(new net::TCPServerSocket(NULL, net::NetLog::Source()));
    204   }
    205 
    206   int result = server_socket_->ListenWithAddressAndPort(address, port, backlog);
    207   if (result)
    208     *error_msg = kSocketListenError;
    209   return result;
    210 }
    211 
    212 void TCPSocket::Accept(const AcceptCompletionCallback& callback) {
    213   if (socket_mode_ != SERVER || !server_socket_.get()) {
    214     callback.Run(net::ERR_FAILED, NULL);
    215     return;
    216   }
    217 
    218   // Limits to only 1 blocked accept call.
    219   if (!accept_callback_.is_null()) {
    220     callback.Run(net::ERR_FAILED, NULL);
    221     return;
    222   }
    223 
    224   int result = server_socket_->Accept(
    225       &accept_socket_,
    226       base::Bind(&TCPSocket::OnAccept, base::Unretained(this)));
    227   if (result == net::ERR_IO_PENDING) {
    228     accept_callback_ = callback;
    229   } else if (result == net::OK) {
    230     accept_callback_ = callback;
    231     this->OnAccept(result);
    232   } else {
    233     callback.Run(result, NULL);
    234   }
    235 }
    236 
    237 bool TCPSocket::IsConnected() {
    238   RefreshConnectionStatus();
    239   return is_connected_;
    240 }
    241 
    242 bool TCPSocket::GetPeerAddress(net::IPEndPoint* address) {
    243   if (!socket_.get())
    244     return false;
    245   return !socket_->GetPeerAddress(address);
    246 }
    247 
    248 bool TCPSocket::GetLocalAddress(net::IPEndPoint* address) {
    249   if (socket_.get()) {
    250     return !socket_->GetLocalAddress(address);
    251   } else if (server_socket_.get()) {
    252     return !server_socket_->GetLocalAddress(address);
    253   } else {
    254     return false;
    255   }
    256 }
    257 
    258 Socket::SocketType TCPSocket::GetSocketType() const { return Socket::TYPE_TCP; }
    259 
    260 int TCPSocket::WriteImpl(net::IOBuffer* io_buffer,
    261                          int io_buffer_size,
    262                          const net::CompletionCallback& callback) {
    263   if (socket_mode_ != CLIENT)
    264     return net::ERR_FAILED;
    265   else if (!socket_.get() || !IsConnected())
    266     return net::ERR_SOCKET_NOT_CONNECTED;
    267   else
    268     return socket_->Write(io_buffer, io_buffer_size, callback);
    269 }
    270 
    271 void TCPSocket::RefreshConnectionStatus() {
    272   if (!is_connected_)
    273     return;
    274   if (server_socket_)
    275     return;
    276   if (!socket_->IsConnected()) {
    277     Disconnect();
    278   }
    279 }
    280 
    281 void TCPSocket::OnConnectComplete(int result) {
    282   DCHECK(!connect_callback_.is_null());
    283   DCHECK(!is_connected_);
    284   is_connected_ = result == net::OK;
    285   connect_callback_.Run(result);
    286   connect_callback_.Reset();
    287 }
    288 
    289 void TCPSocket::OnReadComplete(scoped_refptr<net::IOBuffer> io_buffer,
    290                                int result) {
    291   DCHECK(!read_callback_.is_null());
    292   read_callback_.Run(result, io_buffer);
    293   read_callback_.Reset();
    294 }
    295 
    296 void TCPSocket::OnAccept(int result) {
    297   DCHECK(!accept_callback_.is_null());
    298   if (result == net::OK && accept_socket_.get()) {
    299     accept_callback_.Run(
    300         result, static_cast<net::TCPClientSocket*>(accept_socket_.release()));
    301   } else {
    302     accept_callback_.Run(result, NULL);
    303   }
    304   accept_callback_.Reset();
    305 }
    306 
    307 void TCPSocket::Release() {
    308   // Release() is only invoked when the underlying sockets are taken (via
    309   // ClientStream()) by TLSSocket. TLSSocket only supports CLIENT-mode
    310   // sockets.
    311   DCHECK(!server_socket_.release() && !accept_socket_.release() &&
    312          socket_mode_ == CLIENT)
    313       << "Called in server mode.";
    314 
    315   // Release() doesn't disconnect the underlying sockets, but it does
    316   // disconnect them from this TCPSocket.
    317   is_connected_ = false;
    318 
    319   connect_callback_.Reset();
    320   read_callback_.Reset();
    321   accept_callback_.Reset();
    322 
    323   DCHECK(socket_.get()) << "Called on null client socket.";
    324   ignore_result(socket_.release());
    325 }
    326 
    327 net::TCPClientSocket* TCPSocket::ClientStream() {
    328   if (socket_mode_ != CLIENT || GetSocketType() != TYPE_TCP)
    329     return NULL;
    330   return socket_.get();
    331 }
    332 
    333 bool TCPSocket::HasPendingRead() const {
    334   return !read_callback_.is_null();
    335 }
    336 
    337 ResumableTCPSocket::ResumableTCPSocket(const std::string& owner_extension_id)
    338     : TCPSocket(owner_extension_id),
    339       persistent_(false),
    340       buffer_size_(0),
    341       paused_(false) {}
    342 
    343 ResumableTCPSocket::ResumableTCPSocket(net::TCPClientSocket* tcp_client_socket,
    344                                        const std::string& owner_extension_id,
    345                                        bool is_connected)
    346     : TCPSocket(tcp_client_socket, owner_extension_id, is_connected),
    347       persistent_(false),
    348       buffer_size_(0),
    349       paused_(false) {}
    350 
    351 bool ResumableTCPSocket::IsPersistent() const { return persistent(); }
    352 
    353 ResumableTCPServerSocket::ResumableTCPServerSocket(
    354     const std::string& owner_extension_id)
    355     : TCPSocket(owner_extension_id), persistent_(false), paused_(false) {}
    356 
    357 bool ResumableTCPServerSocket::IsPersistent() const { return persistent(); }
    358 
    359 }  // namespace extensions
    360