Home | History | Annotate | Download | only in socket
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "extensions/browser/api/socket/tcp_socket.h"
      6 
      7 #include "extensions/browser/api/api_resource.h"
      8 #include "net/base/address_list.h"
      9 #include "net/base/ip_endpoint.h"
     10 #include "net/base/net_errors.h"
     11 #include "net/base/rand_callback.h"
     12 #include "net/socket/tcp_client_socket.h"
     13 
     14 namespace extensions {
     15 
     16 const char kTCPSocketTypeInvalidError[] =
     17     "Cannot call both connect and listen on the same socket.";
     18 const char kSocketListenError[] = "Could not listen on the specified port.";
     19 
     20 static base::LazyInstance<
     21     BrowserContextKeyedAPIFactory<ApiResourceManager<ResumableTCPSocket> > >
     22     g_factory = LAZY_INSTANCE_INITIALIZER;
     23 
     24 // static
     25 template <>
     26 BrowserContextKeyedAPIFactory<ApiResourceManager<ResumableTCPSocket> >*
     27 ApiResourceManager<ResumableTCPSocket>::GetFactoryInstance() {
     28   return g_factory.Pointer();
     29 }
     30 
     31 static base::LazyInstance<BrowserContextKeyedAPIFactory<
     32     ApiResourceManager<ResumableTCPServerSocket> > > g_server_factory =
     33     LAZY_INSTANCE_INITIALIZER;
     34 
     35 // static
     36 template <>
     37 BrowserContextKeyedAPIFactory<ApiResourceManager<ResumableTCPServerSocket> >*
     38 ApiResourceManager<ResumableTCPServerSocket>::GetFactoryInstance() {
     39   return g_server_factory.Pointer();
     40 }
     41 
     42 TCPSocket::TCPSocket(const std::string& owner_extension_id)
     43     : Socket(owner_extension_id), socket_mode_(UNKNOWN) {}
     44 
     45 TCPSocket::TCPSocket(net::TCPClientSocket* tcp_client_socket,
     46                      const std::string& owner_extension_id,
     47                      bool is_connected)
     48     : Socket(owner_extension_id),
     49       socket_(tcp_client_socket),
     50       socket_mode_(CLIENT) {
     51   this->is_connected_ = is_connected;
     52 }
     53 
     54 TCPSocket::TCPSocket(net::TCPServerSocket* tcp_server_socket,
     55                      const std::string& owner_extension_id)
     56     : Socket(owner_extension_id),
     57       server_socket_(tcp_server_socket),
     58       socket_mode_(SERVER) {}
     59 
     60 // static
     61 TCPSocket* TCPSocket::CreateSocketForTesting(
     62     net::TCPClientSocket* tcp_client_socket,
     63     const std::string& owner_extension_id,
     64     bool is_connected) {
     65   return new TCPSocket(tcp_client_socket, owner_extension_id, is_connected);
     66 }
     67 
     68 // static
     69 TCPSocket* TCPSocket::CreateServerSocketForTesting(
     70     net::TCPServerSocket* tcp_server_socket,
     71     const std::string& owner_extension_id) {
     72   return new TCPSocket(tcp_server_socket, owner_extension_id);
     73 }
     74 
     75 TCPSocket::~TCPSocket() { Disconnect(); }
     76 
     77 void TCPSocket::Connect(const std::string& address,
     78                         int port,
     79                         const CompletionCallback& callback) {
     80   DCHECK(!callback.is_null());
     81 
     82   if (socket_mode_ == SERVER || !connect_callback_.is_null()) {
     83     callback.Run(net::ERR_CONNECTION_FAILED);
     84     return;
     85   }
     86   DCHECK(!server_socket_.get());
     87   socket_mode_ = CLIENT;
     88   connect_callback_ = callback;
     89 
     90   int result = net::ERR_CONNECTION_FAILED;
     91   do {
     92     if (is_connected_)
     93       break;
     94 
     95     net::AddressList address_list;
     96     if (!StringAndPortToAddressList(address, port, &address_list)) {
     97       result = net::ERR_ADDRESS_INVALID;
     98       break;
     99     }
    100 
    101     socket_.reset(
    102         new net::TCPClientSocket(address_list, NULL, net::NetLog::Source()));
    103 
    104     connect_callback_ = callback;
    105     result = socket_->Connect(
    106         base::Bind(&TCPSocket::OnConnectComplete, base::Unretained(this)));
    107   } while (false);
    108 
    109   if (result != net::ERR_IO_PENDING)
    110     OnConnectComplete(result);
    111 }
    112 
    113 void TCPSocket::Disconnect() {
    114   is_connected_ = false;
    115   if (socket_.get())
    116     socket_->Disconnect();
    117   server_socket_.reset(NULL);
    118   connect_callback_.Reset();
    119   read_callback_.Reset();
    120   accept_callback_.Reset();
    121   accept_socket_.reset(NULL);
    122 }
    123 
    124 int TCPSocket::Bind(const std::string& address, int port) {
    125   return net::ERR_FAILED;
    126 }
    127 
    128 void TCPSocket::Read(int count, const ReadCompletionCallback& callback) {
    129   DCHECK(!callback.is_null());
    130 
    131   if (socket_mode_ != CLIENT) {
    132     callback.Run(net::ERR_FAILED, NULL);
    133     return;
    134   }
    135 
    136   if (!read_callback_.is_null()) {
    137     callback.Run(net::ERR_IO_PENDING, NULL);
    138     return;
    139   }
    140 
    141   if (count < 0) {
    142     callback.Run(net::ERR_INVALID_ARGUMENT, NULL);
    143     return;
    144   }
    145 
    146   if (!socket_.get() || !IsConnected()) {
    147     callback.Run(net::ERR_SOCKET_NOT_CONNECTED, NULL);
    148     return;
    149   }
    150 
    151   read_callback_ = callback;
    152   scoped_refptr<net::IOBuffer> io_buffer = new net::IOBuffer(count);
    153   int result = socket_->Read(
    154       io_buffer.get(),
    155       count,
    156       base::Bind(
    157           &TCPSocket::OnReadComplete, base::Unretained(this), io_buffer));
    158 
    159   if (result != net::ERR_IO_PENDING)
    160     OnReadComplete(io_buffer, result);
    161 }
    162 
    163 void TCPSocket::RecvFrom(int count,
    164                          const RecvFromCompletionCallback& callback) {
    165   callback.Run(net::ERR_FAILED, NULL, NULL, 0);
    166 }
    167 
    168 void TCPSocket::SendTo(scoped_refptr<net::IOBuffer> io_buffer,
    169                        int byte_count,
    170                        const std::string& address,
    171                        int port,
    172                        const CompletionCallback& callback) {
    173   callback.Run(net::ERR_FAILED);
    174 }
    175 
    176 bool TCPSocket::SetKeepAlive(bool enable, int delay) {
    177   if (!socket_.get())
    178     return false;
    179   return socket_->SetKeepAlive(enable, delay);
    180 }
    181 
    182 bool TCPSocket::SetNoDelay(bool no_delay) {
    183   if (!socket_.get())
    184     return false;
    185   return socket_->SetNoDelay(no_delay);
    186 }
    187 
    188 int TCPSocket::Listen(const std::string& address,
    189                       int port,
    190                       int backlog,
    191                       std::string* error_msg) {
    192   if (socket_mode_ == CLIENT) {
    193     *error_msg = kTCPSocketTypeInvalidError;
    194     return net::ERR_NOT_IMPLEMENTED;
    195   }
    196   DCHECK(!socket_.get());
    197   socket_mode_ = SERVER;
    198 
    199   scoped_ptr<net::IPEndPoint> bind_address(new net::IPEndPoint());
    200   if (!StringAndPortToIPEndPoint(address, port, bind_address.get()))
    201     return net::ERR_INVALID_ARGUMENT;
    202 
    203   if (!server_socket_.get()) {
    204     server_socket_.reset(new net::TCPServerSocket(NULL, net::NetLog::Source()));
    205   }
    206   int result = server_socket_->Listen(*bind_address, backlog);
    207   if (result)
    208     *error_msg = kSocketListenError;
    209   return result;
    210 }
    211 
    212 void TCPSocket::Accept(const AcceptCompletionCallback& callback) {
    213   if (socket_mode_ != SERVER || !server_socket_.get()) {
    214     callback.Run(net::ERR_FAILED, NULL);
    215     return;
    216   }
    217 
    218   // Limits to only 1 blocked accept call.
    219   if (!accept_callback_.is_null()) {
    220     callback.Run(net::ERR_FAILED, NULL);
    221     return;
    222   }
    223 
    224   int result = server_socket_->Accept(
    225       &accept_socket_,
    226       base::Bind(&TCPSocket::OnAccept, base::Unretained(this)));
    227   if (result == net::ERR_IO_PENDING) {
    228     accept_callback_ = callback;
    229   } else if (result == net::OK) {
    230     accept_callback_ = callback;
    231     this->OnAccept(result);
    232   } else {
    233     callback.Run(result, NULL);
    234   }
    235 }
    236 
    237 bool TCPSocket::IsConnected() {
    238   RefreshConnectionStatus();
    239   return is_connected_;
    240 }
    241 
    242 bool TCPSocket::GetPeerAddress(net::IPEndPoint* address) {
    243   if (!socket_.get())
    244     return false;
    245   return !socket_->GetPeerAddress(address);
    246 }
    247 
    248 bool TCPSocket::GetLocalAddress(net::IPEndPoint* address) {
    249   if (socket_.get()) {
    250     return !socket_->GetLocalAddress(address);
    251   } else if (server_socket_.get()) {
    252     return !server_socket_->GetLocalAddress(address);
    253   } else {
    254     return false;
    255   }
    256 }
    257 
    258 Socket::SocketType TCPSocket::GetSocketType() const { return Socket::TYPE_TCP; }
    259 
    260 int TCPSocket::WriteImpl(net::IOBuffer* io_buffer,
    261                          int io_buffer_size,
    262                          const net::CompletionCallback& callback) {
    263   if (socket_mode_ != CLIENT)
    264     return net::ERR_FAILED;
    265   else if (!socket_.get() || !IsConnected())
    266     return net::ERR_SOCKET_NOT_CONNECTED;
    267   else
    268     return socket_->Write(io_buffer, io_buffer_size, callback);
    269 }
    270 
    271 void TCPSocket::RefreshConnectionStatus() {
    272   if (!is_connected_)
    273     return;
    274   if (server_socket_)
    275     return;
    276   if (!socket_->IsConnected()) {
    277     Disconnect();
    278   }
    279 }
    280 
    281 void TCPSocket::OnConnectComplete(int result) {
    282   DCHECK(!connect_callback_.is_null());
    283   DCHECK(!is_connected_);
    284   is_connected_ = result == net::OK;
    285   connect_callback_.Run(result);
    286   connect_callback_.Reset();
    287 }
    288 
    289 void TCPSocket::OnReadComplete(scoped_refptr<net::IOBuffer> io_buffer,
    290                                int result) {
    291   DCHECK(!read_callback_.is_null());
    292   read_callback_.Run(result, io_buffer);
    293   read_callback_.Reset();
    294 }
    295 
    296 void TCPSocket::OnAccept(int result) {
    297   DCHECK(!accept_callback_.is_null());
    298   if (result == net::OK && accept_socket_.get()) {
    299     accept_callback_.Run(
    300         result, static_cast<net::TCPClientSocket*>(accept_socket_.release()));
    301   } else {
    302     accept_callback_.Run(result, NULL);
    303   }
    304   accept_callback_.Reset();
    305 }
    306 
    307 ResumableTCPSocket::ResumableTCPSocket(const std::string& owner_extension_id)
    308     : TCPSocket(owner_extension_id),
    309       persistent_(false),
    310       buffer_size_(0),
    311       paused_(false) {}
    312 
    313 ResumableTCPSocket::ResumableTCPSocket(net::TCPClientSocket* tcp_client_socket,
    314                                        const std::string& owner_extension_id,
    315                                        bool is_connected)
    316     : TCPSocket(tcp_client_socket, owner_extension_id, is_connected),
    317       persistent_(false),
    318       buffer_size_(0),
    319       paused_(false) {}
    320 
    321 bool ResumableTCPSocket::IsPersistent() const { return persistent(); }
    322 
    323 ResumableTCPServerSocket::ResumableTCPServerSocket(
    324     const std::string& owner_extension_id)
    325     : TCPSocket(owner_extension_id), persistent_(false), paused_(false) {}
    326 
    327 bool ResumableTCPServerSocket::IsPersistent() const { return persistent(); }
    328 
    329 }  // namespace extensions
    330