1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/dns/dns_socket_pool.h" 6 7 #include "base/logging.h" 8 #include "base/rand_util.h" 9 #include "base/stl_util.h" 10 #include "net/base/address_list.h" 11 #include "net/base/ip_endpoint.h" 12 #include "net/base/net_errors.h" 13 #include "net/base/rand_callback.h" 14 #include "net/socket/client_socket_factory.h" 15 #include "net/socket/stream_socket.h" 16 #include "net/udp/datagram_client_socket.h" 17 18 namespace net { 19 20 namespace { 21 22 // When we initialize the SocketPool, we allocate kInitialPoolSize sockets. 23 // When we allocate a socket, we ensure we have at least kAllocateMinSize 24 // sockets to choose from. When we free a socket, we retain it if we have 25 // less than kRetainMaxSize sockets in the pool. 26 27 // On Windows, we can't request specific (random) ports, since that will 28 // trigger firewall prompts, so request default ones, but keep a pile of 29 // them. Everywhere else, request fresh, random ports each time. 30 #if defined(OS_WIN) 31 const DatagramSocket::BindType kBindType = DatagramSocket::DEFAULT_BIND; 32 const unsigned kInitialPoolSize = 256; 33 const unsigned kAllocateMinSize = 256; 34 const unsigned kRetainMaxSize = 0; 35 #else 36 const DatagramSocket::BindType kBindType = DatagramSocket::RANDOM_BIND; 37 const unsigned kInitialPoolSize = 0; 38 const unsigned kAllocateMinSize = 1; 39 const unsigned kRetainMaxSize = 0; 40 #endif 41 42 } // namespace 43 44 DnsSocketPool::DnsSocketPool(ClientSocketFactory* socket_factory) 45 : socket_factory_(socket_factory), 46 net_log_(NULL), 47 nameservers_(NULL), 48 initialized_(false) { 49 } 50 51 void DnsSocketPool::InitializeInternal( 52 const std::vector<IPEndPoint>* nameservers, 53 NetLog* net_log) { 54 DCHECK(nameservers); 55 DCHECK(!initialized_); 56 57 net_log_ = net_log; 58 nameservers_ = nameservers; 59 initialized_ = true; 60 } 61 62 scoped_ptr<StreamSocket> DnsSocketPool::CreateTCPSocket( 63 unsigned server_index, 64 const NetLog::Source& source) { 65 DCHECK_LT(server_index, nameservers_->size()); 66 67 return scoped_ptr<StreamSocket>( 68 socket_factory_->CreateTransportClientSocket( 69 AddressList((*nameservers_)[server_index]), net_log_, source)); 70 } 71 72 scoped_ptr<DatagramClientSocket> DnsSocketPool::CreateConnectedSocket( 73 unsigned server_index) { 74 DCHECK_LT(server_index, nameservers_->size()); 75 76 scoped_ptr<DatagramClientSocket> socket; 77 78 NetLog::Source no_source; 79 socket.reset(socket_factory_->CreateDatagramClientSocket( 80 kBindType, base::Bind(&base::RandInt), net_log_, no_source)); 81 82 if (socket.get()) { 83 int rv = socket->Connect((*nameservers_)[server_index]); 84 if (rv != OK) { 85 LOG(WARNING) << "Failed to connect socket: " << rv; 86 socket.reset(); 87 } 88 } else { 89 LOG(WARNING) << "Failed to create socket."; 90 } 91 92 return socket.Pass(); 93 } 94 95 class NullDnsSocketPool : public DnsSocketPool { 96 public: 97 NullDnsSocketPool(ClientSocketFactory* factory) 98 : DnsSocketPool(factory) { 99 } 100 101 virtual void Initialize( 102 const std::vector<IPEndPoint>* nameservers, 103 NetLog* net_log) OVERRIDE { 104 InitializeInternal(nameservers, net_log); 105 } 106 107 virtual scoped_ptr<DatagramClientSocket> AllocateSocket( 108 unsigned server_index) OVERRIDE { 109 return CreateConnectedSocket(server_index); 110 } 111 112 virtual void FreeSocket( 113 unsigned server_index, 114 scoped_ptr<DatagramClientSocket> socket) OVERRIDE { 115 } 116 117 private: 118 DISALLOW_COPY_AND_ASSIGN(NullDnsSocketPool); 119 }; 120 121 // static 122 scoped_ptr<DnsSocketPool> DnsSocketPool::CreateNull( 123 ClientSocketFactory* factory) { 124 return scoped_ptr<DnsSocketPool>(new NullDnsSocketPool(factory)); 125 } 126 127 class DefaultDnsSocketPool : public DnsSocketPool { 128 public: 129 DefaultDnsSocketPool(ClientSocketFactory* factory) 130 : DnsSocketPool(factory) { 131 }; 132 133 virtual ~DefaultDnsSocketPool(); 134 135 virtual void Initialize( 136 const std::vector<IPEndPoint>* nameservers, 137 NetLog* net_log) OVERRIDE; 138 139 virtual scoped_ptr<DatagramClientSocket> AllocateSocket( 140 unsigned server_index) OVERRIDE; 141 142 virtual void FreeSocket( 143 unsigned server_index, 144 scoped_ptr<DatagramClientSocket> socket) OVERRIDE; 145 146 private: 147 void FillPool(unsigned server_index, unsigned size); 148 149 typedef std::vector<DatagramClientSocket*> SocketVector; 150 151 std::vector<SocketVector> pools_; 152 153 DISALLOW_COPY_AND_ASSIGN(DefaultDnsSocketPool); 154 }; 155 156 // static 157 scoped_ptr<DnsSocketPool> DnsSocketPool::CreateDefault( 158 ClientSocketFactory* factory) { 159 return scoped_ptr<DnsSocketPool>(new DefaultDnsSocketPool(factory)); 160 } 161 162 void DefaultDnsSocketPool::Initialize( 163 const std::vector<IPEndPoint>* nameservers, 164 NetLog* net_log) { 165 InitializeInternal(nameservers, net_log); 166 167 DCHECK(pools_.empty()); 168 const unsigned num_servers = nameservers->size(); 169 pools_.resize(num_servers); 170 for (unsigned server_index = 0; server_index < num_servers; ++server_index) 171 FillPool(server_index, kInitialPoolSize); 172 } 173 174 DefaultDnsSocketPool::~DefaultDnsSocketPool() { 175 unsigned num_servers = pools_.size(); 176 for (unsigned server_index = 0; server_index < num_servers; ++server_index) { 177 SocketVector& pool = pools_[server_index]; 178 STLDeleteElements(&pool); 179 } 180 } 181 182 scoped_ptr<DatagramClientSocket> DefaultDnsSocketPool::AllocateSocket( 183 unsigned server_index) { 184 DCHECK_LT(server_index, pools_.size()); 185 SocketVector& pool = pools_[server_index]; 186 187 FillPool(server_index, kAllocateMinSize); 188 if (pool.size() == 0) { 189 LOG(WARNING) << "No DNS sockets available in pool " << server_index << "!"; 190 return scoped_ptr<DatagramClientSocket>(); 191 } 192 193 if (pool.size() < kAllocateMinSize) { 194 LOG(WARNING) << "Low DNS port entropy: wanted " << kAllocateMinSize 195 << " sockets to choose from, but only have " << pool.size() 196 << " in pool " << server_index << "."; 197 } 198 199 unsigned socket_index = base::RandInt(0, pool.size() - 1); 200 DatagramClientSocket* socket = pool[socket_index]; 201 pool[socket_index] = pool.back(); 202 pool.pop_back(); 203 204 return scoped_ptr<DatagramClientSocket>(socket); 205 } 206 207 void DefaultDnsSocketPool::FreeSocket( 208 unsigned server_index, 209 scoped_ptr<DatagramClientSocket> socket) { 210 DCHECK_LT(server_index, pools_.size()); 211 212 // In some builds, kRetainMaxSize will be 0 if we never reuse sockets. 213 // In that case, don't compile this code to avoid a "tautological 214 // comparison" warning from clang. 215 #if kRetainMaxSize > 0 216 SocketVector& pool = pools_[server_index]; 217 if (pool.size() < kRetainMaxSize) 218 pool.push_back(socket.release()); 219 #endif 220 } 221 222 void DefaultDnsSocketPool::FillPool(unsigned server_index, unsigned size) { 223 SocketVector& pool = pools_[server_index]; 224 225 for (unsigned pool_index = pool.size(); pool_index < size; ++pool_index) { 226 DatagramClientSocket* socket = 227 CreateConnectedSocket(server_index).release(); 228 if (!socket) 229 break; 230 pool.push_back(socket); 231 } 232 } 233 234 } // namespace net 235