Home | History | Annotate | Download | only in base
      1 /*
      2  *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #include "webrtc/base/firewallsocketserver.h"
     12 
     13 #include <assert.h>
     14 
     15 #include <algorithm>
     16 
     17 #include "webrtc/base/asyncsocket.h"
     18 #include "webrtc/base/logging.h"
     19 
     20 namespace rtc {
     21 
     22 class FirewallSocket : public AsyncSocketAdapter {
     23  public:
     24   FirewallSocket(FirewallSocketServer* server, AsyncSocket* socket, int type)
     25     : AsyncSocketAdapter(socket), server_(server), type_(type) {
     26   }
     27 
     28   virtual int Connect(const SocketAddress& addr) {
     29     if (type_ == SOCK_STREAM) {
     30       if (!server_->Check(FP_TCP, GetLocalAddress(), addr)) {
     31         LOG(LS_VERBOSE) << "FirewallSocket outbound TCP connection from "
     32                         << GetLocalAddress().ToSensitiveString() << " to "
     33                         << addr.ToSensitiveString() << " denied";
     34         // TODO: Handle this asynchronously.
     35         SetError(EHOSTUNREACH);
     36         return SOCKET_ERROR;
     37       }
     38     }
     39     return AsyncSocketAdapter::Connect(addr);
     40   }
     41   virtual int Send(const void* pv, size_t cb) {
     42     return SendTo(pv, cb, GetRemoteAddress());
     43   }
     44   virtual int SendTo(const void* pv, size_t cb, const SocketAddress& addr) {
     45     if (type_ == SOCK_DGRAM) {
     46       if (!server_->Check(FP_UDP, GetLocalAddress(), addr)) {
     47         LOG(LS_VERBOSE) << "FirewallSocket outbound UDP packet from "
     48                         << GetLocalAddress().ToSensitiveString() << " to "
     49                         << addr.ToSensitiveString() << " dropped";
     50         return static_cast<int>(cb);
     51       }
     52     }
     53     return AsyncSocketAdapter::SendTo(pv, cb, addr);
     54   }
     55   virtual int Recv(void* pv, size_t cb) {
     56     SocketAddress addr;
     57     return RecvFrom(pv, cb, &addr);
     58   }
     59   virtual int RecvFrom(void* pv, size_t cb, SocketAddress* paddr) {
     60     if (type_ == SOCK_DGRAM) {
     61       while (true) {
     62         int res = AsyncSocketAdapter::RecvFrom(pv, cb, paddr);
     63         if (res <= 0)
     64           return res;
     65         if (server_->Check(FP_UDP, *paddr, GetLocalAddress()))
     66           return res;
     67         LOG(LS_VERBOSE) << "FirewallSocket inbound UDP packet from "
     68                         << paddr->ToSensitiveString() << " to "
     69                         << GetLocalAddress().ToSensitiveString() << " dropped";
     70       }
     71     }
     72     return AsyncSocketAdapter::RecvFrom(pv, cb, paddr);
     73   }
     74 
     75   virtual int Listen(int backlog) {
     76     if (!server_->tcp_listen_enabled()) {
     77       LOG(LS_VERBOSE) << "FirewallSocket listen attempt denied";
     78       return -1;
     79     }
     80 
     81     return AsyncSocketAdapter::Listen(backlog);
     82   }
     83   virtual AsyncSocket* Accept(SocketAddress* paddr) {
     84     SocketAddress addr;
     85     while (AsyncSocket* sock = AsyncSocketAdapter::Accept(&addr)) {
     86       if (server_->Check(FP_TCP, addr, GetLocalAddress())) {
     87         if (paddr)
     88           *paddr = addr;
     89         return sock;
     90       }
     91       sock->Close();
     92       delete sock;
     93       LOG(LS_VERBOSE) << "FirewallSocket inbound TCP connection from "
     94                       << addr.ToSensitiveString() << " to "
     95                       << GetLocalAddress().ToSensitiveString() << " denied";
     96     }
     97     return 0;
     98   }
     99 
    100  private:
    101   FirewallSocketServer* server_;
    102   int type_;
    103 };
    104 
    105 FirewallSocketServer::FirewallSocketServer(SocketServer* server,
    106                                            FirewallManager* manager,
    107                                            bool should_delete_server)
    108     : server_(server), manager_(manager),
    109       should_delete_server_(should_delete_server),
    110       udp_sockets_enabled_(true), tcp_sockets_enabled_(true),
    111       tcp_listen_enabled_(true) {
    112   if (manager_)
    113     manager_->AddServer(this);
    114 }
    115 
    116 FirewallSocketServer::~FirewallSocketServer() {
    117   if (manager_)
    118     manager_->RemoveServer(this);
    119 
    120   if (server_ && should_delete_server_) {
    121     delete server_;
    122     server_ = NULL;
    123   }
    124 }
    125 
    126 void FirewallSocketServer::AddRule(bool allow, FirewallProtocol p,
    127                                    FirewallDirection d,
    128                                    const SocketAddress& addr) {
    129   SocketAddress src, dst;
    130   if (d == FD_IN) {
    131     dst = addr;
    132   } else {
    133     src = addr;
    134   }
    135   AddRule(allow, p, src, dst);
    136 }
    137 
    138 
    139 void FirewallSocketServer::AddRule(bool allow, FirewallProtocol p,
    140                                    const SocketAddress& src,
    141                                    const SocketAddress& dst) {
    142   Rule r;
    143   r.allow = allow;
    144   r.p = p;
    145   r.src = src;
    146   r.dst = dst;
    147   CritScope scope(&crit_);
    148   rules_.push_back(r);
    149 }
    150 
    151 void FirewallSocketServer::ClearRules() {
    152   CritScope scope(&crit_);
    153   rules_.clear();
    154 }
    155 
    156 bool FirewallSocketServer::Check(FirewallProtocol p,
    157                                  const SocketAddress& src,
    158                                  const SocketAddress& dst) {
    159   CritScope scope(&crit_);
    160   for (size_t i = 0; i < rules_.size(); ++i) {
    161     const Rule& r = rules_[i];
    162     if ((r.p != p) && (r.p != FP_ANY))
    163       continue;
    164     if ((r.src.ipaddr() != src.ipaddr()) && !r.src.IsNil())
    165       continue;
    166     if ((r.src.port() != src.port()) && (r.src.port() != 0))
    167       continue;
    168     if ((r.dst.ipaddr() != dst.ipaddr()) && !r.dst.IsNil())
    169       continue;
    170     if ((r.dst.port() != dst.port()) && (r.dst.port() != 0))
    171       continue;
    172     return r.allow;
    173   }
    174   return true;
    175 }
    176 
    177 Socket* FirewallSocketServer::CreateSocket(int type) {
    178   return CreateSocket(AF_INET, type);
    179 }
    180 
    181 Socket* FirewallSocketServer::CreateSocket(int family, int type) {
    182   return WrapSocket(server_->CreateAsyncSocket(family, type), type);
    183 }
    184 
    185 AsyncSocket* FirewallSocketServer::CreateAsyncSocket(int type) {
    186   return CreateAsyncSocket(AF_INET, type);
    187 }
    188 
    189 AsyncSocket* FirewallSocketServer::CreateAsyncSocket(int family, int type) {
    190   return WrapSocket(server_->CreateAsyncSocket(family, type), type);
    191 }
    192 
    193 AsyncSocket* FirewallSocketServer::WrapSocket(AsyncSocket* sock, int type) {
    194   if (!sock ||
    195       (type == SOCK_STREAM && !tcp_sockets_enabled_) ||
    196       (type == SOCK_DGRAM && !udp_sockets_enabled_)) {
    197     LOG(LS_VERBOSE) << "FirewallSocketServer socket creation denied";
    198     delete sock;
    199     return NULL;
    200   }
    201   return new FirewallSocket(this, sock, type);
    202 }
    203 
    204 FirewallManager::FirewallManager() {
    205 }
    206 
    207 FirewallManager::~FirewallManager() {
    208   assert(servers_.empty());
    209 }
    210 
    211 void FirewallManager::AddServer(FirewallSocketServer* server) {
    212   CritScope scope(&crit_);
    213   servers_.push_back(server);
    214 }
    215 
    216 void FirewallManager::RemoveServer(FirewallSocketServer* server) {
    217   CritScope scope(&crit_);
    218   servers_.erase(std::remove(servers_.begin(), servers_.end(), server),
    219                  servers_.end());
    220 }
    221 
    222 void FirewallManager::AddRule(bool allow, FirewallProtocol p,
    223                               FirewallDirection d, const SocketAddress& addr) {
    224   CritScope scope(&crit_);
    225   for (std::vector<FirewallSocketServer*>::const_iterator it =
    226       servers_.begin(); it != servers_.end(); ++it) {
    227     (*it)->AddRule(allow, p, d, addr);
    228   }
    229 }
    230 
    231 void FirewallManager::ClearRules() {
    232   CritScope scope(&crit_);
    233   for (std::vector<FirewallSocketServer*>::const_iterator it =
    234       servers_.begin(); it != servers_.end(); ++it) {
    235     (*it)->ClearRules();
    236   }
    237 }
    238 
    239 }  // namespace rtc
    240