1 /* 2 * libjingle 3 * Copyright 2004--2005, Google Inc. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, 9 * this list of conditions and the following disclaimer. 10 * 2. Redistributions in binary form must reproduce the above copyright notice, 11 * this list of conditions and the following disclaimer in the documentation 12 * and/or other materials provided with the distribution. 13 * 3. The name of the author may not be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED 17 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 18 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO 19 * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 20 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 22 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 23 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 24 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 25 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 */ 27 28 #include "talk/base/natsocketfactory.h" 29 30 #include "talk/base/logging.h" 31 #include "talk/base/natserver.h" 32 #include "talk/base/virtualsocketserver.h" 33 34 namespace talk_base { 35 36 // Packs the given socketaddress into the buffer in buf, in the quasi-STUN 37 // format that the natserver uses. 38 // Returns 0 if an invalid address is passed. 39 size_t PackAddressForNAT(char* buf, size_t buf_size, 40 const SocketAddress& remote_addr) { 41 const IPAddress& ip = remote_addr.ipaddr(); 42 int family = ip.family(); 43 buf[0] = 0; 44 buf[1] = family; 45 // Writes the port. 46 *(reinterpret_cast<uint16*>(&buf[2])) = HostToNetwork16(remote_addr.port()); 47 if (family == AF_INET) { 48 ASSERT(buf_size >= kNATEncodedIPv4AddressSize); 49 in_addr v4addr = ip.ipv4_address(); 50 std::memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4); 51 return kNATEncodedIPv4AddressSize; 52 } else if (family == AF_INET6) { 53 ASSERT(buf_size >= kNATEncodedIPv6AddressSize); 54 in6_addr v6addr = ip.ipv6_address(); 55 std::memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4); 56 return kNATEncodedIPv6AddressSize; 57 } 58 return 0U; 59 } 60 61 // Decodes the remote address from a packet that has been encoded with the nat's 62 // quasi-STUN format. Returns the length of the address (i.e., the offset into 63 // data where the original packet starts). 64 size_t UnpackAddressFromNAT(const char* buf, size_t buf_size, 65 SocketAddress* remote_addr) { 66 ASSERT(buf_size >= 8); 67 ASSERT(buf[0] == 0); 68 int family = buf[1]; 69 uint16 port = NetworkToHost16(*(reinterpret_cast<const uint16*>(&buf[2]))); 70 if (family == AF_INET) { 71 const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]); 72 *remote_addr = SocketAddress(IPAddress(*v4addr), port); 73 return kNATEncodedIPv4AddressSize; 74 } else if (family == AF_INET6) { 75 ASSERT(buf_size >= 20); 76 const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]); 77 *remote_addr = SocketAddress(IPAddress(*v6addr), port); 78 return kNATEncodedIPv6AddressSize; 79 } 80 return 0U; 81 } 82 83 84 // NATSocket 85 class NATSocket : public AsyncSocket, public sigslot::has_slots<> { 86 public: 87 explicit NATSocket(NATInternalSocketFactory* sf, int family, int type) 88 : sf_(sf), family_(family), type_(type), connected_(false), 89 socket_(NULL), buf_(NULL), size_(0) { 90 } 91 92 virtual ~NATSocket() { 93 delete socket_; 94 delete[] buf_; 95 } 96 97 virtual SocketAddress GetLocalAddress() const { 98 return (socket_) ? socket_->GetLocalAddress() : SocketAddress(); 99 } 100 101 virtual SocketAddress GetRemoteAddress() const { 102 return remote_addr_; // will be NIL if not connected 103 } 104 105 virtual int Bind(const SocketAddress& addr) { 106 if (socket_) { // already bound, bubble up error 107 return -1; 108 } 109 110 int result; 111 socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_); 112 result = (socket_) ? socket_->Bind(addr) : -1; 113 if (result >= 0) { 114 socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent); 115 socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent); 116 socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent); 117 socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent); 118 } else { 119 server_addr_.Clear(); 120 delete socket_; 121 socket_ = NULL; 122 } 123 124 return result; 125 } 126 127 virtual int Connect(const SocketAddress& addr) { 128 if (!socket_) { // socket must be bound, for now 129 return -1; 130 } 131 132 int result = 0; 133 if (type_ == SOCK_STREAM) { 134 result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_); 135 } else { 136 connected_ = true; 137 } 138 139 if (result >= 0) { 140 remote_addr_ = addr; 141 } 142 143 return result; 144 } 145 146 virtual int Send(const void* data, size_t size) { 147 ASSERT(connected_); 148 return SendTo(data, size, remote_addr_); 149 } 150 151 virtual int SendTo(const void* data, size_t size, const SocketAddress& addr) { 152 ASSERT(!connected_ || addr == remote_addr_); 153 if (server_addr_.IsNil() || type_ == SOCK_STREAM) { 154 return socket_->SendTo(data, size, addr); 155 } 156 // This array will be too large for IPv4 packets, but only by 12 bytes. 157 scoped_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]); 158 size_t addrlength = PackAddressForNAT(buf.get(), 159 size + kNATEncodedIPv6AddressSize, 160 addr); 161 size_t encoded_size = size + addrlength; 162 std::memcpy(buf.get() + addrlength, data, size); 163 int result = socket_->SendTo(buf.get(), encoded_size, server_addr_); 164 if (result >= 0) { 165 ASSERT(result == static_cast<int>(encoded_size)); 166 result = result - static_cast<int>(addrlength); 167 } 168 return result; 169 } 170 171 virtual int Recv(void* data, size_t size) { 172 SocketAddress addr; 173 return RecvFrom(data, size, &addr); 174 } 175 176 virtual int RecvFrom(void* data, size_t size, SocketAddress *out_addr) { 177 if (server_addr_.IsNil() || type_ == SOCK_STREAM) { 178 return socket_->RecvFrom(data, size, out_addr); 179 } 180 // Make sure we have enough room to read the requested amount plus the 181 // largest possible header address. 182 SocketAddress remote_addr; 183 Grow(size + kNATEncodedIPv6AddressSize); 184 185 // Read the packet from the socket. 186 int result = socket_->RecvFrom(buf_, size_, &remote_addr); 187 if (result >= 0) { 188 ASSERT(remote_addr == server_addr_); 189 190 // TODO: we need better framing so we know how many bytes we can 191 // return before we need to read the next address. For UDP, this will be 192 // fine as long as the reader always reads everything in the packet. 193 ASSERT((size_t)result < size_); 194 195 // Decode the wire packet into the actual results. 196 SocketAddress real_remote_addr; 197 size_t addrlength = 198 UnpackAddressFromNAT(buf_, result, &real_remote_addr); 199 std::memcpy(data, buf_ + addrlength, result - addrlength); 200 201 // Make sure this packet should be delivered before returning it. 202 if (!connected_ || (real_remote_addr == remote_addr_)) { 203 if (out_addr) 204 *out_addr = real_remote_addr; 205 result = result - static_cast<int>(addrlength); 206 } else { 207 LOG(LS_ERROR) << "Dropping packet from unknown remote address: " 208 << real_remote_addr.ToString(); 209 result = 0; // Tell the caller we didn't read anything 210 } 211 } 212 213 return result; 214 } 215 216 virtual int Close() { 217 int result = 0; 218 if (socket_) { 219 result = socket_->Close(); 220 if (result >= 0) { 221 connected_ = false; 222 remote_addr_ = SocketAddress(); 223 delete socket_; 224 socket_ = NULL; 225 } 226 } 227 return result; 228 } 229 230 virtual int Listen(int backlog) { 231 return socket_->Listen(backlog); 232 } 233 virtual AsyncSocket* Accept(SocketAddress *paddr) { 234 return socket_->Accept(paddr); 235 } 236 virtual int GetError() const { 237 return socket_->GetError(); 238 } 239 virtual void SetError(int error) { 240 socket_->SetError(error); 241 } 242 virtual ConnState GetState() const { 243 return connected_ ? CS_CONNECTED : CS_CLOSED; 244 } 245 virtual int EstimateMTU(uint16* mtu) { 246 return socket_->EstimateMTU(mtu); 247 } 248 virtual int GetOption(Option opt, int* value) { 249 return socket_->GetOption(opt, value); 250 } 251 virtual int SetOption(Option opt, int value) { 252 return socket_->SetOption(opt, value); 253 } 254 255 void OnConnectEvent(AsyncSocket* socket) { 256 // If we're NATed, we need to send a request with the real addr to use. 257 ASSERT(socket == socket_); 258 if (server_addr_.IsNil()) { 259 connected_ = true; 260 SignalConnectEvent(this); 261 } else { 262 SendConnectRequest(); 263 } 264 } 265 void OnReadEvent(AsyncSocket* socket) { 266 // If we're NATed, we need to process the connect reply. 267 ASSERT(socket == socket_); 268 if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) { 269 HandleConnectReply(); 270 } else { 271 SignalReadEvent(this); 272 } 273 } 274 void OnWriteEvent(AsyncSocket* socket) { 275 ASSERT(socket == socket_); 276 SignalWriteEvent(this); 277 } 278 void OnCloseEvent(AsyncSocket* socket, int error) { 279 ASSERT(socket == socket_); 280 SignalCloseEvent(this, error); 281 } 282 283 private: 284 // Makes sure the buffer is at least the given size. 285 void Grow(size_t new_size) { 286 if (size_ < new_size) { 287 delete[] buf_; 288 size_ = new_size; 289 buf_ = new char[size_]; 290 } 291 } 292 293 // Sends the destination address to the server to tell it to connect. 294 void SendConnectRequest() { 295 char buf[256]; 296 size_t length = PackAddressForNAT(buf, ARRAY_SIZE(buf), remote_addr_); 297 socket_->Send(buf, length); 298 } 299 300 // Handles the byte sent back from the server and fires the appropriate event. 301 void HandleConnectReply() { 302 char code; 303 socket_->Recv(&code, sizeof(code)); 304 if (code == 0) { 305 SignalConnectEvent(this); 306 } else { 307 Close(); 308 SignalCloseEvent(this, code); 309 } 310 } 311 312 NATInternalSocketFactory* sf_; 313 int family_; 314 int type_; 315 bool connected_; 316 SocketAddress remote_addr_; 317 SocketAddress server_addr_; // address of the NAT server 318 AsyncSocket* socket_; 319 char* buf_; 320 size_t size_; 321 }; 322 323 // NATSocketFactory 324 NATSocketFactory::NATSocketFactory(SocketFactory* factory, 325 const SocketAddress& nat_addr) 326 : factory_(factory), nat_addr_(nat_addr) { 327 } 328 329 Socket* NATSocketFactory::CreateSocket(int type) { 330 return CreateSocket(AF_INET, type); 331 } 332 333 Socket* NATSocketFactory::CreateSocket(int family, int type) { 334 return new NATSocket(this, family, type); 335 } 336 337 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) { 338 return CreateAsyncSocket(AF_INET, type); 339 } 340 341 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) { 342 return new NATSocket(this, family, type); 343 } 344 345 AsyncSocket* NATSocketFactory::CreateInternalSocket(int family, int type, 346 const SocketAddress& local_addr, SocketAddress* nat_addr) { 347 *nat_addr = nat_addr_; 348 return factory_->CreateAsyncSocket(family, type); 349 } 350 351 // NATSocketServer 352 NATSocketServer::NATSocketServer(SocketServer* server) 353 : server_(server), msg_queue_(NULL) { 354 } 355 356 NATSocketServer::Translator* NATSocketServer::GetTranslator( 357 const SocketAddress& ext_ip) { 358 return nats_.Get(ext_ip); 359 } 360 361 NATSocketServer::Translator* NATSocketServer::AddTranslator( 362 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) { 363 // Fail if a translator already exists with this extternal address. 364 if (nats_.Get(ext_ip)) 365 return NULL; 366 367 return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip)); 368 } 369 370 void NATSocketServer::RemoveTranslator( 371 const SocketAddress& ext_ip) { 372 nats_.Remove(ext_ip); 373 } 374 375 Socket* NATSocketServer::CreateSocket(int type) { 376 return CreateSocket(AF_INET, type); 377 } 378 379 Socket* NATSocketServer::CreateSocket(int family, int type) { 380 return new NATSocket(this, family, type); 381 } 382 383 AsyncSocket* NATSocketServer::CreateAsyncSocket(int type) { 384 return CreateAsyncSocket(AF_INET, type); 385 } 386 387 AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) { 388 return new NATSocket(this, family, type); 389 } 390 391 AsyncSocket* NATSocketServer::CreateInternalSocket(int family, int type, 392 const SocketAddress& local_addr, SocketAddress* nat_addr) { 393 AsyncSocket* socket = NULL; 394 Translator* nat = nats_.FindClient(local_addr); 395 if (nat) { 396 socket = nat->internal_factory()->CreateAsyncSocket(family, type); 397 *nat_addr = (type == SOCK_STREAM) ? 398 nat->internal_tcp_address() : nat->internal_address(); 399 } else { 400 socket = server_->CreateAsyncSocket(family, type); 401 } 402 return socket; 403 } 404 405 // NATSocketServer::Translator 406 NATSocketServer::Translator::Translator( 407 NATSocketServer* server, NATType type, const SocketAddress& int_ip, 408 SocketFactory* ext_factory, const SocketAddress& ext_ip) 409 : server_(server) { 410 // Create a new private network, and a NATServer running on the private 411 // network that bridges to the external network. Also tell the private 412 // network to use the same message queue as us. 413 VirtualSocketServer* internal_server = new VirtualSocketServer(server_); 414 internal_server->SetMessageQueue(server_->queue()); 415 internal_factory_.reset(internal_server); 416 nat_server_.reset(new NATServer(type, internal_server, int_ip, 417 ext_factory, ext_ip)); 418 } 419 420 421 NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator( 422 const SocketAddress& ext_ip) { 423 return nats_.Get(ext_ip); 424 } 425 426 NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator( 427 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) { 428 // Fail if a translator already exists with this extternal address. 429 if (nats_.Get(ext_ip)) 430 return NULL; 431 432 AddClient(ext_ip); 433 return nats_.Add(ext_ip, 434 new Translator(server_, type, int_ip, server_, ext_ip)); 435 } 436 void NATSocketServer::Translator::RemoveTranslator( 437 const SocketAddress& ext_ip) { 438 nats_.Remove(ext_ip); 439 RemoveClient(ext_ip); 440 } 441 442 bool NATSocketServer::Translator::AddClient( 443 const SocketAddress& int_ip) { 444 // Fail if a client already exists with this internal address. 445 if (clients_.find(int_ip) != clients_.end()) 446 return false; 447 448 clients_.insert(int_ip); 449 return true; 450 } 451 452 void NATSocketServer::Translator::RemoveClient( 453 const SocketAddress& int_ip) { 454 std::set<SocketAddress>::iterator it = clients_.find(int_ip); 455 if (it != clients_.end()) { 456 clients_.erase(it); 457 } 458 } 459 460 NATSocketServer::Translator* NATSocketServer::Translator::FindClient( 461 const SocketAddress& int_ip) { 462 // See if we have the requested IP, or any of our children do. 463 return (clients_.find(int_ip) != clients_.end()) ? 464 this : nats_.FindClient(int_ip); 465 } 466 467 // NATSocketServer::TranslatorMap 468 NATSocketServer::TranslatorMap::~TranslatorMap() { 469 for (TranslatorMap::iterator it = begin(); it != end(); ++it) { 470 delete it->second; 471 } 472 } 473 474 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get( 475 const SocketAddress& ext_ip) { 476 TranslatorMap::iterator it = find(ext_ip); 477 return (it != end()) ? it->second : NULL; 478 } 479 480 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add( 481 const SocketAddress& ext_ip, Translator* nat) { 482 (*this)[ext_ip] = nat; 483 return nat; 484 } 485 486 void NATSocketServer::TranslatorMap::Remove( 487 const SocketAddress& ext_ip) { 488 TranslatorMap::iterator it = find(ext_ip); 489 if (it != end()) { 490 delete it->second; 491 erase(it); 492 } 493 } 494 495 NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient( 496 const SocketAddress& int_ip) { 497 Translator* nat = NULL; 498 for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) { 499 nat = it->second->FindClient(int_ip); 500 } 501 return nat; 502 } 503 504 } // namespace talk_base 505