Home | History | Annotate | Download | only in dns
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/dns/dns_socket_pool.h"
      6 
      7 #include "base/logging.h"
      8 #include "base/rand_util.h"
      9 #include "base/stl_util.h"
     10 #include "net/base/address_list.h"
     11 #include "net/base/ip_endpoint.h"
     12 #include "net/base/net_errors.h"
     13 #include "net/base/rand_callback.h"
     14 #include "net/socket/client_socket_factory.h"
     15 #include "net/socket/stream_socket.h"
     16 #include "net/udp/datagram_client_socket.h"
     17 
     18 namespace net {
     19 
     20 namespace {
     21 
     22 // When we initialize the SocketPool, we allocate kInitialPoolSize sockets.
     23 // When we allocate a socket, we ensure we have at least kAllocateMinSize
     24 // sockets to choose from.  When we free a socket, we retain it if we have
     25 // less than kRetainMaxSize sockets in the pool.
     26 
     27 // On Windows, we can't request specific (random) ports, since that will
     28 // trigger firewall prompts, so request default ones, but keep a pile of
     29 // them.  Everywhere else, request fresh, random ports each time.
     30 #if defined(OS_WIN)
     31 const DatagramSocket::BindType kBindType = DatagramSocket::DEFAULT_BIND;
     32 const unsigned kInitialPoolSize = 256;
     33 const unsigned kAllocateMinSize = 256;
     34 const unsigned kRetainMaxSize = 0;
     35 #else
     36 const DatagramSocket::BindType kBindType = DatagramSocket::RANDOM_BIND;
     37 const unsigned kInitialPoolSize = 0;
     38 const unsigned kAllocateMinSize = 1;
     39 const unsigned kRetainMaxSize = 0;
     40 #endif
     41 
     42 } // namespace
     43 
     44 DnsSocketPool::DnsSocketPool(ClientSocketFactory* socket_factory)
     45     : socket_factory_(socket_factory),
     46       net_log_(NULL),
     47       nameservers_(NULL),
     48       initialized_(false) {
     49 }
     50 
     51 void DnsSocketPool::InitializeInternal(
     52     const std::vector<IPEndPoint>* nameservers,
     53     NetLog* net_log) {
     54   DCHECK(nameservers);
     55   DCHECK(!initialized_);
     56 
     57   net_log_ = net_log;
     58   nameservers_ = nameservers;
     59   initialized_ = true;
     60 }
     61 
     62 scoped_ptr<StreamSocket> DnsSocketPool::CreateTCPSocket(
     63     unsigned server_index,
     64     const NetLog::Source& source) {
     65   DCHECK_LT(server_index, nameservers_->size());
     66 
     67   return scoped_ptr<StreamSocket>(
     68       socket_factory_->CreateTransportClientSocket(
     69           AddressList((*nameservers_)[server_index]), net_log_, source));
     70 }
     71 
     72 scoped_ptr<DatagramClientSocket> DnsSocketPool::CreateConnectedSocket(
     73     unsigned server_index) {
     74   DCHECK_LT(server_index, nameservers_->size());
     75 
     76   scoped_ptr<DatagramClientSocket> socket;
     77 
     78   NetLog::Source no_source;
     79   socket.reset(socket_factory_->CreateDatagramClientSocket(
     80       kBindType, base::Bind(&base::RandInt), net_log_, no_source));
     81 
     82   if (socket.get()) {
     83     int rv = socket->Connect((*nameservers_)[server_index]);
     84     if (rv != OK) {
     85       LOG(WARNING) << "Failed to connect socket: " << rv;
     86       socket.reset();
     87     }
     88   } else {
     89     LOG(WARNING) << "Failed to create socket.";
     90   }
     91 
     92   return socket.Pass();
     93 }
     94 
     95 class NullDnsSocketPool : public DnsSocketPool {
     96  public:
     97   NullDnsSocketPool(ClientSocketFactory* factory)
     98      : DnsSocketPool(factory) {
     99   }
    100 
    101   virtual void Initialize(
    102       const std::vector<IPEndPoint>* nameservers,
    103       NetLog* net_log) OVERRIDE {
    104     InitializeInternal(nameservers, net_log);
    105   }
    106 
    107   virtual scoped_ptr<DatagramClientSocket> AllocateSocket(
    108       unsigned server_index) OVERRIDE {
    109     return CreateConnectedSocket(server_index);
    110   }
    111 
    112   virtual void FreeSocket(
    113       unsigned server_index,
    114       scoped_ptr<DatagramClientSocket> socket) OVERRIDE {
    115   }
    116 
    117  private:
    118   DISALLOW_COPY_AND_ASSIGN(NullDnsSocketPool);
    119 };
    120 
    121 // static
    122 scoped_ptr<DnsSocketPool> DnsSocketPool::CreateNull(
    123     ClientSocketFactory* factory) {
    124   return scoped_ptr<DnsSocketPool>(new NullDnsSocketPool(factory));
    125 }
    126 
    127 class DefaultDnsSocketPool : public DnsSocketPool {
    128  public:
    129   DefaultDnsSocketPool(ClientSocketFactory* factory)
    130      : DnsSocketPool(factory) {
    131   };
    132 
    133   virtual ~DefaultDnsSocketPool();
    134 
    135   virtual void Initialize(
    136       const std::vector<IPEndPoint>* nameservers,
    137       NetLog* net_log) OVERRIDE;
    138 
    139   virtual scoped_ptr<DatagramClientSocket> AllocateSocket(
    140       unsigned server_index) OVERRIDE;
    141 
    142   virtual void FreeSocket(
    143       unsigned server_index,
    144       scoped_ptr<DatagramClientSocket> socket) OVERRIDE;
    145 
    146  private:
    147   void FillPool(unsigned server_index, unsigned size);
    148 
    149   typedef std::vector<DatagramClientSocket*> SocketVector;
    150 
    151   std::vector<SocketVector> pools_;
    152 
    153   DISALLOW_COPY_AND_ASSIGN(DefaultDnsSocketPool);
    154 };
    155 
    156 // static
    157 scoped_ptr<DnsSocketPool> DnsSocketPool::CreateDefault(
    158     ClientSocketFactory* factory) {
    159   return scoped_ptr<DnsSocketPool>(new DefaultDnsSocketPool(factory));
    160 }
    161 
    162 void DefaultDnsSocketPool::Initialize(
    163     const std::vector<IPEndPoint>* nameservers,
    164     NetLog* net_log) {
    165   InitializeInternal(nameservers, net_log);
    166 
    167   DCHECK(pools_.empty());
    168   const unsigned num_servers = nameservers->size();
    169   pools_.resize(num_servers);
    170   for (unsigned server_index = 0; server_index < num_servers; ++server_index)
    171     FillPool(server_index, kInitialPoolSize);
    172 }
    173 
    174 DefaultDnsSocketPool::~DefaultDnsSocketPool() {
    175   unsigned num_servers = pools_.size();
    176   for (unsigned server_index = 0; server_index < num_servers; ++server_index) {
    177     SocketVector& pool = pools_[server_index];
    178     STLDeleteElements(&pool);
    179   }
    180 }
    181 
    182 scoped_ptr<DatagramClientSocket> DefaultDnsSocketPool::AllocateSocket(
    183     unsigned server_index) {
    184   DCHECK_LT(server_index, pools_.size());
    185   SocketVector& pool = pools_[server_index];
    186 
    187   FillPool(server_index, kAllocateMinSize);
    188   if (pool.size() == 0) {
    189     LOG(WARNING) << "No DNS sockets available in pool " << server_index << "!";
    190     return scoped_ptr<DatagramClientSocket>();
    191   }
    192 
    193   if (pool.size() < kAllocateMinSize) {
    194     LOG(WARNING) << "Low DNS port entropy: wanted " << kAllocateMinSize
    195                  << " sockets to choose from, but only have " << pool.size()
    196                  << " in pool " << server_index << ".";
    197   }
    198 
    199   unsigned socket_index = base::RandInt(0, pool.size() - 1);
    200   DatagramClientSocket* socket = pool[socket_index];
    201   pool[socket_index] = pool.back();
    202   pool.pop_back();
    203 
    204   return scoped_ptr<DatagramClientSocket>(socket);
    205 }
    206 
    207 void DefaultDnsSocketPool::FreeSocket(
    208     unsigned server_index,
    209     scoped_ptr<DatagramClientSocket> socket) {
    210   DCHECK_LT(server_index, pools_.size());
    211 
    212   // In some builds, kRetainMaxSize will be 0 if we never reuse sockets.
    213   // In that case, don't compile this code to avoid a "tautological
    214   // comparison" warning from clang.
    215 #if kRetainMaxSize > 0
    216   SocketVector& pool = pools_[server_index];
    217   if (pool.size() < kRetainMaxSize)
    218     pool.push_back(socket.release());
    219 #endif
    220 }
    221 
    222 void DefaultDnsSocketPool::FillPool(unsigned server_index, unsigned size) {
    223   SocketVector& pool = pools_[server_index];
    224 
    225   for (unsigned pool_index = pool.size(); pool_index < size; ++pool_index) {
    226     DatagramClientSocket* socket =
    227         CreateConnectedSocket(server_index).release();
    228     if (!socket)
    229       break;
    230     pool.push_back(socket);
    231   }
    232 }
    233 
    234 } // namespace net
    235