1 /* 2 * libjingle 3 * Copyright 2004--2005, Google Inc. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, 9 * this list of conditions and the following disclaimer. 10 * 2. Redistributions in binary form must reproduce the above copyright notice, 11 * this list of conditions and the following disclaimer in the documentation 12 * and/or other materials provided with the distribution. 13 * 3. The name of the author may not be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED 17 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 18 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO 19 * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 20 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 22 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 23 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 24 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 25 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 */ 27 28 #include "talk/base/natsocketfactory.h" 29 30 #include "talk/base/logging.h" 31 #include "talk/base/natserver.h" 32 #include "talk/base/virtualsocketserver.h" 33 34 namespace talk_base { 35 36 // Packs the given socketaddress into the buffer in buf, in the quasi-STUN 37 // format that the natserver uses. 38 // Returns 0 if an invalid address is passed. 39 size_t PackAddressForNAT(char* buf, size_t buf_size, 40 const SocketAddress& remote_addr) { 41 const IPAddress& ip = remote_addr.ipaddr(); 42 int family = ip.family(); 43 buf[0] = 0; 44 buf[1] = family; 45 // Writes the port. 46 *(reinterpret_cast<uint16*>(&buf[2])) = HostToNetwork16(remote_addr.port()); 47 if (family == AF_INET) { 48 ASSERT(buf_size >= kNATEncodedIPv4AddressSize); 49 in_addr v4addr = ip.ipv4_address(); 50 std::memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4); 51 return kNATEncodedIPv4AddressSize; 52 } else if (family == AF_INET6) { 53 ASSERT(buf_size >= kNATEncodedIPv6AddressSize); 54 in6_addr v6addr = ip.ipv6_address(); 55 std::memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4); 56 return kNATEncodedIPv6AddressSize; 57 } 58 return 0U; 59 } 60 61 // Decodes the remote address from a packet that has been encoded with the nat's 62 // quasi-STUN format. Returns the length of the address (i.e., the offset into 63 // data where the original packet starts). 64 size_t UnpackAddressFromNAT(const char* buf, size_t buf_size, 65 SocketAddress* remote_addr) { 66 ASSERT(buf_size >= 8); 67 ASSERT(buf[0] == 0); 68 int family = buf[1]; 69 uint16 port = NetworkToHost16(*(reinterpret_cast<const uint16*>(&buf[2]))); 70 if (family == AF_INET) { 71 const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]); 72 *remote_addr = SocketAddress(IPAddress(*v4addr), port); 73 return kNATEncodedIPv4AddressSize; 74 } else if (family == AF_INET6) { 75 ASSERT(buf_size >= 20); 76 const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]); 77 *remote_addr = SocketAddress(IPAddress(*v6addr), port); 78 return kNATEncodedIPv6AddressSize; 79 } 80 return 0U; 81 } 82 83 84 // NATSocket 85 class NATSocket : public AsyncSocket, public sigslot::has_slots<> { 86 public: 87 explicit NATSocket(NATInternalSocketFactory* sf, int family, int type) 88 : sf_(sf), family_(family), type_(type), async_(true), connected_(false), 89 socket_(NULL), buf_(NULL), size_(0) { 90 } 91 92 virtual ~NATSocket() { 93 delete socket_; 94 delete[] buf_; 95 } 96 97 virtual SocketAddress GetLocalAddress() const { 98 return (socket_) ? socket_->GetLocalAddress() : SocketAddress(); 99 } 100 101 virtual SocketAddress GetRemoteAddress() const { 102 return remote_addr_; // will be NIL if not connected 103 } 104 105 virtual int Bind(const SocketAddress& addr) { 106 if (socket_) { // already bound, bubble up error 107 return -1; 108 } 109 110 int result; 111 socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_); 112 result = (socket_) ? socket_->Bind(addr) : -1; 113 if (result >= 0) { 114 socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent); 115 socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent); 116 socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent); 117 socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent); 118 } else { 119 server_addr_.Clear(); 120 delete socket_; 121 socket_ = NULL; 122 } 123 124 return result; 125 } 126 127 virtual int Connect(const SocketAddress& addr) { 128 if (!socket_) { // socket must be bound, for now 129 return -1; 130 } 131 132 int result = 0; 133 if (type_ == SOCK_STREAM) { 134 result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_); 135 } else { 136 connected_ = true; 137 } 138 139 if (result >= 0) { 140 remote_addr_ = addr; 141 } 142 143 return result; 144 } 145 146 virtual int Send(const void* data, size_t size) { 147 ASSERT(connected_); 148 return SendTo(data, size, remote_addr_); 149 } 150 151 virtual int SendTo(const void* data, size_t size, const SocketAddress& addr) { 152 ASSERT(!connected_ || addr == remote_addr_); 153 if (server_addr_.IsNil() || type_ == SOCK_STREAM) { 154 return socket_->SendTo(data, size, addr); 155 } 156 // This array will be too large for IPv4 packets, but only by 12 bytes. 157 scoped_array<char> buf(new char[size + kNATEncodedIPv6AddressSize]); 158 size_t addrlength = PackAddressForNAT(buf.get(), 159 size + kNATEncodedIPv6AddressSize, 160 addr); 161 size_t encoded_size = size + addrlength; 162 std::memcpy(buf.get() + addrlength, data, size); 163 int result = socket_->SendTo(buf.get(), encoded_size, server_addr_); 164 if (result >= 0) { 165 ASSERT(result == static_cast<int>(encoded_size)); 166 result = result - static_cast<int>(addrlength); 167 } 168 return result; 169 } 170 171 virtual int Recv(void* data, size_t size) { 172 SocketAddress addr; 173 return RecvFrom(data, size, &addr); 174 } 175 176 virtual int RecvFrom(void* data, size_t size, SocketAddress *out_addr) { 177 if (server_addr_.IsNil() || type_ == SOCK_STREAM) { 178 return socket_->RecvFrom(data, size, out_addr); 179 } 180 // Make sure we have enough room to read the requested amount plus the 181 // largest possible header address. 182 SocketAddress remote_addr; 183 Grow(size + kNATEncodedIPv6AddressSize); 184 185 // Read the packet from the socket. 186 int result = socket_->RecvFrom(buf_, size_, &remote_addr); 187 if (result >= 0) { 188 ASSERT(remote_addr == server_addr_); 189 190 // TODO: we need better framing so we know how many bytes we can 191 // return before we need to read the next address. For UDP, this will be 192 // fine as long as the reader always reads everything in the packet. 193 ASSERT((size_t)result < size_); 194 195 // Decode the wire packet into the actual results. 196 SocketAddress real_remote_addr; 197 size_t addrlength = 198 UnpackAddressFromNAT(buf_, result, &real_remote_addr); 199 std::memcpy(data, buf_ + addrlength, result - addrlength); 200 201 // Make sure this packet should be delivered before returning it. 202 if (!connected_ || (real_remote_addr == remote_addr_)) { 203 if (out_addr) 204 *out_addr = real_remote_addr; 205 result = result - static_cast<int>(addrlength); 206 } else { 207 LOG(LS_ERROR) << "Dropping packet from unknown remote address: " 208 << real_remote_addr.ToString(); 209 result = 0; // Tell the caller we didn't read anything 210 } 211 } 212 213 return result; 214 } 215 216 virtual int Close() { 217 int result = 0; 218 if (socket_) { 219 result = socket_->Close(); 220 if (result >= 0) { 221 connected_ = false; 222 remote_addr_ = SocketAddress(); 223 delete socket_; 224 socket_ = NULL; 225 } 226 } 227 return result; 228 } 229 230 virtual int Listen(int backlog) { 231 return socket_->Listen(backlog); 232 } 233 virtual AsyncSocket* Accept(SocketAddress *paddr) { 234 return socket_->Accept(paddr); 235 } 236 virtual int GetError() const { 237 return socket_->GetError(); 238 } 239 virtual void SetError(int error) { 240 socket_->SetError(error); 241 } 242 virtual ConnState GetState() const { 243 return connected_ ? CS_CONNECTED : CS_CLOSED; 244 } 245 virtual int EstimateMTU(uint16* mtu) { 246 return socket_->EstimateMTU(mtu); 247 } 248 virtual int GetOption(Option opt, int* value) { 249 return socket_->GetOption(opt, value); 250 } 251 virtual int SetOption(Option opt, int value) { 252 return socket_->SetOption(opt, value); 253 } 254 255 void OnConnectEvent(AsyncSocket* socket) { 256 // If we're NATed, we need to send a request with the real addr to use. 257 ASSERT(socket == socket_); 258 if (server_addr_.IsNil()) { 259 connected_ = true; 260 SignalConnectEvent(this); 261 } else { 262 SendConnectRequest(); 263 } 264 } 265 void OnReadEvent(AsyncSocket* socket) { 266 // If we're NATed, we need to process the connect reply. 267 ASSERT(socket == socket_); 268 if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) { 269 HandleConnectReply(); 270 } else { 271 SignalReadEvent(this); 272 } 273 } 274 void OnWriteEvent(AsyncSocket* socket) { 275 ASSERT(socket == socket_); 276 SignalWriteEvent(this); 277 } 278 void OnCloseEvent(AsyncSocket* socket, int error) { 279 ASSERT(socket == socket_); 280 SignalCloseEvent(this, error); 281 } 282 283 private: 284 // Makes sure the buffer is at least the given size. 285 void Grow(size_t new_size) { 286 if (size_ < new_size) { 287 delete[] buf_; 288 size_ = new_size; 289 buf_ = new char[size_]; 290 } 291 } 292 293 // Sends the destination address to the server to tell it to connect. 294 void SendConnectRequest() { 295 char buf[256]; 296 size_t length = PackAddressForNAT(buf, ARRAY_SIZE(buf), remote_addr_); 297 socket_->Send(buf, length); 298 } 299 300 // Handles the byte sent back from the server and fires the appropriate event. 301 void HandleConnectReply() { 302 char code; 303 socket_->Recv(&code, sizeof(code)); 304 if (code == 0) { 305 SignalConnectEvent(this); 306 } else { 307 Close(); 308 SignalCloseEvent(this, code); 309 } 310 } 311 312 NATInternalSocketFactory* sf_; 313 int family_; 314 int type_; 315 bool async_; 316 bool connected_; 317 SocketAddress remote_addr_; 318 SocketAddress server_addr_; // address of the NAT server 319 AsyncSocket* socket_; 320 char* buf_; 321 size_t size_; 322 }; 323 324 // NATSocketFactory 325 NATSocketFactory::NATSocketFactory(SocketFactory* factory, 326 const SocketAddress& nat_addr) 327 : factory_(factory), nat_addr_(nat_addr) { 328 } 329 330 Socket* NATSocketFactory::CreateSocket(int type) { 331 return CreateSocket(AF_INET, type); 332 } 333 334 Socket* NATSocketFactory::CreateSocket(int family, int type) { 335 return new NATSocket(this, family, type); 336 } 337 338 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) { 339 return CreateAsyncSocket(AF_INET, type); 340 } 341 342 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) { 343 return new NATSocket(this, family, type); 344 } 345 346 AsyncSocket* NATSocketFactory::CreateInternalSocket(int family, int type, 347 const SocketAddress& local_addr, SocketAddress* nat_addr) { 348 *nat_addr = nat_addr_; 349 return factory_->CreateAsyncSocket(family, type); 350 } 351 352 // NATSocketServer 353 NATSocketServer::NATSocketServer(SocketServer* server) 354 : server_(server), msg_queue_(NULL) { 355 } 356 357 NATSocketServer::Translator* NATSocketServer::GetTranslator( 358 const SocketAddress& ext_ip) { 359 return nats_.Get(ext_ip); 360 } 361 362 NATSocketServer::Translator* NATSocketServer::AddTranslator( 363 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) { 364 // Fail if a translator already exists with this extternal address. 365 if (nats_.Get(ext_ip)) 366 return NULL; 367 368 return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip)); 369 } 370 371 void NATSocketServer::RemoveTranslator( 372 const SocketAddress& ext_ip) { 373 nats_.Remove(ext_ip); 374 } 375 376 Socket* NATSocketServer::CreateSocket(int type) { 377 return CreateSocket(AF_INET, type); 378 } 379 380 Socket* NATSocketServer::CreateSocket(int family, int type) { 381 return new NATSocket(this, family, type); 382 } 383 384 AsyncSocket* NATSocketServer::CreateAsyncSocket(int type) { 385 return CreateAsyncSocket(AF_INET, type); 386 } 387 388 AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) { 389 return new NATSocket(this, family, type); 390 } 391 392 AsyncSocket* NATSocketServer::CreateInternalSocket(int family, int type, 393 const SocketAddress& local_addr, SocketAddress* nat_addr) { 394 AsyncSocket* socket = NULL; 395 Translator* nat = nats_.FindClient(local_addr); 396 if (nat) { 397 socket = nat->internal_factory()->CreateAsyncSocket(family, type); 398 *nat_addr = (type == SOCK_STREAM) ? 399 nat->internal_tcp_address() : nat->internal_address(); 400 } else { 401 socket = server_->CreateAsyncSocket(family, type); 402 } 403 return socket; 404 } 405 406 // NATSocketServer::Translator 407 NATSocketServer::Translator::Translator( 408 NATSocketServer* server, NATType type, const SocketAddress& int_ip, 409 SocketFactory* ext_factory, const SocketAddress& ext_ip) 410 : server_(server) { 411 // Create a new private network, and a NATServer running on the private 412 // network that bridges to the external network. Also tell the private 413 // network to use the same message queue as us. 414 VirtualSocketServer* internal_server = new VirtualSocketServer(server_); 415 internal_server->SetMessageQueue(server_->queue()); 416 internal_factory_.reset(internal_server); 417 nat_server_.reset(new NATServer(type, internal_server, int_ip, 418 ext_factory, ext_ip)); 419 } 420 421 422 NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator( 423 const SocketAddress& ext_ip) { 424 return nats_.Get(ext_ip); 425 } 426 427 NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator( 428 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) { 429 // Fail if a translator already exists with this extternal address. 430 if (nats_.Get(ext_ip)) 431 return NULL; 432 433 AddClient(ext_ip); 434 return nats_.Add(ext_ip, 435 new Translator(server_, type, int_ip, server_, ext_ip)); 436 } 437 void NATSocketServer::Translator::RemoveTranslator( 438 const SocketAddress& ext_ip) { 439 nats_.Remove(ext_ip); 440 RemoveClient(ext_ip); 441 } 442 443 bool NATSocketServer::Translator::AddClient( 444 const SocketAddress& int_ip) { 445 // Fail if a client already exists with this internal address. 446 if (clients_.find(int_ip) != clients_.end()) 447 return false; 448 449 clients_.insert(int_ip); 450 return true; 451 } 452 453 void NATSocketServer::Translator::RemoveClient( 454 const SocketAddress& int_ip) { 455 std::set<SocketAddress>::iterator it = clients_.find(int_ip); 456 if (it != clients_.end()) { 457 clients_.erase(it); 458 } 459 } 460 461 NATSocketServer::Translator* NATSocketServer::Translator::FindClient( 462 const SocketAddress& int_ip) { 463 // See if we have the requested IP, or any of our children do. 464 return (clients_.find(int_ip) != clients_.end()) ? 465 this : nats_.FindClient(int_ip); 466 } 467 468 // NATSocketServer::TranslatorMap 469 NATSocketServer::TranslatorMap::~TranslatorMap() { 470 for (TranslatorMap::iterator it = begin(); it != end(); ++it) { 471 delete it->second; 472 } 473 } 474 475 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get( 476 const SocketAddress& ext_ip) { 477 TranslatorMap::iterator it = find(ext_ip); 478 return (it != end()) ? it->second : NULL; 479 } 480 481 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add( 482 const SocketAddress& ext_ip, Translator* nat) { 483 (*this)[ext_ip] = nat; 484 return nat; 485 } 486 487 void NATSocketServer::TranslatorMap::Remove( 488 const SocketAddress& ext_ip) { 489 TranslatorMap::iterator it = find(ext_ip); 490 if (it != end()) { 491 delete it->second; 492 erase(it); 493 } 494 } 495 496 NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient( 497 const SocketAddress& int_ip) { 498 Translator* nat = NULL; 499 for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) { 500 nat = it->second->FindClient(int_ip); 501 } 502 return nat; 503 } 504 505 } // namespace talk_base 506