1 /* 2 * Copyright 2004 The WebRTC Project Authors. All rights reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "webrtc/base/natsocketfactory.h" 12 13 #include "webrtc/base/logging.h" 14 #include "webrtc/base/natserver.h" 15 #include "webrtc/base/virtualsocketserver.h" 16 17 namespace rtc { 18 19 // Packs the given socketaddress into the buffer in buf, in the quasi-STUN 20 // format that the natserver uses. 21 // Returns 0 if an invalid address is passed. 22 size_t PackAddressForNAT(char* buf, size_t buf_size, 23 const SocketAddress& remote_addr) { 24 const IPAddress& ip = remote_addr.ipaddr(); 25 int family = ip.family(); 26 buf[0] = 0; 27 buf[1] = family; 28 // Writes the port. 29 *(reinterpret_cast<uint16*>(&buf[2])) = HostToNetwork16(remote_addr.port()); 30 if (family == AF_INET) { 31 ASSERT(buf_size >= kNATEncodedIPv4AddressSize); 32 in_addr v4addr = ip.ipv4_address(); 33 memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4); 34 return kNATEncodedIPv4AddressSize; 35 } else if (family == AF_INET6) { 36 ASSERT(buf_size >= kNATEncodedIPv6AddressSize); 37 in6_addr v6addr = ip.ipv6_address(); 38 memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4); 39 return kNATEncodedIPv6AddressSize; 40 } 41 return 0U; 42 } 43 44 // Decodes the remote address from a packet that has been encoded with the nat's 45 // quasi-STUN format. Returns the length of the address (i.e., the offset into 46 // data where the original packet starts). 47 size_t UnpackAddressFromNAT(const char* buf, size_t buf_size, 48 SocketAddress* remote_addr) { 49 ASSERT(buf_size >= 8); 50 ASSERT(buf[0] == 0); 51 int family = buf[1]; 52 uint16 port = NetworkToHost16(*(reinterpret_cast<const uint16*>(&buf[2]))); 53 if (family == AF_INET) { 54 const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]); 55 *remote_addr = SocketAddress(IPAddress(*v4addr), port); 56 return kNATEncodedIPv4AddressSize; 57 } else if (family == AF_INET6) { 58 ASSERT(buf_size >= 20); 59 const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]); 60 *remote_addr = SocketAddress(IPAddress(*v6addr), port); 61 return kNATEncodedIPv6AddressSize; 62 } 63 return 0U; 64 } 65 66 67 // NATSocket 68 class NATSocket : public AsyncSocket, public sigslot::has_slots<> { 69 public: 70 explicit NATSocket(NATInternalSocketFactory* sf, int family, int type) 71 : sf_(sf), family_(family), type_(type), connected_(false), 72 socket_(NULL), buf_(NULL), size_(0) { 73 } 74 75 virtual ~NATSocket() { 76 delete socket_; 77 delete[] buf_; 78 } 79 80 virtual SocketAddress GetLocalAddress() const { 81 return (socket_) ? socket_->GetLocalAddress() : SocketAddress(); 82 } 83 84 virtual SocketAddress GetRemoteAddress() const { 85 return remote_addr_; // will be NIL if not connected 86 } 87 88 virtual int Bind(const SocketAddress& addr) { 89 if (socket_) { // already bound, bubble up error 90 return -1; 91 } 92 93 int result; 94 socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_); 95 result = (socket_) ? socket_->Bind(addr) : -1; 96 if (result >= 0) { 97 socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent); 98 socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent); 99 socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent); 100 socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent); 101 } else { 102 server_addr_.Clear(); 103 delete socket_; 104 socket_ = NULL; 105 } 106 107 return result; 108 } 109 110 virtual int Connect(const SocketAddress& addr) { 111 if (!socket_) { // socket must be bound, for now 112 return -1; 113 } 114 115 int result = 0; 116 if (type_ == SOCK_STREAM) { 117 result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_); 118 } else { 119 connected_ = true; 120 } 121 122 if (result >= 0) { 123 remote_addr_ = addr; 124 } 125 126 return result; 127 } 128 129 virtual int Send(const void* data, size_t size) { 130 ASSERT(connected_); 131 return SendTo(data, size, remote_addr_); 132 } 133 134 virtual int SendTo(const void* data, size_t size, const SocketAddress& addr) { 135 ASSERT(!connected_ || addr == remote_addr_); 136 if (server_addr_.IsNil() || type_ == SOCK_STREAM) { 137 return socket_->SendTo(data, size, addr); 138 } 139 // This array will be too large for IPv4 packets, but only by 12 bytes. 140 scoped_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]); 141 size_t addrlength = PackAddressForNAT(buf.get(), 142 size + kNATEncodedIPv6AddressSize, 143 addr); 144 size_t encoded_size = size + addrlength; 145 memcpy(buf.get() + addrlength, data, size); 146 int result = socket_->SendTo(buf.get(), encoded_size, server_addr_); 147 if (result >= 0) { 148 ASSERT(result == static_cast<int>(encoded_size)); 149 result = result - static_cast<int>(addrlength); 150 } 151 return result; 152 } 153 154 virtual int Recv(void* data, size_t size) { 155 SocketAddress addr; 156 return RecvFrom(data, size, &addr); 157 } 158 159 virtual int RecvFrom(void* data, size_t size, SocketAddress *out_addr) { 160 if (server_addr_.IsNil() || type_ == SOCK_STREAM) { 161 return socket_->RecvFrom(data, size, out_addr); 162 } 163 // Make sure we have enough room to read the requested amount plus the 164 // largest possible header address. 165 SocketAddress remote_addr; 166 Grow(size + kNATEncodedIPv6AddressSize); 167 168 // Read the packet from the socket. 169 int result = socket_->RecvFrom(buf_, size_, &remote_addr); 170 if (result >= 0) { 171 ASSERT(remote_addr == server_addr_); 172 173 // TODO: we need better framing so we know how many bytes we can 174 // return before we need to read the next address. For UDP, this will be 175 // fine as long as the reader always reads everything in the packet. 176 ASSERT((size_t)result < size_); 177 178 // Decode the wire packet into the actual results. 179 SocketAddress real_remote_addr; 180 size_t addrlength = 181 UnpackAddressFromNAT(buf_, result, &real_remote_addr); 182 memcpy(data, buf_ + addrlength, result - addrlength); 183 184 // Make sure this packet should be delivered before returning it. 185 if (!connected_ || (real_remote_addr == remote_addr_)) { 186 if (out_addr) 187 *out_addr = real_remote_addr; 188 result = result - static_cast<int>(addrlength); 189 } else { 190 LOG(LS_ERROR) << "Dropping packet from unknown remote address: " 191 << real_remote_addr.ToString(); 192 result = 0; // Tell the caller we didn't read anything 193 } 194 } 195 196 return result; 197 } 198 199 virtual int Close() { 200 int result = 0; 201 if (socket_) { 202 result = socket_->Close(); 203 if (result >= 0) { 204 connected_ = false; 205 remote_addr_ = SocketAddress(); 206 delete socket_; 207 socket_ = NULL; 208 } 209 } 210 return result; 211 } 212 213 virtual int Listen(int backlog) { 214 return socket_->Listen(backlog); 215 } 216 virtual AsyncSocket* Accept(SocketAddress *paddr) { 217 return socket_->Accept(paddr); 218 } 219 virtual int GetError() const { 220 return socket_->GetError(); 221 } 222 virtual void SetError(int error) { 223 socket_->SetError(error); 224 } 225 virtual ConnState GetState() const { 226 return connected_ ? CS_CONNECTED : CS_CLOSED; 227 } 228 virtual int EstimateMTU(uint16* mtu) { 229 return socket_->EstimateMTU(mtu); 230 } 231 virtual int GetOption(Option opt, int* value) { 232 return socket_->GetOption(opt, value); 233 } 234 virtual int SetOption(Option opt, int value) { 235 return socket_->SetOption(opt, value); 236 } 237 238 void OnConnectEvent(AsyncSocket* socket) { 239 // If we're NATed, we need to send a request with the real addr to use. 240 ASSERT(socket == socket_); 241 if (server_addr_.IsNil()) { 242 connected_ = true; 243 SignalConnectEvent(this); 244 } else { 245 SendConnectRequest(); 246 } 247 } 248 void OnReadEvent(AsyncSocket* socket) { 249 // If we're NATed, we need to process the connect reply. 250 ASSERT(socket == socket_); 251 if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) { 252 HandleConnectReply(); 253 } else { 254 SignalReadEvent(this); 255 } 256 } 257 void OnWriteEvent(AsyncSocket* socket) { 258 ASSERT(socket == socket_); 259 SignalWriteEvent(this); 260 } 261 void OnCloseEvent(AsyncSocket* socket, int error) { 262 ASSERT(socket == socket_); 263 SignalCloseEvent(this, error); 264 } 265 266 private: 267 // Makes sure the buffer is at least the given size. 268 void Grow(size_t new_size) { 269 if (size_ < new_size) { 270 delete[] buf_; 271 size_ = new_size; 272 buf_ = new char[size_]; 273 } 274 } 275 276 // Sends the destination address to the server to tell it to connect. 277 void SendConnectRequest() { 278 char buf[256]; 279 size_t length = PackAddressForNAT(buf, ARRAY_SIZE(buf), remote_addr_); 280 socket_->Send(buf, length); 281 } 282 283 // Handles the byte sent back from the server and fires the appropriate event. 284 void HandleConnectReply() { 285 char code; 286 socket_->Recv(&code, sizeof(code)); 287 if (code == 0) { 288 SignalConnectEvent(this); 289 } else { 290 Close(); 291 SignalCloseEvent(this, code); 292 } 293 } 294 295 NATInternalSocketFactory* sf_; 296 int family_; 297 int type_; 298 bool connected_; 299 SocketAddress remote_addr_; 300 SocketAddress server_addr_; // address of the NAT server 301 AsyncSocket* socket_; 302 char* buf_; 303 size_t size_; 304 }; 305 306 // NATSocketFactory 307 NATSocketFactory::NATSocketFactory(SocketFactory* factory, 308 const SocketAddress& nat_addr) 309 : factory_(factory), nat_addr_(nat_addr) { 310 } 311 312 Socket* NATSocketFactory::CreateSocket(int type) { 313 return CreateSocket(AF_INET, type); 314 } 315 316 Socket* NATSocketFactory::CreateSocket(int family, int type) { 317 return new NATSocket(this, family, type); 318 } 319 320 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) { 321 return CreateAsyncSocket(AF_INET, type); 322 } 323 324 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) { 325 return new NATSocket(this, family, type); 326 } 327 328 AsyncSocket* NATSocketFactory::CreateInternalSocket(int family, int type, 329 const SocketAddress& local_addr, SocketAddress* nat_addr) { 330 *nat_addr = nat_addr_; 331 return factory_->CreateAsyncSocket(family, type); 332 } 333 334 // NATSocketServer 335 NATSocketServer::NATSocketServer(SocketServer* server) 336 : server_(server), msg_queue_(NULL) { 337 } 338 339 NATSocketServer::Translator* NATSocketServer::GetTranslator( 340 const SocketAddress& ext_ip) { 341 return nats_.Get(ext_ip); 342 } 343 344 NATSocketServer::Translator* NATSocketServer::AddTranslator( 345 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) { 346 // Fail if a translator already exists with this extternal address. 347 if (nats_.Get(ext_ip)) 348 return NULL; 349 350 return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip)); 351 } 352 353 void NATSocketServer::RemoveTranslator( 354 const SocketAddress& ext_ip) { 355 nats_.Remove(ext_ip); 356 } 357 358 Socket* NATSocketServer::CreateSocket(int type) { 359 return CreateSocket(AF_INET, type); 360 } 361 362 Socket* NATSocketServer::CreateSocket(int family, int type) { 363 return new NATSocket(this, family, type); 364 } 365 366 AsyncSocket* NATSocketServer::CreateAsyncSocket(int type) { 367 return CreateAsyncSocket(AF_INET, type); 368 } 369 370 AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) { 371 return new NATSocket(this, family, type); 372 } 373 374 AsyncSocket* NATSocketServer::CreateInternalSocket(int family, int type, 375 const SocketAddress& local_addr, SocketAddress* nat_addr) { 376 AsyncSocket* socket = NULL; 377 Translator* nat = nats_.FindClient(local_addr); 378 if (nat) { 379 socket = nat->internal_factory()->CreateAsyncSocket(family, type); 380 *nat_addr = (type == SOCK_STREAM) ? 381 nat->internal_tcp_address() : nat->internal_address(); 382 } else { 383 socket = server_->CreateAsyncSocket(family, type); 384 } 385 return socket; 386 } 387 388 // NATSocketServer::Translator 389 NATSocketServer::Translator::Translator( 390 NATSocketServer* server, NATType type, const SocketAddress& int_ip, 391 SocketFactory* ext_factory, const SocketAddress& ext_ip) 392 : server_(server) { 393 // Create a new private network, and a NATServer running on the private 394 // network that bridges to the external network. Also tell the private 395 // network to use the same message queue as us. 396 VirtualSocketServer* internal_server = new VirtualSocketServer(server_); 397 internal_server->SetMessageQueue(server_->queue()); 398 internal_factory_.reset(internal_server); 399 nat_server_.reset(new NATServer(type, internal_server, int_ip, 400 ext_factory, ext_ip)); 401 } 402 403 404 NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator( 405 const SocketAddress& ext_ip) { 406 return nats_.Get(ext_ip); 407 } 408 409 NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator( 410 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) { 411 // Fail if a translator already exists with this extternal address. 412 if (nats_.Get(ext_ip)) 413 return NULL; 414 415 AddClient(ext_ip); 416 return nats_.Add(ext_ip, 417 new Translator(server_, type, int_ip, server_, ext_ip)); 418 } 419 void NATSocketServer::Translator::RemoveTranslator( 420 const SocketAddress& ext_ip) { 421 nats_.Remove(ext_ip); 422 RemoveClient(ext_ip); 423 } 424 425 bool NATSocketServer::Translator::AddClient( 426 const SocketAddress& int_ip) { 427 // Fail if a client already exists with this internal address. 428 if (clients_.find(int_ip) != clients_.end()) 429 return false; 430 431 clients_.insert(int_ip); 432 return true; 433 } 434 435 void NATSocketServer::Translator::RemoveClient( 436 const SocketAddress& int_ip) { 437 std::set<SocketAddress>::iterator it = clients_.find(int_ip); 438 if (it != clients_.end()) { 439 clients_.erase(it); 440 } 441 } 442 443 NATSocketServer::Translator* NATSocketServer::Translator::FindClient( 444 const SocketAddress& int_ip) { 445 // See if we have the requested IP, or any of our children do. 446 return (clients_.find(int_ip) != clients_.end()) ? 447 this : nats_.FindClient(int_ip); 448 } 449 450 // NATSocketServer::TranslatorMap 451 NATSocketServer::TranslatorMap::~TranslatorMap() { 452 for (TranslatorMap::iterator it = begin(); it != end(); ++it) { 453 delete it->second; 454 } 455 } 456 457 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get( 458 const SocketAddress& ext_ip) { 459 TranslatorMap::iterator it = find(ext_ip); 460 return (it != end()) ? it->second : NULL; 461 } 462 463 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add( 464 const SocketAddress& ext_ip, Translator* nat) { 465 (*this)[ext_ip] = nat; 466 return nat; 467 } 468 469 void NATSocketServer::TranslatorMap::Remove( 470 const SocketAddress& ext_ip) { 471 TranslatorMap::iterator it = find(ext_ip); 472 if (it != end()) { 473 delete it->second; 474 erase(it); 475 } 476 } 477 478 NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient( 479 const SocketAddress& int_ip) { 480 Translator* nat = NULL; 481 for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) { 482 nat = it->second->FindClient(int_ip); 483 } 484 return nat; 485 } 486 487 } // namespace rtc 488