Home | History | Annotate | Download | only in dns
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/dns/dns_socket_pool.h"
      6 
      7 #include "base/logging.h"
      8 #include "base/rand_util.h"
      9 #include "base/stl_util.h"
     10 #include "net/base/address_list.h"
     11 #include "net/base/ip_endpoint.h"
     12 #include "net/base/net_errors.h"
     13 #include "net/base/rand_callback.h"
     14 #include "net/socket/client_socket_factory.h"
     15 #include "net/socket/stream_socket.h"
     16 #include "net/udp/datagram_client_socket.h"
     17 
     18 namespace net {
     19 
     20 namespace {
     21 
     22 // When we initialize the SocketPool, we allocate kInitialPoolSize sockets.
     23 // When we allocate a socket, we ensure we have at least kAllocateMinSize
     24 // sockets to choose from.  Freed sockets are not retained.
     25 
     26 // On Windows, we can't request specific (random) ports, since that will
     27 // trigger firewall prompts, so request default ones, but keep a pile of
     28 // them.  Everywhere else, request fresh, random ports each time.
     29 #if defined(OS_WIN)
     30 const DatagramSocket::BindType kBindType = DatagramSocket::DEFAULT_BIND;
     31 const unsigned kInitialPoolSize = 256;
     32 const unsigned kAllocateMinSize = 256;
     33 #else
     34 const DatagramSocket::BindType kBindType = DatagramSocket::RANDOM_BIND;
     35 const unsigned kInitialPoolSize = 0;
     36 const unsigned kAllocateMinSize = 1;
     37 #endif
     38 
     39 } // namespace
     40 
     41 DnsSocketPool::DnsSocketPool(ClientSocketFactory* socket_factory)
     42     : socket_factory_(socket_factory),
     43       net_log_(NULL),
     44       nameservers_(NULL),
     45       initialized_(false) {
     46 }
     47 
     48 void DnsSocketPool::InitializeInternal(
     49     const std::vector<IPEndPoint>* nameservers,
     50     NetLog* net_log) {
     51   DCHECK(nameservers);
     52   DCHECK(!initialized_);
     53 
     54   net_log_ = net_log;
     55   nameservers_ = nameservers;
     56   initialized_ = true;
     57 }
     58 
     59 scoped_ptr<StreamSocket> DnsSocketPool::CreateTCPSocket(
     60     unsigned server_index,
     61     const NetLog::Source& source) {
     62   DCHECK_LT(server_index, nameservers_->size());
     63 
     64   return scoped_ptr<StreamSocket>(
     65       socket_factory_->CreateTransportClientSocket(
     66           AddressList((*nameservers_)[server_index]), net_log_, source));
     67 }
     68 
     69 scoped_ptr<DatagramClientSocket> DnsSocketPool::CreateConnectedSocket(
     70     unsigned server_index) {
     71   DCHECK_LT(server_index, nameservers_->size());
     72 
     73   scoped_ptr<DatagramClientSocket> socket;
     74 
     75   NetLog::Source no_source;
     76   socket = socket_factory_->CreateDatagramClientSocket(
     77       kBindType, base::Bind(&base::RandInt), net_log_, no_source);
     78 
     79   if (socket.get()) {
     80     int rv = socket->Connect((*nameservers_)[server_index]);
     81     if (rv != OK) {
     82       VLOG(1) << "Failed to connect socket: " << rv;
     83       socket.reset();
     84     }
     85   } else {
     86     LOG(WARNING) << "Failed to create socket.";
     87   }
     88 
     89   return socket.Pass();
     90 }
     91 
     92 class NullDnsSocketPool : public DnsSocketPool {
     93  public:
     94   NullDnsSocketPool(ClientSocketFactory* factory)
     95      : DnsSocketPool(factory) {
     96   }
     97 
     98   virtual void Initialize(
     99       const std::vector<IPEndPoint>* nameservers,
    100       NetLog* net_log) OVERRIDE {
    101     InitializeInternal(nameservers, net_log);
    102   }
    103 
    104   virtual scoped_ptr<DatagramClientSocket> AllocateSocket(
    105       unsigned server_index) OVERRIDE {
    106     return CreateConnectedSocket(server_index);
    107   }
    108 
    109   virtual void FreeSocket(
    110       unsigned server_index,
    111       scoped_ptr<DatagramClientSocket> socket) OVERRIDE {
    112   }
    113 
    114  private:
    115   DISALLOW_COPY_AND_ASSIGN(NullDnsSocketPool);
    116 };
    117 
    118 // static
    119 scoped_ptr<DnsSocketPool> DnsSocketPool::CreateNull(
    120     ClientSocketFactory* factory) {
    121   return scoped_ptr<DnsSocketPool>(new NullDnsSocketPool(factory));
    122 }
    123 
    124 class DefaultDnsSocketPool : public DnsSocketPool {
    125  public:
    126   DefaultDnsSocketPool(ClientSocketFactory* factory)
    127      : DnsSocketPool(factory) {
    128   };
    129 
    130   virtual ~DefaultDnsSocketPool();
    131 
    132   virtual void Initialize(
    133       const std::vector<IPEndPoint>* nameservers,
    134       NetLog* net_log) OVERRIDE;
    135 
    136   virtual scoped_ptr<DatagramClientSocket> AllocateSocket(
    137       unsigned server_index) OVERRIDE;
    138 
    139   virtual void FreeSocket(
    140       unsigned server_index,
    141       scoped_ptr<DatagramClientSocket> socket) OVERRIDE;
    142 
    143  private:
    144   void FillPool(unsigned server_index, unsigned size);
    145 
    146   typedef std::vector<DatagramClientSocket*> SocketVector;
    147 
    148   std::vector<SocketVector> pools_;
    149 
    150   DISALLOW_COPY_AND_ASSIGN(DefaultDnsSocketPool);
    151 };
    152 
    153 // static
    154 scoped_ptr<DnsSocketPool> DnsSocketPool::CreateDefault(
    155     ClientSocketFactory* factory) {
    156   return scoped_ptr<DnsSocketPool>(new DefaultDnsSocketPool(factory));
    157 }
    158 
    159 void DefaultDnsSocketPool::Initialize(
    160     const std::vector<IPEndPoint>* nameservers,
    161     NetLog* net_log) {
    162   InitializeInternal(nameservers, net_log);
    163 
    164   DCHECK(pools_.empty());
    165   const unsigned num_servers = nameservers->size();
    166   pools_.resize(num_servers);
    167   for (unsigned server_index = 0; server_index < num_servers; ++server_index)
    168     FillPool(server_index, kInitialPoolSize);
    169 }
    170 
    171 DefaultDnsSocketPool::~DefaultDnsSocketPool() {
    172   unsigned num_servers = pools_.size();
    173   for (unsigned server_index = 0; server_index < num_servers; ++server_index) {
    174     SocketVector& pool = pools_[server_index];
    175     STLDeleteElements(&pool);
    176   }
    177 }
    178 
    179 scoped_ptr<DatagramClientSocket> DefaultDnsSocketPool::AllocateSocket(
    180     unsigned server_index) {
    181   DCHECK_LT(server_index, pools_.size());
    182   SocketVector& pool = pools_[server_index];
    183 
    184   FillPool(server_index, kAllocateMinSize);
    185   if (pool.size() == 0) {
    186     LOG(WARNING) << "No DNS sockets available in pool " << server_index << "!";
    187     return scoped_ptr<DatagramClientSocket>();
    188   }
    189 
    190   if (pool.size() < kAllocateMinSize) {
    191     LOG(WARNING) << "Low DNS port entropy: wanted " << kAllocateMinSize
    192                  << " sockets to choose from, but only have " << pool.size()
    193                  << " in pool " << server_index << ".";
    194   }
    195 
    196   unsigned socket_index = base::RandInt(0, pool.size() - 1);
    197   DatagramClientSocket* socket = pool[socket_index];
    198   pool[socket_index] = pool.back();
    199   pool.pop_back();
    200 
    201   return scoped_ptr<DatagramClientSocket>(socket);
    202 }
    203 
    204 void DefaultDnsSocketPool::FreeSocket(
    205     unsigned server_index,
    206     scoped_ptr<DatagramClientSocket> socket) {
    207   DCHECK_LT(server_index, pools_.size());
    208 }
    209 
    210 void DefaultDnsSocketPool::FillPool(unsigned server_index, unsigned size) {
    211   SocketVector& pool = pools_[server_index];
    212 
    213   for (unsigned pool_index = pool.size(); pool_index < size; ++pool_index) {
    214     DatagramClientSocket* socket =
    215         CreateConnectedSocket(server_index).release();
    216     if (!socket)
    217       break;
    218     pool.push_back(socket);
    219   }
    220 }
    221 
    222 } // namespace net
    223