1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/dns/dns_socket_pool.h" 6 7 #include "base/logging.h" 8 #include "base/rand_util.h" 9 #include "base/stl_util.h" 10 #include "net/base/address_list.h" 11 #include "net/base/ip_endpoint.h" 12 #include "net/base/net_errors.h" 13 #include "net/base/rand_callback.h" 14 #include "net/socket/client_socket_factory.h" 15 #include "net/socket/stream_socket.h" 16 #include "net/udp/datagram_client_socket.h" 17 18 namespace net { 19 20 namespace { 21 22 // When we initialize the SocketPool, we allocate kInitialPoolSize sockets. 23 // When we allocate a socket, we ensure we have at least kAllocateMinSize 24 // sockets to choose from. Freed sockets are not retained. 25 26 // On Windows, we can't request specific (random) ports, since that will 27 // trigger firewall prompts, so request default ones, but keep a pile of 28 // them. Everywhere else, request fresh, random ports each time. 29 #if defined(OS_WIN) 30 const DatagramSocket::BindType kBindType = DatagramSocket::DEFAULT_BIND; 31 const unsigned kInitialPoolSize = 256; 32 const unsigned kAllocateMinSize = 256; 33 #else 34 const DatagramSocket::BindType kBindType = DatagramSocket::RANDOM_BIND; 35 const unsigned kInitialPoolSize = 0; 36 const unsigned kAllocateMinSize = 1; 37 #endif 38 39 } // namespace 40 41 DnsSocketPool::DnsSocketPool(ClientSocketFactory* socket_factory) 42 : socket_factory_(socket_factory), 43 net_log_(NULL), 44 nameservers_(NULL), 45 initialized_(false) { 46 } 47 48 void DnsSocketPool::InitializeInternal( 49 const std::vector<IPEndPoint>* nameservers, 50 NetLog* net_log) { 51 DCHECK(nameservers); 52 DCHECK(!initialized_); 53 54 net_log_ = net_log; 55 nameservers_ = nameservers; 56 initialized_ = true; 57 } 58 59 scoped_ptr<StreamSocket> DnsSocketPool::CreateTCPSocket( 60 unsigned server_index, 61 const NetLog::Source& source) { 62 DCHECK_LT(server_index, nameservers_->size()); 63 64 return scoped_ptr<StreamSocket>( 65 socket_factory_->CreateTransportClientSocket( 66 AddressList((*nameservers_)[server_index]), net_log_, source)); 67 } 68 69 scoped_ptr<DatagramClientSocket> DnsSocketPool::CreateConnectedSocket( 70 unsigned server_index) { 71 DCHECK_LT(server_index, nameservers_->size()); 72 73 scoped_ptr<DatagramClientSocket> socket; 74 75 NetLog::Source no_source; 76 socket = socket_factory_->CreateDatagramClientSocket( 77 kBindType, base::Bind(&base::RandInt), net_log_, no_source); 78 79 if (socket.get()) { 80 int rv = socket->Connect((*nameservers_)[server_index]); 81 if (rv != OK) { 82 VLOG(1) << "Failed to connect socket: " << rv; 83 socket.reset(); 84 } 85 } else { 86 LOG(WARNING) << "Failed to create socket."; 87 } 88 89 return socket.Pass(); 90 } 91 92 class NullDnsSocketPool : public DnsSocketPool { 93 public: 94 NullDnsSocketPool(ClientSocketFactory* factory) 95 : DnsSocketPool(factory) { 96 } 97 98 virtual void Initialize( 99 const std::vector<IPEndPoint>* nameservers, 100 NetLog* net_log) OVERRIDE { 101 InitializeInternal(nameservers, net_log); 102 } 103 104 virtual scoped_ptr<DatagramClientSocket> AllocateSocket( 105 unsigned server_index) OVERRIDE { 106 return CreateConnectedSocket(server_index); 107 } 108 109 virtual void FreeSocket( 110 unsigned server_index, 111 scoped_ptr<DatagramClientSocket> socket) OVERRIDE { 112 } 113 114 private: 115 DISALLOW_COPY_AND_ASSIGN(NullDnsSocketPool); 116 }; 117 118 // static 119 scoped_ptr<DnsSocketPool> DnsSocketPool::CreateNull( 120 ClientSocketFactory* factory) { 121 return scoped_ptr<DnsSocketPool>(new NullDnsSocketPool(factory)); 122 } 123 124 class DefaultDnsSocketPool : public DnsSocketPool { 125 public: 126 DefaultDnsSocketPool(ClientSocketFactory* factory) 127 : DnsSocketPool(factory) { 128 }; 129 130 virtual ~DefaultDnsSocketPool(); 131 132 virtual void Initialize( 133 const std::vector<IPEndPoint>* nameservers, 134 NetLog* net_log) OVERRIDE; 135 136 virtual scoped_ptr<DatagramClientSocket> AllocateSocket( 137 unsigned server_index) OVERRIDE; 138 139 virtual void FreeSocket( 140 unsigned server_index, 141 scoped_ptr<DatagramClientSocket> socket) OVERRIDE; 142 143 private: 144 void FillPool(unsigned server_index, unsigned size); 145 146 typedef std::vector<DatagramClientSocket*> SocketVector; 147 148 std::vector<SocketVector> pools_; 149 150 DISALLOW_COPY_AND_ASSIGN(DefaultDnsSocketPool); 151 }; 152 153 // static 154 scoped_ptr<DnsSocketPool> DnsSocketPool::CreateDefault( 155 ClientSocketFactory* factory) { 156 return scoped_ptr<DnsSocketPool>(new DefaultDnsSocketPool(factory)); 157 } 158 159 void DefaultDnsSocketPool::Initialize( 160 const std::vector<IPEndPoint>* nameservers, 161 NetLog* net_log) { 162 InitializeInternal(nameservers, net_log); 163 164 DCHECK(pools_.empty()); 165 const unsigned num_servers = nameservers->size(); 166 pools_.resize(num_servers); 167 for (unsigned server_index = 0; server_index < num_servers; ++server_index) 168 FillPool(server_index, kInitialPoolSize); 169 } 170 171 DefaultDnsSocketPool::~DefaultDnsSocketPool() { 172 unsigned num_servers = pools_.size(); 173 for (unsigned server_index = 0; server_index < num_servers; ++server_index) { 174 SocketVector& pool = pools_[server_index]; 175 STLDeleteElements(&pool); 176 } 177 } 178 179 scoped_ptr<DatagramClientSocket> DefaultDnsSocketPool::AllocateSocket( 180 unsigned server_index) { 181 DCHECK_LT(server_index, pools_.size()); 182 SocketVector& pool = pools_[server_index]; 183 184 FillPool(server_index, kAllocateMinSize); 185 if (pool.size() == 0) { 186 LOG(WARNING) << "No DNS sockets available in pool " << server_index << "!"; 187 return scoped_ptr<DatagramClientSocket>(); 188 } 189 190 if (pool.size() < kAllocateMinSize) { 191 LOG(WARNING) << "Low DNS port entropy: wanted " << kAllocateMinSize 192 << " sockets to choose from, but only have " << pool.size() 193 << " in pool " << server_index << "."; 194 } 195 196 unsigned socket_index = base::RandInt(0, pool.size() - 1); 197 DatagramClientSocket* socket = pool[socket_index]; 198 pool[socket_index] = pool.back(); 199 pool.pop_back(); 200 201 return scoped_ptr<DatagramClientSocket>(socket); 202 } 203 204 void DefaultDnsSocketPool::FreeSocket( 205 unsigned server_index, 206 scoped_ptr<DatagramClientSocket> socket) { 207 DCHECK_LT(server_index, pools_.size()); 208 } 209 210 void DefaultDnsSocketPool::FillPool(unsigned server_index, unsigned size) { 211 SocketVector& pool = pools_[server_index]; 212 213 for (unsigned pool_index = pool.size(); pool_index < size; ++pool_index) { 214 DatagramClientSocket* socket = 215 CreateConnectedSocket(server_index).release(); 216 if (!socket) 217 break; 218 pool.push_back(socket); 219 } 220 } 221 222 } // namespace net 223