Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "chrome/browser/extensions/api/socket/tcp_socket.h"
      6 
      7 #include "chrome/browser/extensions/api/api_resource.h"
      8 #include "net/base/address_list.h"
      9 #include "net/base/ip_endpoint.h"
     10 #include "net/base/net_errors.h"
     11 #include "net/base/rand_callback.h"
     12 #include "net/socket/tcp_client_socket.h"
     13 
     14 namespace extensions {
     15 
     16 const char kTCPSocketTypeInvalidError[] =
     17     "Cannot call both connect and listen on the same socket.";
     18 const char kSocketListenError[] = "Could not listen on the specified port.";
     19 
     20 static base::LazyInstance<ProfileKeyedAPIFactory<
     21       ApiResourceManager<ResumableTCPSocket> > >
     22           g_factory = LAZY_INSTANCE_INITIALIZER;
     23 
     24 // static
     25 template <>
     26 ProfileKeyedAPIFactory<ApiResourceManager<ResumableTCPSocket> >*
     27 ApiResourceManager<ResumableTCPSocket>::GetFactoryInstance() {
     28   return &g_factory.Get();
     29 }
     30 
     31 static base::LazyInstance<ProfileKeyedAPIFactory<
     32       ApiResourceManager<ResumableTCPServerSocket> > >
     33           g_server_factory = LAZY_INSTANCE_INITIALIZER;
     34 
     35 // static
     36 template <>
     37 ProfileKeyedAPIFactory<ApiResourceManager<ResumableTCPServerSocket> >*
     38 ApiResourceManager<ResumableTCPServerSocket>::GetFactoryInstance() {
     39   return &g_server_factory.Get();
     40 }
     41 
     42 TCPSocket::TCPSocket(const std::string& owner_extension_id)
     43     : Socket(owner_extension_id),
     44       socket_mode_(UNKNOWN) {
     45 }
     46 
     47 TCPSocket::TCPSocket(net::TCPClientSocket* tcp_client_socket,
     48                      const std::string& owner_extension_id,
     49                      bool is_connected)
     50     : Socket(owner_extension_id),
     51       socket_(tcp_client_socket),
     52       socket_mode_(CLIENT) {
     53   this->is_connected_ = is_connected;
     54 }
     55 
     56 TCPSocket::TCPSocket(net::TCPServerSocket* tcp_server_socket,
     57                      const std::string& owner_extension_id)
     58     : Socket(owner_extension_id),
     59       server_socket_(tcp_server_socket),
     60       socket_mode_(SERVER) {
     61 }
     62 
     63 // static
     64 TCPSocket* TCPSocket::CreateSocketForTesting(
     65     net::TCPClientSocket* tcp_client_socket,
     66     const std::string& owner_extension_id,
     67     bool is_connected) {
     68   return new TCPSocket(tcp_client_socket, owner_extension_id, is_connected);
     69 }
     70 
     71 // static
     72 TCPSocket* TCPSocket::CreateServerSocketForTesting(
     73     net::TCPServerSocket* tcp_server_socket,
     74     const std::string& owner_extension_id) {
     75   return new TCPSocket(tcp_server_socket, owner_extension_id);
     76 }
     77 
     78 TCPSocket::~TCPSocket() {
     79   Disconnect();
     80 }
     81 
     82 void TCPSocket::Connect(const std::string& address,
     83                         int port,
     84                         const CompletionCallback& callback) {
     85   DCHECK(!callback.is_null());
     86 
     87   if (socket_mode_ == SERVER || !connect_callback_.is_null()) {
     88     callback.Run(net::ERR_CONNECTION_FAILED);
     89     return;
     90   }
     91   DCHECK(!server_socket_.get());
     92   socket_mode_ = CLIENT;
     93   connect_callback_ = callback;
     94 
     95   int result = net::ERR_CONNECTION_FAILED;
     96   do {
     97     if (is_connected_)
     98       break;
     99 
    100     net::AddressList address_list;
    101     if (!StringAndPortToAddressList(address, port, &address_list)) {
    102       result = net::ERR_ADDRESS_INVALID;
    103       break;
    104     }
    105 
    106     socket_.reset(new net::TCPClientSocket(address_list, NULL,
    107                                            net::NetLog::Source()));
    108 
    109     connect_callback_ = callback;
    110     result = socket_->Connect(base::Bind(
    111         &TCPSocket::OnConnectComplete, base::Unretained(this)));
    112   } while (false);
    113 
    114   if (result != net::ERR_IO_PENDING)
    115     OnConnectComplete(result);
    116 }
    117 
    118 void TCPSocket::Disconnect() {
    119   is_connected_ = false;
    120   if (socket_.get())
    121     socket_->Disconnect();
    122   server_socket_.reset(NULL);
    123   connect_callback_.Reset();
    124   read_callback_.Reset();
    125   accept_callback_.Reset();
    126   accept_socket_.reset(NULL);
    127 }
    128 
    129 int TCPSocket::Bind(const std::string& address, int port) {
    130   return net::ERR_FAILED;
    131 }
    132 
    133 void TCPSocket::Read(int count,
    134                      const ReadCompletionCallback& callback) {
    135   DCHECK(!callback.is_null());
    136 
    137   if (socket_mode_ != CLIENT) {
    138     callback.Run(net::ERR_FAILED, NULL);
    139     return;
    140   }
    141 
    142   if (!read_callback_.is_null()) {
    143     callback.Run(net::ERR_IO_PENDING, NULL);
    144     return;
    145   }
    146 
    147   if (count < 0) {
    148     callback.Run(net::ERR_INVALID_ARGUMENT, NULL);
    149     return;
    150   }
    151 
    152   if (!socket_.get() || !IsConnected()) {
    153     callback.Run(net::ERR_SOCKET_NOT_CONNECTED, NULL);
    154     return;
    155   }
    156 
    157   read_callback_ = callback;
    158   scoped_refptr<net::IOBuffer> io_buffer = new net::IOBuffer(count);
    159   int result = socket_->Read(io_buffer.get(), count,
    160       base::Bind(&TCPSocket::OnReadComplete, base::Unretained(this),
    161           io_buffer));
    162 
    163   if (result != net::ERR_IO_PENDING)
    164     OnReadComplete(io_buffer, result);
    165 }
    166 
    167 void TCPSocket::RecvFrom(int count,
    168                          const RecvFromCompletionCallback& callback) {
    169   callback.Run(net::ERR_FAILED, NULL, NULL, 0);
    170 }
    171 
    172 void TCPSocket::SendTo(scoped_refptr<net::IOBuffer> io_buffer,
    173                        int byte_count,
    174                        const std::string& address,
    175                        int port,
    176                        const CompletionCallback& callback) {
    177   callback.Run(net::ERR_FAILED);
    178 }
    179 
    180 bool TCPSocket::SetKeepAlive(bool enable, int delay) {
    181   if (!socket_.get())
    182     return false;
    183   return socket_->SetKeepAlive(enable, delay);
    184 }
    185 
    186 bool TCPSocket::SetNoDelay(bool no_delay) {
    187   if (!socket_.get())
    188     return false;
    189   return socket_->SetNoDelay(no_delay);
    190 }
    191 
    192 int TCPSocket::Listen(const std::string& address, int port, int backlog,
    193                       std::string* error_msg) {
    194   if (socket_mode_ == CLIENT) {
    195     *error_msg = kTCPSocketTypeInvalidError;
    196     return net::ERR_NOT_IMPLEMENTED;
    197   }
    198   DCHECK(!socket_.get());
    199   socket_mode_ = SERVER;
    200 
    201   scoped_ptr<net::IPEndPoint> bind_address(new net::IPEndPoint());
    202   if (!StringAndPortToIPEndPoint(address, port, bind_address.get()))
    203     return net::ERR_INVALID_ARGUMENT;
    204 
    205   if (!server_socket_.get()) {
    206     server_socket_.reset(new net::TCPServerSocket(NULL,
    207                                                   net::NetLog::Source()));
    208   }
    209   int result = server_socket_->Listen(*bind_address, backlog);
    210   if (result)
    211     *error_msg = kSocketListenError;
    212   return result;
    213 }
    214 
    215 void TCPSocket::Accept(const AcceptCompletionCallback &callback) {
    216   if (socket_mode_ != SERVER || !server_socket_.get()) {
    217     callback.Run(net::ERR_FAILED, NULL);
    218     return;
    219   }
    220 
    221   // Limits to only 1 blocked accept call.
    222   if (!accept_callback_.is_null()) {
    223     callback.Run(net::ERR_FAILED, NULL);
    224     return;
    225   }
    226 
    227   int result = server_socket_->Accept(&accept_socket_, base::Bind(
    228       &TCPSocket::OnAccept, base::Unretained(this)));
    229   if (result == net::ERR_IO_PENDING) {
    230     accept_callback_ = callback;
    231   } else if (result == net::OK) {
    232     accept_callback_ = callback;
    233     this->OnAccept(result);
    234   } else {
    235     callback.Run(result, NULL);
    236   }
    237 }
    238 
    239 bool TCPSocket::IsConnected() {
    240   RefreshConnectionStatus();
    241   return is_connected_;
    242 }
    243 
    244 bool TCPSocket::GetPeerAddress(net::IPEndPoint* address) {
    245   if (!socket_.get())
    246     return false;
    247   return !socket_->GetPeerAddress(address);
    248 }
    249 
    250 bool TCPSocket::GetLocalAddress(net::IPEndPoint* address) {
    251   if (socket_.get()) {
    252     return !socket_->GetLocalAddress(address);
    253   } else if (server_socket_.get()) {
    254     return !server_socket_->GetLocalAddress(address);
    255   } else {
    256     return false;
    257   }
    258 }
    259 
    260 Socket::SocketType TCPSocket::GetSocketType() const {
    261   return Socket::TYPE_TCP;
    262 }
    263 
    264 int TCPSocket::WriteImpl(net::IOBuffer* io_buffer,
    265                          int io_buffer_size,
    266                          const net::CompletionCallback& callback) {
    267   if (socket_mode_ != CLIENT)
    268     return net::ERR_FAILED;
    269   else if (!socket_.get() || !IsConnected())
    270     return net::ERR_SOCKET_NOT_CONNECTED;
    271   else
    272     return socket_->Write(io_buffer, io_buffer_size, callback);
    273 }
    274 
    275 void TCPSocket::RefreshConnectionStatus() {
    276   if (!is_connected_) return;
    277   if (server_socket_) return;
    278   if (!socket_->IsConnected()) {
    279     Disconnect();
    280   }
    281 }
    282 
    283 void TCPSocket::OnConnectComplete(int result) {
    284   DCHECK(!connect_callback_.is_null());
    285   DCHECK(!is_connected_);
    286   is_connected_ = result == net::OK;
    287   connect_callback_.Run(result);
    288   connect_callback_.Reset();
    289 }
    290 
    291 void TCPSocket::OnReadComplete(scoped_refptr<net::IOBuffer> io_buffer,
    292                                int result) {
    293   DCHECK(!read_callback_.is_null());
    294   read_callback_.Run(result, io_buffer);
    295   read_callback_.Reset();
    296 }
    297 
    298 void TCPSocket::OnAccept(int result) {
    299   DCHECK(!accept_callback_.is_null());
    300   if (result == net::OK && accept_socket_.get()) {
    301     accept_callback_.Run(
    302         result, static_cast<net::TCPClientSocket*>(accept_socket_.release()));
    303   } else {
    304     accept_callback_.Run(result, NULL);
    305   }
    306   accept_callback_.Reset();
    307 }
    308 
    309 ResumableTCPSocket::ResumableTCPSocket(const std::string& owner_extension_id)
    310     : TCPSocket(owner_extension_id),
    311       persistent_(false),
    312       buffer_size_(0),
    313       paused_(false) {
    314 }
    315 
    316 ResumableTCPSocket::ResumableTCPSocket(net::TCPClientSocket* tcp_client_socket,
    317                                        const std::string& owner_extension_id,
    318                                        bool is_connected)
    319     : TCPSocket(tcp_client_socket, owner_extension_id, is_connected),
    320       persistent_(false),
    321       buffer_size_(0),
    322       paused_(false) {
    323 }
    324 
    325 bool ResumableTCPSocket::IsPersistent() const {
    326   return persistent();
    327 }
    328 
    329 ResumableTCPServerSocket::ResumableTCPServerSocket(
    330     const std::string& owner_extension_id)
    331     : TCPSocket(owner_extension_id),
    332       persistent_(false),
    333       paused_(false) {
    334 }
    335 
    336 bool ResumableTCPServerSocket::IsPersistent() const {
    337   return persistent();
    338 }
    339 
    340 }  // namespace extensions
    341