Home | History | Annotate | Download | only in socket
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #ifndef EXTENSIONS_BROWSER_API_SOCKET_SOCKET_API_H_
      6 #define EXTENSIONS_BROWSER_API_SOCKET_SOCKET_API_H_
      7 
      8 #include <string>
      9 
     10 #include "base/gtest_prod_util.h"
     11 #include "base/memory/ref_counted.h"
     12 #include "extensions/browser/api/api_resource_manager.h"
     13 #include "extensions/browser/api/async_api_function.h"
     14 #include "extensions/browser/extension_function.h"
     15 #include "extensions/common/api/socket.h"
     16 #include "net/base/address_list.h"
     17 #include "net/dns/host_resolver.h"
     18 #include "net/socket/tcp_client_socket.h"
     19 
     20 namespace content {
     21 class BrowserContext;
     22 class ResourceContext;
     23 }
     24 
     25 namespace net {
     26 class IOBuffer;
     27 class URLRequestContextGetter;
     28 class SSLClientSocket;
     29 }
     30 
     31 namespace extensions {
     32 class TLSSocket;
     33 class Socket;
     34 
     35 // A simple interface to ApiResourceManager<Socket> or derived class. The goal
     36 // of this interface is to allow Socket API functions to use distinct instances
     37 // of ApiResourceManager<> depending on the type of socket (old version in
     38 // "socket" namespace vs new version in "socket.xxx" namespaces).
     39 class SocketResourceManagerInterface {
     40  public:
     41   virtual ~SocketResourceManagerInterface() {}
     42 
     43   virtual bool SetBrowserContext(content::BrowserContext* context) = 0;
     44   virtual int Add(Socket* socket) = 0;
     45   virtual Socket* Get(const std::string& extension_id, int api_resource_id) = 0;
     46   virtual void Remove(const std::string& extension_id, int api_resource_id) = 0;
     47   virtual void Replace(const std::string& extension_id,
     48                        int api_resource_id,
     49                        Socket* socket) = 0;
     50   virtual base::hash_set<int>* GetResourceIds(
     51       const std::string& extension_id) = 0;
     52 };
     53 
     54 // Implementation of SocketResourceManagerInterface using an
     55 // ApiResourceManager<T> instance (where T derives from Socket).
     56 template <typename T>
     57 class SocketResourceManager : public SocketResourceManagerInterface {
     58  public:
     59   SocketResourceManager() : manager_(NULL) {}
     60 
     61   virtual bool SetBrowserContext(content::BrowserContext* context) OVERRIDE {
     62     manager_ = ApiResourceManager<T>::Get(context);
     63     DCHECK(manager_)
     64         << "There is no socket manager. "
     65            "If this assertion is failing during a test, then it is likely that "
     66            "TestExtensionSystem is failing to provide an instance of "
     67            "ApiResourceManager<Socket>.";
     68     return manager_ != NULL;
     69   }
     70 
     71   virtual int Add(Socket* socket) OVERRIDE {
     72     // Note: Cast needed here, because "T" may be a subclass of "Socket".
     73     return manager_->Add(static_cast<T*>(socket));
     74   }
     75 
     76   virtual Socket* Get(const std::string& extension_id,
     77                       int api_resource_id) OVERRIDE {
     78     return manager_->Get(extension_id, api_resource_id);
     79   }
     80 
     81   virtual void Replace(const std::string& extension_id,
     82                        int api_resource_id,
     83                        Socket* socket) OVERRIDE {
     84     manager_->Replace(extension_id, api_resource_id, static_cast<T*>(socket));
     85   }
     86 
     87   virtual void Remove(const std::string& extension_id,
     88                       int api_resource_id) OVERRIDE {
     89     manager_->Remove(extension_id, api_resource_id);
     90   }
     91 
     92   virtual base::hash_set<int>* GetResourceIds(const std::string& extension_id)
     93       OVERRIDE {
     94     return manager_->GetResourceIds(extension_id);
     95   }
     96 
     97  private:
     98   ApiResourceManager<T>* manager_;
     99 };
    100 
    101 class SocketAsyncApiFunction : public AsyncApiFunction {
    102  public:
    103   SocketAsyncApiFunction();
    104 
    105  protected:
    106   virtual ~SocketAsyncApiFunction();
    107 
    108   // AsyncApiFunction:
    109   virtual bool PrePrepare() OVERRIDE;
    110   virtual bool Respond() OVERRIDE;
    111 
    112   virtual scoped_ptr<SocketResourceManagerInterface>
    113       CreateSocketResourceManager();
    114 
    115   int AddSocket(Socket* socket);
    116   Socket* GetSocket(int api_resource_id);
    117   void ReplaceSocket(int api_resource_id, Socket* socket);
    118   void RemoveSocket(int api_resource_id);
    119   base::hash_set<int>* GetSocketIds();
    120 
    121  private:
    122   scoped_ptr<SocketResourceManagerInterface> manager_;
    123 };
    124 
    125 class SocketExtensionWithDnsLookupFunction : public SocketAsyncApiFunction {
    126  protected:
    127   SocketExtensionWithDnsLookupFunction();
    128   virtual ~SocketExtensionWithDnsLookupFunction();
    129 
    130   // AsyncApiFunction:
    131   virtual bool PrePrepare() OVERRIDE;
    132 
    133   void StartDnsLookup(const std::string& hostname);
    134   virtual void AfterDnsLookup(int lookup_result) = 0;
    135 
    136   std::string resolved_address_;
    137 
    138  private:
    139   void OnDnsLookup(int resolve_result);
    140 
    141   // Weak pointer to the resource context.
    142   content::ResourceContext* resource_context_;
    143 
    144   scoped_ptr<net::HostResolver::RequestHandle> request_handle_;
    145   scoped_ptr<net::AddressList> addresses_;
    146 };
    147 
    148 class SocketCreateFunction : public SocketAsyncApiFunction {
    149  public:
    150   DECLARE_EXTENSION_FUNCTION("socket.create", SOCKET_CREATE)
    151 
    152   SocketCreateFunction();
    153 
    154  protected:
    155   virtual ~SocketCreateFunction();
    156 
    157   // AsyncApiFunction:
    158   virtual bool Prepare() OVERRIDE;
    159   virtual void Work() OVERRIDE;
    160 
    161  private:
    162   FRIEND_TEST_ALL_PREFIXES(SocketUnitTest, Create);
    163   enum SocketType { kSocketTypeInvalid = -1, kSocketTypeTCP, kSocketTypeUDP };
    164 
    165   scoped_ptr<core_api::socket::Create::Params> params_;
    166   SocketType socket_type_;
    167 };
    168 
    169 class SocketDestroyFunction : public SocketAsyncApiFunction {
    170  public:
    171   DECLARE_EXTENSION_FUNCTION("socket.destroy", SOCKET_DESTROY)
    172 
    173  protected:
    174   virtual ~SocketDestroyFunction() {}
    175 
    176   // AsyncApiFunction:
    177   virtual bool Prepare() OVERRIDE;
    178   virtual void Work() OVERRIDE;
    179 
    180  private:
    181   int socket_id_;
    182 };
    183 
    184 class SocketConnectFunction : public SocketExtensionWithDnsLookupFunction {
    185  public:
    186   DECLARE_EXTENSION_FUNCTION("socket.connect", SOCKET_CONNECT)
    187 
    188   SocketConnectFunction();
    189 
    190  protected:
    191   virtual ~SocketConnectFunction();
    192 
    193   // AsyncApiFunction:
    194   virtual bool Prepare() OVERRIDE;
    195   virtual void AsyncWorkStart() OVERRIDE;
    196 
    197   // SocketExtensionWithDnsLookupFunction:
    198   virtual void AfterDnsLookup(int lookup_result) OVERRIDE;
    199 
    200  private:
    201   void StartConnect();
    202   void OnConnect(int result);
    203 
    204   int socket_id_;
    205   std::string hostname_;
    206   int port_;
    207   Socket* socket_;
    208 };
    209 
    210 class SocketDisconnectFunction : public SocketAsyncApiFunction {
    211  public:
    212   DECLARE_EXTENSION_FUNCTION("socket.disconnect", SOCKET_DISCONNECT)
    213 
    214  protected:
    215   virtual ~SocketDisconnectFunction() {}
    216 
    217   // AsyncApiFunction:
    218   virtual bool Prepare() OVERRIDE;
    219   virtual void Work() OVERRIDE;
    220 
    221  private:
    222   int socket_id_;
    223 };
    224 
    225 class SocketBindFunction : public SocketAsyncApiFunction {
    226  public:
    227   DECLARE_EXTENSION_FUNCTION("socket.bind", SOCKET_BIND)
    228 
    229  protected:
    230   virtual ~SocketBindFunction() {}
    231 
    232   // AsyncApiFunction:
    233   virtual bool Prepare() OVERRIDE;
    234   virtual void Work() OVERRIDE;
    235 
    236  private:
    237   int socket_id_;
    238   std::string address_;
    239   int port_;
    240 };
    241 
    242 class SocketListenFunction : public SocketAsyncApiFunction {
    243  public:
    244   DECLARE_EXTENSION_FUNCTION("socket.listen", SOCKET_LISTEN)
    245 
    246   SocketListenFunction();
    247 
    248  protected:
    249   virtual ~SocketListenFunction();
    250 
    251   // AsyncApiFunction:
    252   virtual bool Prepare() OVERRIDE;
    253   virtual void Work() OVERRIDE;
    254 
    255  private:
    256   scoped_ptr<core_api::socket::Listen::Params> params_;
    257 };
    258 
    259 class SocketAcceptFunction : public SocketAsyncApiFunction {
    260  public:
    261   DECLARE_EXTENSION_FUNCTION("socket.accept", SOCKET_ACCEPT)
    262 
    263   SocketAcceptFunction();
    264 
    265  protected:
    266   virtual ~SocketAcceptFunction();
    267 
    268   // AsyncApiFunction:
    269   virtual bool Prepare() OVERRIDE;
    270   virtual void AsyncWorkStart() OVERRIDE;
    271 
    272  private:
    273   void OnAccept(int result_code, net::TCPClientSocket* socket);
    274 
    275   scoped_ptr<core_api::socket::Accept::Params> params_;
    276 };
    277 
    278 class SocketReadFunction : public SocketAsyncApiFunction {
    279  public:
    280   DECLARE_EXTENSION_FUNCTION("socket.read", SOCKET_READ)
    281 
    282   SocketReadFunction();
    283 
    284  protected:
    285   virtual ~SocketReadFunction();
    286 
    287   // AsyncApiFunction:
    288   virtual bool Prepare() OVERRIDE;
    289   virtual void AsyncWorkStart() OVERRIDE;
    290   void OnCompleted(int result, scoped_refptr<net::IOBuffer> io_buffer);
    291 
    292  private:
    293   scoped_ptr<core_api::socket::Read::Params> params_;
    294 };
    295 
    296 class SocketWriteFunction : public SocketAsyncApiFunction {
    297  public:
    298   DECLARE_EXTENSION_FUNCTION("socket.write", SOCKET_WRITE)
    299 
    300   SocketWriteFunction();
    301 
    302  protected:
    303   virtual ~SocketWriteFunction();
    304 
    305   // AsyncApiFunction:
    306   virtual bool Prepare() OVERRIDE;
    307   virtual void AsyncWorkStart() OVERRIDE;
    308   void OnCompleted(int result);
    309 
    310  private:
    311   int socket_id_;
    312   scoped_refptr<net::IOBuffer> io_buffer_;
    313   size_t io_buffer_size_;
    314 };
    315 
    316 class SocketRecvFromFunction : public SocketAsyncApiFunction {
    317  public:
    318   DECLARE_EXTENSION_FUNCTION("socket.recvFrom", SOCKET_RECVFROM)
    319 
    320   SocketRecvFromFunction();
    321 
    322  protected:
    323   virtual ~SocketRecvFromFunction();
    324 
    325   // AsyncApiFunction
    326   virtual bool Prepare() OVERRIDE;
    327   virtual void AsyncWorkStart() OVERRIDE;
    328   void OnCompleted(int result,
    329                    scoped_refptr<net::IOBuffer> io_buffer,
    330                    const std::string& address,
    331                    int port);
    332 
    333  private:
    334   scoped_ptr<core_api::socket::RecvFrom::Params> params_;
    335 };
    336 
    337 class SocketSendToFunction : public SocketExtensionWithDnsLookupFunction {
    338  public:
    339   DECLARE_EXTENSION_FUNCTION("socket.sendTo", SOCKET_SENDTO)
    340 
    341   SocketSendToFunction();
    342 
    343  protected:
    344   virtual ~SocketSendToFunction();
    345 
    346   // AsyncApiFunction:
    347   virtual bool Prepare() OVERRIDE;
    348   virtual void AsyncWorkStart() OVERRIDE;
    349   void OnCompleted(int result);
    350 
    351   // SocketExtensionWithDnsLookupFunction:
    352   virtual void AfterDnsLookup(int lookup_result) OVERRIDE;
    353 
    354  private:
    355   void StartSendTo();
    356 
    357   int socket_id_;
    358   scoped_refptr<net::IOBuffer> io_buffer_;
    359   size_t io_buffer_size_;
    360   std::string hostname_;
    361   int port_;
    362   Socket* socket_;
    363 };
    364 
    365 class SocketSetKeepAliveFunction : public SocketAsyncApiFunction {
    366  public:
    367   DECLARE_EXTENSION_FUNCTION("socket.setKeepAlive", SOCKET_SETKEEPALIVE)
    368 
    369   SocketSetKeepAliveFunction();
    370 
    371  protected:
    372   virtual ~SocketSetKeepAliveFunction();
    373 
    374   // AsyncApiFunction:
    375   virtual bool Prepare() OVERRIDE;
    376   virtual void Work() OVERRIDE;
    377 
    378  private:
    379   scoped_ptr<core_api::socket::SetKeepAlive::Params> params_;
    380 };
    381 
    382 class SocketSetNoDelayFunction : public SocketAsyncApiFunction {
    383  public:
    384   DECLARE_EXTENSION_FUNCTION("socket.setNoDelay", SOCKET_SETNODELAY)
    385 
    386   SocketSetNoDelayFunction();
    387 
    388  protected:
    389   virtual ~SocketSetNoDelayFunction();
    390 
    391   // AsyncApiFunction:
    392   virtual bool Prepare() OVERRIDE;
    393   virtual void Work() OVERRIDE;
    394 
    395  private:
    396   scoped_ptr<core_api::socket::SetNoDelay::Params> params_;
    397 };
    398 
    399 class SocketGetInfoFunction : public SocketAsyncApiFunction {
    400  public:
    401   DECLARE_EXTENSION_FUNCTION("socket.getInfo", SOCKET_GETINFO)
    402 
    403   SocketGetInfoFunction();
    404 
    405  protected:
    406   virtual ~SocketGetInfoFunction();
    407 
    408   // AsyncApiFunction:
    409   virtual bool Prepare() OVERRIDE;
    410   virtual void Work() OVERRIDE;
    411 
    412  private:
    413   scoped_ptr<core_api::socket::GetInfo::Params> params_;
    414 };
    415 
    416 class SocketGetNetworkListFunction : public AsyncExtensionFunction {
    417  public:
    418   DECLARE_EXTENSION_FUNCTION("socket.getNetworkList", SOCKET_GETNETWORKLIST)
    419 
    420  protected:
    421   virtual ~SocketGetNetworkListFunction() {}
    422   virtual bool RunAsync() OVERRIDE;
    423 
    424  private:
    425   void GetNetworkListOnFileThread();
    426   void HandleGetNetworkListError();
    427   void SendResponseOnUIThread(const net::NetworkInterfaceList& interface_list);
    428 };
    429 
    430 class SocketJoinGroupFunction : public SocketAsyncApiFunction {
    431  public:
    432   DECLARE_EXTENSION_FUNCTION("socket.joinGroup", SOCKET_MULTICAST_JOIN_GROUP)
    433 
    434   SocketJoinGroupFunction();
    435 
    436  protected:
    437   virtual ~SocketJoinGroupFunction();
    438 
    439   // AsyncApiFunction
    440   virtual bool Prepare() OVERRIDE;
    441   virtual void Work() OVERRIDE;
    442 
    443  private:
    444   scoped_ptr<core_api::socket::JoinGroup::Params> params_;
    445 };
    446 
    447 class SocketLeaveGroupFunction : public SocketAsyncApiFunction {
    448  public:
    449   DECLARE_EXTENSION_FUNCTION("socket.leaveGroup", SOCKET_MULTICAST_LEAVE_GROUP)
    450 
    451   SocketLeaveGroupFunction();
    452 
    453  protected:
    454   virtual ~SocketLeaveGroupFunction();
    455 
    456   // AsyncApiFunction
    457   virtual bool Prepare() OVERRIDE;
    458   virtual void Work() OVERRIDE;
    459 
    460  private:
    461   scoped_ptr<core_api::socket::LeaveGroup::Params> params_;
    462 };
    463 
    464 class SocketSetMulticastTimeToLiveFunction : public SocketAsyncApiFunction {
    465  public:
    466   DECLARE_EXTENSION_FUNCTION("socket.setMulticastTimeToLive",
    467                              SOCKET_MULTICAST_SET_TIME_TO_LIVE)
    468 
    469   SocketSetMulticastTimeToLiveFunction();
    470 
    471  protected:
    472   virtual ~SocketSetMulticastTimeToLiveFunction();
    473 
    474   // AsyncApiFunction
    475   virtual bool Prepare() OVERRIDE;
    476   virtual void Work() OVERRIDE;
    477 
    478  private:
    479   scoped_ptr<core_api::socket::SetMulticastTimeToLive::Params> params_;
    480 };
    481 
    482 class SocketSetMulticastLoopbackModeFunction : public SocketAsyncApiFunction {
    483  public:
    484   DECLARE_EXTENSION_FUNCTION("socket.setMulticastLoopbackMode",
    485                              SOCKET_MULTICAST_SET_LOOPBACK_MODE)
    486 
    487   SocketSetMulticastLoopbackModeFunction();
    488 
    489  protected:
    490   virtual ~SocketSetMulticastLoopbackModeFunction();
    491 
    492   // AsyncApiFunction
    493   virtual bool Prepare() OVERRIDE;
    494   virtual void Work() OVERRIDE;
    495 
    496  private:
    497   scoped_ptr<core_api::socket::SetMulticastLoopbackMode::Params> params_;
    498 };
    499 
    500 class SocketGetJoinedGroupsFunction : public SocketAsyncApiFunction {
    501  public:
    502   DECLARE_EXTENSION_FUNCTION("socket.getJoinedGroups",
    503                              SOCKET_MULTICAST_GET_JOINED_GROUPS)
    504 
    505   SocketGetJoinedGroupsFunction();
    506 
    507  protected:
    508   virtual ~SocketGetJoinedGroupsFunction();
    509 
    510   // AsyncApiFunction
    511   virtual bool Prepare() OVERRIDE;
    512   virtual void Work() OVERRIDE;
    513 
    514  private:
    515   scoped_ptr<core_api::socket::GetJoinedGroups::Params> params_;
    516 };
    517 
    518 class SocketSecureFunction : public SocketAsyncApiFunction {
    519  public:
    520   DECLARE_EXTENSION_FUNCTION("socket.secure", SOCKET_SECURE);
    521   SocketSecureFunction();
    522 
    523  protected:
    524   virtual ~SocketSecureFunction();
    525 
    526   // AsyncApiFunction
    527   virtual bool Prepare() OVERRIDE;
    528   virtual void AsyncWorkStart() OVERRIDE;
    529 
    530  private:
    531   // Callback from TLSSocket::UpgradeSocketToTLS().
    532   void TlsConnectDone(scoped_ptr<TLSSocket> socket, int result);
    533 
    534   scoped_ptr<core_api::socket::Secure::Params> params_;
    535   scoped_refptr<net::URLRequestContextGetter> url_request_getter_;
    536 
    537   DISALLOW_COPY_AND_ASSIGN(SocketSecureFunction);
    538 };
    539 
    540 }  // namespace extensions
    541 
    542 #endif  // EXTENSIONS_BROWSER_API_SOCKET_SOCKET_API_H_
    543