Home | History | Annotate | Download | only in base
      1 /*
      2  * libjingle
      3  * Copyright 2012, Google Inc.
      4  *
      5  * Redistribution and use in source and binary forms, with or without
      6  * modification, are permitted provided that the following conditions are met:
      7  *
      8  *  1. Redistributions of source code must retain the above copyright notice,
      9  *     this list of conditions and the following disclaimer.
     10  *  2. Redistributions in binary form must reproduce the above copyright notice,
     11  *     this list of conditions and the following disclaimer in the documentation
     12  *     and/or other materials provided with the distribution.
     13  *  3. The name of the author may not be used to endorse or promote products
     14  *     derived from this software without specific prior written permission.
     15  *
     16  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
     17  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
     18  * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
     19  * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     20  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
     21  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
     22  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
     23  * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
     24  * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
     25  * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     26  */
     27 
     28 #include "talk/p2p/base/turnserver.h"
     29 
     30 #include "talk/base/bytebuffer.h"
     31 #include "talk/base/helpers.h"
     32 #include "talk/base/logging.h"
     33 #include "talk/base/messagedigest.h"
     34 #include "talk/base/socketadapters.h"
     35 #include "talk/base/stringencode.h"
     36 #include "talk/base/thread.h"
     37 #include "talk/p2p/base/asyncstuntcpsocket.h"
     38 #include "talk/p2p/base/common.h"
     39 #include "talk/p2p/base/packetsocketfactory.h"
     40 #include "talk/p2p/base/stun.h"
     41 
     42 namespace cricket {
     43 
     44 // TODO(juberti): Move this all to a future turnmessage.h
     45 //static const int IPPROTO_UDP = 17;
     46 static const int kNonceTimeout = 60 * 60 * 1000;              // 60 minutes
     47 static const int kDefaultAllocationTimeout = 10 * 60 * 1000;  // 10 minutes
     48 static const int kPermissionTimeout = 5 * 60 * 1000;          //  5 minutes
     49 static const int kChannelTimeout = 10 * 60 * 1000;            // 10 minutes
     50 
     51 static const int kMinChannelNumber = 0x4000;
     52 static const int kMaxChannelNumber = 0x7FFF;
     53 
     54 static const size_t kNonceKeySize = 16;
     55 static const size_t kNonceSize = 40;
     56 
     57 static const size_t TURN_CHANNEL_HEADER_SIZE = 4U;
     58 
     59 // TODO(mallinath) - Move these to a common place.
     60 static const size_t kMaxPacketSize = 64 * 1024;
     61 
     62 inline bool IsTurnChannelData(uint16 msg_type) {
     63   // The first two bits of a channel data message are 0b01.
     64   return ((msg_type & 0xC000) == 0x4000);
     65 }
     66 
     67 // IDs used for posted messages.
     68 enum {
     69   MSG_TIMEOUT,
     70 };
     71 
     72 // Encapsulates a TURN allocation.
     73 // The object is created when an allocation request is received, and then
     74 // handles TURN messages (via HandleTurnMessage) and channel data messages
     75 // (via HandleChannelData) for this allocation when received by the server.
     76 // The object self-deletes and informs the server if its lifetime timer expires.
     77 class TurnServer::Allocation : public talk_base::MessageHandler,
     78                                public sigslot::has_slots<> {
     79  public:
     80   Allocation(TurnServer* server_,
     81              talk_base::Thread* thread, const Connection& conn,
     82              talk_base::AsyncPacketSocket* server_socket,
     83              const std::string& key);
     84   virtual ~Allocation();
     85 
     86   Connection* conn() { return &conn_; }
     87   const std::string& key() const { return key_; }
     88   const std::string& transaction_id() const { return transaction_id_; }
     89   const std::string& username() const { return username_; }
     90   const std::string& last_nonce() const { return last_nonce_; }
     91   void set_last_nonce(const std::string& nonce) { last_nonce_ = nonce; }
     92 
     93   std::string ToString() const;
     94 
     95   void HandleTurnMessage(const TurnMessage* msg);
     96   void HandleChannelData(const char* data, size_t size);
     97 
     98   sigslot::signal1<Allocation*> SignalDestroyed;
     99 
    100  private:
    101   typedef std::list<Permission*> PermissionList;
    102   typedef std::list<Channel*> ChannelList;
    103 
    104   void HandleAllocateRequest(const TurnMessage* msg);
    105   void HandleRefreshRequest(const TurnMessage* msg);
    106   void HandleSendIndication(const TurnMessage* msg);
    107   void HandleCreatePermissionRequest(const TurnMessage* msg);
    108   void HandleChannelBindRequest(const TurnMessage* msg);
    109 
    110   void OnExternalPacket(talk_base::AsyncPacketSocket* socket,
    111                         const char* data, size_t size,
    112                         const talk_base::SocketAddress& addr);
    113 
    114   static int ComputeLifetime(const TurnMessage* msg);
    115   bool HasPermission(const talk_base::IPAddress& addr);
    116   void AddPermission(const talk_base::IPAddress& addr);
    117   Permission* FindPermission(const talk_base::IPAddress& addr) const;
    118   Channel* FindChannel(int channel_id) const;
    119   Channel* FindChannel(const talk_base::SocketAddress& addr) const;
    120 
    121   void SendResponse(TurnMessage* msg);
    122   void SendBadRequestResponse(const TurnMessage* req);
    123   void SendErrorResponse(const TurnMessage* req, int code,
    124                          const std::string& reason);
    125   void SendExternal(const void* data, size_t size,
    126                     const talk_base::SocketAddress& peer);
    127 
    128   void OnPermissionDestroyed(Permission* perm);
    129   void OnChannelDestroyed(Channel* channel);
    130   virtual void OnMessage(talk_base::Message* msg);
    131 
    132   TurnServer* server_;
    133   talk_base::Thread* thread_;
    134   Connection conn_;
    135   talk_base::scoped_ptr<talk_base::AsyncPacketSocket> external_socket_;
    136   std::string key_;
    137   std::string transaction_id_;
    138   std::string username_;
    139   std::string last_nonce_;
    140   PermissionList perms_;
    141   ChannelList channels_;
    142 };
    143 
    144 // Encapsulates a TURN permission.
    145 // The object is created when a create permission request is received by an
    146 // allocation, and self-deletes when its lifetime timer expires.
    147 class TurnServer::Permission : public talk_base::MessageHandler {
    148  public:
    149   Permission(talk_base::Thread* thread, const talk_base::IPAddress& peer);
    150   ~Permission();
    151 
    152   const talk_base::IPAddress& peer() const { return peer_; }
    153   void Refresh();
    154 
    155   sigslot::signal1<Permission*> SignalDestroyed;
    156 
    157  private:
    158   virtual void OnMessage(talk_base::Message* msg);
    159 
    160   talk_base::Thread* thread_;
    161   talk_base::IPAddress peer_;
    162 };
    163 
    164 // Encapsulates a TURN channel binding.
    165 // The object is created when a channel bind request is received by an
    166 // allocation, and self-deletes when its lifetime timer expires.
    167 class TurnServer::Channel : public talk_base::MessageHandler {
    168  public:
    169   Channel(talk_base::Thread* thread, int id,
    170                      const talk_base::SocketAddress& peer);
    171   ~Channel();
    172 
    173   int id() const { return id_; }
    174   const talk_base::SocketAddress& peer() const { return peer_; }
    175   void Refresh();
    176 
    177   sigslot::signal1<Channel*> SignalDestroyed;
    178 
    179  private:
    180   virtual void OnMessage(talk_base::Message* msg);
    181 
    182   talk_base::Thread* thread_;
    183   int id_;
    184   talk_base::SocketAddress peer_;
    185 };
    186 
    187 static bool InitResponse(const StunMessage* req, StunMessage* resp) {
    188   int resp_type = (req) ? GetStunSuccessResponseType(req->type()) : -1;
    189   if (resp_type == -1)
    190     return false;
    191   resp->SetType(resp_type);
    192   resp->SetTransactionID(req->transaction_id());
    193   return true;
    194 }
    195 
    196 static bool InitErrorResponse(const StunMessage* req, int code,
    197                               const std::string& reason, StunMessage* resp) {
    198   int resp_type = (req) ? GetStunErrorResponseType(req->type()) : -1;
    199   if (resp_type == -1)
    200     return false;
    201   resp->SetType(resp_type);
    202   resp->SetTransactionID(req->transaction_id());
    203   VERIFY(resp->AddAttribute(new cricket::StunErrorCodeAttribute(
    204       STUN_ATTR_ERROR_CODE, code, reason)));
    205   return true;
    206 }
    207 
    208 TurnServer::TurnServer(talk_base::Thread* thread)
    209     : thread_(thread),
    210       nonce_key_(talk_base::CreateRandomString(kNonceKeySize)),
    211       auth_hook_(NULL),
    212       enable_otu_nonce_(false) {
    213 }
    214 
    215 TurnServer::~TurnServer() {
    216   for (AllocationMap::iterator it = allocations_.begin();
    217        it != allocations_.end(); ++it) {
    218     delete it->second;
    219   }
    220 
    221   for (InternalSocketMap::iterator it = server_sockets_.begin();
    222        it != server_sockets_.end(); ++it) {
    223     talk_base::AsyncPacketSocket* socket = it->first;
    224     delete socket;
    225   }
    226 
    227   for (ServerSocketMap::iterator it = server_listen_sockets_.begin();
    228        it != server_listen_sockets_.end(); ++it) {
    229     talk_base::AsyncSocket* socket = it->first;
    230     delete socket;
    231   }
    232 }
    233 
    234 void TurnServer::AddInternalSocket(talk_base::AsyncPacketSocket* socket,
    235                                    ProtocolType proto) {
    236   ASSERT(server_sockets_.end() == server_sockets_.find(socket));
    237   server_sockets_[socket] = proto;
    238   socket->SignalReadPacket.connect(this, &TurnServer::OnInternalPacket);
    239 }
    240 
    241 void TurnServer::AddInternalServerSocket(talk_base::AsyncSocket* socket,
    242                                          ProtocolType proto) {
    243   ASSERT(server_listen_sockets_.end() ==
    244       server_listen_sockets_.find(socket));
    245   server_listen_sockets_[socket] = proto;
    246   socket->SignalReadEvent.connect(this, &TurnServer::OnNewInternalConnection);
    247 }
    248 
    249 void TurnServer::SetExternalSocketFactory(
    250     talk_base::PacketSocketFactory* factory,
    251     const talk_base::SocketAddress& external_addr) {
    252   external_socket_factory_.reset(factory);
    253   external_addr_ = external_addr;
    254 }
    255 
    256 void TurnServer::OnNewInternalConnection(talk_base::AsyncSocket* socket) {
    257   ASSERT(server_listen_sockets_.find(socket) != server_listen_sockets_.end());
    258   AcceptConnection(socket);
    259 }
    260 
    261 void TurnServer::AcceptConnection(talk_base::AsyncSocket* server_socket) {
    262   // Check if someone is trying to connect to us.
    263   talk_base::SocketAddress accept_addr;
    264   talk_base::AsyncSocket* accepted_socket = server_socket->Accept(&accept_addr);
    265   if (accepted_socket != NULL) {
    266     ProtocolType proto = server_listen_sockets_[server_socket];
    267     cricket::AsyncStunTCPSocket* tcp_socket =
    268         new cricket::AsyncStunTCPSocket(accepted_socket, false);
    269 
    270     tcp_socket->SignalClose.connect(this, &TurnServer::OnInternalSocketClose);
    271     // Finally add the socket so it can start communicating with the client.
    272     AddInternalSocket(tcp_socket, proto);
    273   }
    274 }
    275 
    276 void TurnServer::OnInternalSocketClose(talk_base::AsyncPacketSocket* socket,
    277                                        int err) {
    278   DestroyInternalSocket(socket);
    279 }
    280 
    281 void TurnServer::OnInternalPacket(talk_base::AsyncPacketSocket* socket,
    282                                   const char* data, size_t size,
    283                                   const talk_base::SocketAddress& addr) {
    284   // Fail if the packet is too small to even contain a channel header.
    285   if (size < TURN_CHANNEL_HEADER_SIZE) {
    286    return;
    287   }
    288   InternalSocketMap::iterator iter = server_sockets_.find(socket);
    289   ASSERT(iter != server_sockets_.end());
    290   Connection conn(addr, iter->second, socket);
    291   uint16 msg_type = talk_base::GetBE16(data);
    292   if (!IsTurnChannelData(msg_type)) {
    293     // This is a STUN message.
    294     HandleStunMessage(&conn, data, size);
    295   } else {
    296     // This is a channel message; let the allocation handle it.
    297     Allocation* allocation = FindAllocation(&conn);
    298     if (allocation) {
    299       allocation->HandleChannelData(data, size);
    300     }
    301   }
    302 }
    303 
    304 void TurnServer::HandleStunMessage(Connection* conn, const char* data,
    305                                    size_t size) {
    306   TurnMessage msg;
    307   talk_base::ByteBuffer buf(data, size);
    308   if (!msg.Read(&buf) || (buf.Length() > 0)) {
    309     LOG(LS_WARNING) << "Received invalid STUN message";
    310     return;
    311   }
    312 
    313   // If it's a STUN binding request, handle that specially.
    314   if (msg.type() == STUN_BINDING_REQUEST) {
    315     HandleBindingRequest(conn, &msg);
    316     return;
    317   }
    318 
    319   // Look up the key that we'll use to validate the M-I. If we have an
    320   // existing allocation, the key will already be cached.
    321   Allocation* allocation = FindAllocation(conn);
    322   std::string key;
    323   if (!allocation) {
    324     GetKey(&msg, &key);
    325   } else {
    326     key = allocation->key();
    327   }
    328 
    329   // Ensure the message is authorized; only needed for requests.
    330   if (IsStunRequestType(msg.type())) {
    331     if (!CheckAuthorization(conn, &msg, data, size, key)) {
    332       return;
    333     }
    334   }
    335 
    336   if (!allocation && msg.type() == STUN_ALLOCATE_REQUEST) {
    337     // This is a new allocate request.
    338     HandleAllocateRequest(conn, &msg, key);
    339   } else if (allocation &&
    340              (msg.type() != STUN_ALLOCATE_REQUEST ||
    341               msg.transaction_id() == allocation->transaction_id())) {
    342     // This is a non-allocate request, or a retransmit of an allocate.
    343     // Check that the username matches the previous username used.
    344     if (IsStunRequestType(msg.type()) &&
    345         msg.GetByteString(STUN_ATTR_USERNAME)->GetString() !=
    346             allocation->username()) {
    347       SendErrorResponse(conn, &msg, STUN_ERROR_WRONG_CREDENTIALS,
    348                         STUN_ERROR_REASON_WRONG_CREDENTIALS);
    349       return;
    350     }
    351     allocation->HandleTurnMessage(&msg);
    352   } else {
    353     // Allocation mismatch.
    354     SendErrorResponse(conn, &msg, STUN_ERROR_ALLOCATION_MISMATCH,
    355                       STUN_ERROR_REASON_ALLOCATION_MISMATCH);
    356   }
    357 }
    358 
    359 bool TurnServer::GetKey(const StunMessage* msg, std::string* key) {
    360   const StunByteStringAttribute* username_attr =
    361       msg->GetByteString(STUN_ATTR_USERNAME);
    362   if (!username_attr) {
    363     return false;
    364   }
    365 
    366   std::string username = username_attr->GetString();
    367   return (auth_hook_ != NULL && auth_hook_->GetKey(username, realm_, key));
    368 }
    369 
    370 bool TurnServer::CheckAuthorization(Connection* conn,
    371                                     const StunMessage* msg,
    372                                     const char* data, size_t size,
    373                                     const std::string& key) {
    374   // RFC 5389, 10.2.2.
    375   ASSERT(IsStunRequestType(msg->type()));
    376   const StunByteStringAttribute* mi_attr =
    377       msg->GetByteString(STUN_ATTR_MESSAGE_INTEGRITY);
    378   const StunByteStringAttribute* username_attr =
    379       msg->GetByteString(STUN_ATTR_USERNAME);
    380   const StunByteStringAttribute* realm_attr =
    381       msg->GetByteString(STUN_ATTR_REALM);
    382   const StunByteStringAttribute* nonce_attr =
    383       msg->GetByteString(STUN_ATTR_NONCE);
    384 
    385   // Fail if no M-I.
    386   if (!mi_attr) {
    387     SendErrorResponseWithRealmAndNonce(conn, msg, STUN_ERROR_UNAUTHORIZED,
    388                                        STUN_ERROR_REASON_UNAUTHORIZED);
    389     return false;
    390   }
    391 
    392   // Fail if there is M-I but no username, nonce, or realm.
    393   if (!username_attr || !realm_attr || !nonce_attr) {
    394     SendErrorResponse(conn, msg, STUN_ERROR_BAD_REQUEST,
    395                       STUN_ERROR_REASON_BAD_REQUEST);
    396     return false;
    397   }
    398 
    399   // Fail if bad nonce.
    400   if (!ValidateNonce(nonce_attr->GetString())) {
    401     SendErrorResponseWithRealmAndNonce(conn, msg, STUN_ERROR_STALE_NONCE,
    402                                        STUN_ERROR_REASON_STALE_NONCE);
    403     return false;
    404   }
    405 
    406   // Fail if bad username or M-I.
    407   // We need |data| and |size| for the call to ValidateMessageIntegrity.
    408   if (key.empty() || !StunMessage::ValidateMessageIntegrity(data, size, key)) {
    409     SendErrorResponseWithRealmAndNonce(conn, msg, STUN_ERROR_UNAUTHORIZED,
    410                                        STUN_ERROR_REASON_UNAUTHORIZED);
    411     return false;
    412   }
    413 
    414   // Fail if one-time-use nonce feature is enabled.
    415   Allocation* allocation = FindAllocation(conn);
    416   if (enable_otu_nonce_ && allocation &&
    417       allocation->last_nonce() == nonce_attr->GetString()) {
    418     SendErrorResponseWithRealmAndNonce(conn, msg, STUN_ERROR_STALE_NONCE,
    419                                        STUN_ERROR_REASON_STALE_NONCE);
    420     return false;
    421   }
    422 
    423   if (allocation) {
    424     allocation->set_last_nonce(nonce_attr->GetString());
    425   }
    426   // Success.
    427   return true;
    428 }
    429 
    430 void TurnServer::HandleBindingRequest(Connection* conn,
    431                                       const StunMessage* req) {
    432   StunMessage response;
    433   InitResponse(req, &response);
    434 
    435   // Tell the user the address that we received their request from.
    436   StunAddressAttribute* mapped_addr_attr;
    437   mapped_addr_attr = new StunXorAddressAttribute(
    438       STUN_ATTR_XOR_MAPPED_ADDRESS, conn->src());
    439   VERIFY(response.AddAttribute(mapped_addr_attr));
    440 
    441   SendStun(conn, &response);
    442 }
    443 
    444 void TurnServer::HandleAllocateRequest(Connection* conn,
    445                                        const TurnMessage* msg,
    446                                        const std::string& key) {
    447   // Check the parameters in the request.
    448   const StunUInt32Attribute* transport_attr =
    449       msg->GetUInt32(STUN_ATTR_REQUESTED_TRANSPORT);
    450   if (!transport_attr) {
    451     SendErrorResponse(conn, msg, STUN_ERROR_BAD_REQUEST,
    452                       STUN_ERROR_REASON_BAD_REQUEST);
    453     return;
    454   }
    455 
    456   // Only UDP is supported right now.
    457   int proto = transport_attr->value() >> 24;
    458   if (proto != IPPROTO_UDP) {
    459     SendErrorResponse(conn, msg, STUN_ERROR_UNSUPPORTED_PROTOCOL,
    460                       STUN_ERROR_REASON_UNSUPPORTED_PROTOCOL);
    461     return;
    462   }
    463 
    464   // Create the allocation and let it send the success response.
    465   // If the actual socket allocation fails, send an internal error.
    466   Allocation* alloc = CreateAllocation(conn, proto, key);
    467   if (alloc) {
    468     alloc->HandleTurnMessage(msg);
    469   } else {
    470     SendErrorResponse(conn, msg, STUN_ERROR_SERVER_ERROR,
    471                       "Failed to allocate socket");
    472   }
    473 }
    474 
    475 std::string TurnServer::GenerateNonce() const {
    476   // Generate a nonce of the form hex(now + HMAC-MD5(nonce_key_, now))
    477   uint32 now = talk_base::Time();
    478   std::string input(reinterpret_cast<const char*>(&now), sizeof(now));
    479   std::string nonce = talk_base::hex_encode(input.c_str(), input.size());
    480   nonce += talk_base::ComputeHmac(talk_base::DIGEST_MD5, nonce_key_, input);
    481   ASSERT(nonce.size() == kNonceSize);
    482   return nonce;
    483 }
    484 
    485 bool TurnServer::ValidateNonce(const std::string& nonce) const {
    486   // Check the size.
    487   if (nonce.size() != kNonceSize) {
    488     return false;
    489   }
    490 
    491   // Decode the timestamp.
    492   uint32 then;
    493   char* p = reinterpret_cast<char*>(&then);
    494   size_t len = talk_base::hex_decode(p, sizeof(then),
    495       nonce.substr(0, sizeof(then) * 2));
    496   if (len != sizeof(then)) {
    497     return false;
    498   }
    499 
    500   // Verify the HMAC.
    501   if (nonce.substr(sizeof(then) * 2) != talk_base::ComputeHmac(
    502       talk_base::DIGEST_MD5, nonce_key_, std::string(p, sizeof(then)))) {
    503     return false;
    504   }
    505 
    506   // Validate the timestamp.
    507   return talk_base::TimeSince(then) < kNonceTimeout;
    508 }
    509 
    510 TurnServer::Allocation* TurnServer::FindAllocation(Connection* conn) {
    511   AllocationMap::const_iterator it = allocations_.find(*conn);
    512   return (it != allocations_.end()) ? it->second : NULL;
    513 }
    514 
    515 TurnServer::Allocation* TurnServer::CreateAllocation(Connection* conn,
    516                                                      int proto,
    517                                                      const std::string& key) {
    518   talk_base::AsyncPacketSocket* external_socket = (external_socket_factory_) ?
    519       external_socket_factory_->CreateUdpSocket(external_addr_, 0, 0) : NULL;
    520   if (!external_socket) {
    521     return NULL;
    522   }
    523 
    524   // The Allocation takes ownership of the socket.
    525   Allocation* allocation = new Allocation(this,
    526       thread_, *conn, external_socket, key);
    527   allocation->SignalDestroyed.connect(this, &TurnServer::OnAllocationDestroyed);
    528   allocations_[*conn] = allocation;
    529   return allocation;
    530 }
    531 
    532 void TurnServer::SendErrorResponse(Connection* conn,
    533                                    const StunMessage* req,
    534                                    int code, const std::string& reason) {
    535   TurnMessage resp;
    536   InitErrorResponse(req, code, reason, &resp);
    537   LOG(LS_INFO) << "Sending error response, type=" << resp.type()
    538                << ", code=" << code << ", reason=" << reason;
    539   SendStun(conn, &resp);
    540 }
    541 
    542 void TurnServer::SendErrorResponseWithRealmAndNonce(
    543     Connection* conn, const StunMessage* msg,
    544     int code, const std::string& reason) {
    545   TurnMessage resp;
    546   InitErrorResponse(msg, code, reason, &resp);
    547   VERIFY(resp.AddAttribute(new StunByteStringAttribute(
    548       STUN_ATTR_NONCE, GenerateNonce())));
    549   VERIFY(resp.AddAttribute(new StunByteStringAttribute(
    550       STUN_ATTR_REALM, realm_)));
    551   SendStun(conn, &resp);
    552 }
    553 
    554 void TurnServer::SendStun(Connection* conn, StunMessage* msg) {
    555   talk_base::ByteBuffer buf;
    556   // Add a SOFTWARE attribute if one is set.
    557   if (!software_.empty()) {
    558     VERIFY(msg->AddAttribute(
    559         new StunByteStringAttribute(STUN_ATTR_SOFTWARE, software_)));
    560   }
    561   msg->Write(&buf);
    562   Send(conn, buf);
    563 }
    564 
    565 void TurnServer::Send(Connection* conn,
    566                       const talk_base::ByteBuffer& buf) {
    567   conn->socket()->SendTo(buf.Data(), buf.Length(), conn->src());
    568 }
    569 
    570 void TurnServer::OnAllocationDestroyed(Allocation* allocation) {
    571   // Removing the internal socket if the connection is not udp.
    572   talk_base::AsyncPacketSocket* socket = allocation->conn()->socket();
    573   InternalSocketMap::iterator iter = server_sockets_.find(socket);
    574   ASSERT(iter != server_sockets_.end());
    575   // Skip if the socket serving this allocation is UDP, as this will be shared
    576   // by all allocations.
    577   if (iter->second != cricket::PROTO_UDP) {
    578     DestroyInternalSocket(socket);
    579   }
    580 
    581   AllocationMap::iterator it = allocations_.find(*(allocation->conn()));
    582   if (it != allocations_.end())
    583     allocations_.erase(it);
    584 }
    585 
    586 void TurnServer::DestroyInternalSocket(talk_base::AsyncPacketSocket* socket) {
    587   InternalSocketMap::iterator iter = server_sockets_.find(socket);
    588   if (iter != server_sockets_.end()) {
    589     talk_base::AsyncPacketSocket* socket = iter->first;
    590     delete socket;
    591     server_sockets_.erase(iter);
    592   }
    593 }
    594 
    595 TurnServer::Connection::Connection(const talk_base::SocketAddress& src,
    596                                    ProtocolType proto,
    597                                    talk_base::AsyncPacketSocket* socket)
    598     : src_(src),
    599       dst_(socket->GetRemoteAddress()),
    600       proto_(proto),
    601       socket_(socket) {
    602 }
    603 
    604 bool TurnServer::Connection::operator==(const Connection& c) const {
    605   return src_ == c.src_ && dst_ == c.dst_ && proto_ == c.proto_;
    606 }
    607 
    608 bool TurnServer::Connection::operator<(const Connection& c) const {
    609   return src_ < c.src_ || dst_ < c.dst_ || proto_ < c.proto_;
    610 }
    611 
    612 std::string TurnServer::Connection::ToString() const {
    613   const char* const kProtos[] = {
    614       "unknown", "udp", "tcp", "ssltcp"
    615   };
    616   std::ostringstream ost;
    617   ost << src_.ToString() << "-" << dst_.ToString() << ":"<< kProtos[proto_];
    618   return ost.str();
    619 }
    620 
    621 TurnServer::Allocation::Allocation(TurnServer* server,
    622                                    talk_base::Thread* thread,
    623                                    const Connection& conn,
    624                                    talk_base::AsyncPacketSocket* socket,
    625                                    const std::string& key)
    626     : server_(server),
    627       thread_(thread),
    628       conn_(conn),
    629       external_socket_(socket),
    630       key_(key) {
    631   external_socket_->SignalReadPacket.connect(
    632       this, &TurnServer::Allocation::OnExternalPacket);
    633 }
    634 
    635 TurnServer::Allocation::~Allocation() {
    636   for (ChannelList::iterator it = channels_.begin();
    637        it != channels_.end(); ++it) {
    638     delete *it;
    639   }
    640   for (PermissionList::iterator it = perms_.begin();
    641        it != perms_.end(); ++it) {
    642     delete *it;
    643   }
    644   thread_->Clear(this, MSG_TIMEOUT);
    645   LOG_J(LS_INFO, this) << "Allocation destroyed";
    646 }
    647 
    648 std::string TurnServer::Allocation::ToString() const {
    649   std::ostringstream ost;
    650   ost << "Alloc[" << conn_.ToString() << "]";
    651   return ost.str();
    652 }
    653 
    654 void TurnServer::Allocation::HandleTurnMessage(const TurnMessage* msg) {
    655   ASSERT(msg != NULL);
    656   switch (msg->type()) {
    657     case STUN_ALLOCATE_REQUEST:
    658       HandleAllocateRequest(msg);
    659       break;
    660     case TURN_REFRESH_REQUEST:
    661       HandleRefreshRequest(msg);
    662       break;
    663     case TURN_SEND_INDICATION:
    664       HandleSendIndication(msg);
    665       break;
    666     case TURN_CREATE_PERMISSION_REQUEST:
    667       HandleCreatePermissionRequest(msg);
    668       break;
    669     case TURN_CHANNEL_BIND_REQUEST:
    670       HandleChannelBindRequest(msg);
    671       break;
    672     default:
    673       // Not sure what to do with this, just eat it.
    674       LOG_J(LS_WARNING, this) << "Invalid TURN message type received: "
    675                               << msg->type();
    676   }
    677 }
    678 
    679 void TurnServer::Allocation::HandleAllocateRequest(const TurnMessage* msg) {
    680   // Copy the important info from the allocate request.
    681   transaction_id_ = msg->transaction_id();
    682   const StunByteStringAttribute* username_attr =
    683       msg->GetByteString(STUN_ATTR_USERNAME);
    684   ASSERT(username_attr != NULL);
    685   username_ = username_attr->GetString();
    686 
    687   // Figure out the lifetime and start the allocation timer.
    688   int lifetime_secs = ComputeLifetime(msg);
    689   thread_->PostDelayed(lifetime_secs * 1000, this, MSG_TIMEOUT);
    690 
    691   LOG_J(LS_INFO, this) << "Created allocation, lifetime=" << lifetime_secs;
    692 
    693   // We've already validated all the important bits; just send a response here.
    694   TurnMessage response;
    695   InitResponse(msg, &response);
    696 
    697   StunAddressAttribute* mapped_addr_attr =
    698       new StunXorAddressAttribute(STUN_ATTR_XOR_MAPPED_ADDRESS, conn_.src());
    699   StunAddressAttribute* relayed_addr_attr =
    700       new StunXorAddressAttribute(STUN_ATTR_XOR_RELAYED_ADDRESS,
    701           external_socket_->GetLocalAddress());
    702   StunUInt32Attribute* lifetime_attr =
    703       new StunUInt32Attribute(STUN_ATTR_LIFETIME, lifetime_secs);
    704   VERIFY(response.AddAttribute(mapped_addr_attr));
    705   VERIFY(response.AddAttribute(relayed_addr_attr));
    706   VERIFY(response.AddAttribute(lifetime_attr));
    707 
    708   SendResponse(&response);
    709 }
    710 
    711 void TurnServer::Allocation::HandleRefreshRequest(const TurnMessage* msg) {
    712   // Figure out the new lifetime.
    713   int lifetime_secs = ComputeLifetime(msg);
    714 
    715   // Reset the expiration timer.
    716   thread_->Clear(this, MSG_TIMEOUT);
    717   thread_->PostDelayed(lifetime_secs * 1000, this, MSG_TIMEOUT);
    718 
    719   LOG_J(LS_INFO, this) << "Refreshed allocation, lifetime=" << lifetime_secs;
    720 
    721   // Send a success response with a LIFETIME attribute.
    722   TurnMessage response;
    723   InitResponse(msg, &response);
    724 
    725   StunUInt32Attribute* lifetime_attr =
    726       new StunUInt32Attribute(STUN_ATTR_LIFETIME, lifetime_secs);
    727   VERIFY(response.AddAttribute(lifetime_attr));
    728 
    729   SendResponse(&response);
    730 }
    731 
    732 void TurnServer::Allocation::HandleSendIndication(const TurnMessage* msg) {
    733   // Check mandatory attributes.
    734   const StunByteStringAttribute* data_attr = msg->GetByteString(STUN_ATTR_DATA);
    735   const StunAddressAttribute* peer_attr =
    736       msg->GetAddress(STUN_ATTR_XOR_PEER_ADDRESS);
    737   if (!data_attr || !peer_attr) {
    738     LOG_J(LS_WARNING, this) << "Received invalid send indication";
    739     return;
    740   }
    741 
    742   // If a permission exists, send the data on to the peer.
    743   if (HasPermission(peer_attr->GetAddress().ipaddr())) {
    744     SendExternal(data_attr->bytes(), data_attr->length(),
    745                  peer_attr->GetAddress());
    746   } else {
    747     LOG_J(LS_WARNING, this) << "Received send indication without permission"
    748                             << "peer=" << peer_attr->GetAddress();
    749   }
    750 }
    751 
    752 void TurnServer::Allocation::HandleCreatePermissionRequest(
    753     const TurnMessage* msg) {
    754   // Check mandatory attributes.
    755   const StunAddressAttribute* peer_attr =
    756       msg->GetAddress(STUN_ATTR_XOR_PEER_ADDRESS);
    757   if (!peer_attr) {
    758     SendBadRequestResponse(msg);
    759     return;
    760   }
    761 
    762   // Add this permission.
    763   AddPermission(peer_attr->GetAddress().ipaddr());
    764 
    765   LOG_J(LS_INFO, this) << "Created permission, peer="
    766                        << peer_attr->GetAddress();
    767 
    768   // Send a success response.
    769   TurnMessage response;
    770   InitResponse(msg, &response);
    771   SendResponse(&response);
    772 }
    773 
    774 void TurnServer::Allocation::HandleChannelBindRequest(const TurnMessage* msg) {
    775   // Check mandatory attributes.
    776   const StunUInt32Attribute* channel_attr =
    777       msg->GetUInt32(STUN_ATTR_CHANNEL_NUMBER);
    778   const StunAddressAttribute* peer_attr =
    779       msg->GetAddress(STUN_ATTR_XOR_PEER_ADDRESS);
    780   if (!channel_attr || !peer_attr) {
    781     SendBadRequestResponse(msg);
    782     return;
    783   }
    784 
    785   // Check that channel id is valid.
    786   int channel_id = channel_attr->value() >> 16;
    787   if (channel_id < kMinChannelNumber || channel_id > kMaxChannelNumber) {
    788     SendBadRequestResponse(msg);
    789     return;
    790   }
    791 
    792   // Check that this channel id isn't bound to another transport address, and
    793   // that this transport address isn't bound to another channel id.
    794   Channel* channel1 = FindChannel(channel_id);
    795   Channel* channel2 = FindChannel(peer_attr->GetAddress());
    796   if (channel1 != channel2) {
    797     SendBadRequestResponse(msg);
    798     return;
    799   }
    800 
    801   // Add or refresh this channel.
    802   if (!channel1) {
    803     channel1 = new Channel(thread_, channel_id, peer_attr->GetAddress());
    804     channel1->SignalDestroyed.connect(this,
    805         &TurnServer::Allocation::OnChannelDestroyed);
    806     channels_.push_back(channel1);
    807   } else {
    808     channel1->Refresh();
    809   }
    810 
    811   // Channel binds also refresh permissions.
    812   AddPermission(peer_attr->GetAddress().ipaddr());
    813 
    814   LOG_J(LS_INFO, this) << "Bound channel, id=" << channel_id
    815                        << ", peer=" << peer_attr->GetAddress();
    816 
    817   // Send a success response.
    818   TurnMessage response;
    819   InitResponse(msg, &response);
    820   SendResponse(&response);
    821 }
    822 
    823 void TurnServer::Allocation::HandleChannelData(const char* data, size_t size) {
    824   // Extract the channel number from the data.
    825   uint16 channel_id = talk_base::GetBE16(data);
    826   Channel* channel = FindChannel(channel_id);
    827   if (channel) {
    828     // Send the data to the peer address.
    829     SendExternal(data + TURN_CHANNEL_HEADER_SIZE,
    830                  size - TURN_CHANNEL_HEADER_SIZE, channel->peer());
    831   } else {
    832     LOG_J(LS_WARNING, this) << "Received channel data for invalid channel, id="
    833                             << channel_id;
    834   }
    835 }
    836 
    837 void TurnServer::Allocation::OnExternalPacket(
    838     talk_base::AsyncPacketSocket* socket,
    839     const char* data, size_t size,
    840     const talk_base::SocketAddress& addr) {
    841   ASSERT(external_socket_.get() == socket);
    842   Channel* channel = FindChannel(addr);
    843   if (channel) {
    844     // There is a channel bound to this address. Send as a channel message.
    845     talk_base::ByteBuffer buf;
    846     buf.WriteUInt16(channel->id());
    847     buf.WriteUInt16(static_cast<uint16>(size));
    848     buf.WriteBytes(data, size);
    849     server_->Send(&conn_, buf);
    850   } else if (HasPermission(addr.ipaddr())) {
    851     // No channel, but a permission exists. Send as a data indication.
    852     TurnMessage msg;
    853     msg.SetType(TURN_DATA_INDICATION);
    854     msg.SetTransactionID(
    855         talk_base::CreateRandomString(kStunTransactionIdLength));
    856     VERIFY(msg.AddAttribute(new StunXorAddressAttribute(
    857         STUN_ATTR_XOR_PEER_ADDRESS, addr)));
    858     VERIFY(msg.AddAttribute(new StunByteStringAttribute(
    859         STUN_ATTR_DATA, data, size)));
    860     server_->SendStun(&conn_, &msg);
    861   } else {
    862     LOG_J(LS_WARNING, this) << "Received external packet without permission, "
    863                             << "peer=" << addr;
    864   }
    865 }
    866 
    867 int TurnServer::Allocation::ComputeLifetime(const TurnMessage* msg) {
    868   // Return the smaller of our default lifetime and the requested lifetime.
    869   uint32 lifetime = kDefaultAllocationTimeout / 1000;  // convert to seconds
    870   const StunUInt32Attribute* lifetime_attr = msg->GetUInt32(STUN_ATTR_LIFETIME);
    871   if (lifetime_attr && lifetime_attr->value() < lifetime) {
    872     lifetime = lifetime_attr->value();
    873   }
    874   return lifetime;
    875 }
    876 
    877 bool TurnServer::Allocation::HasPermission(const talk_base::IPAddress& addr) {
    878   return (FindPermission(addr) != NULL);
    879 }
    880 
    881 void TurnServer::Allocation::AddPermission(const talk_base::IPAddress& addr) {
    882   Permission* perm = FindPermission(addr);
    883   if (!perm) {
    884     perm = new Permission(thread_, addr);
    885     perm->SignalDestroyed.connect(
    886         this, &TurnServer::Allocation::OnPermissionDestroyed);
    887     perms_.push_back(perm);
    888   } else {
    889     perm->Refresh();
    890   }
    891 }
    892 
    893 TurnServer::Permission* TurnServer::Allocation::FindPermission(
    894     const talk_base::IPAddress& addr) const {
    895   for (PermissionList::const_iterator it = perms_.begin();
    896        it != perms_.end(); ++it) {
    897     if ((*it)->peer() == addr)
    898       return *it;
    899   }
    900   return NULL;
    901 }
    902 
    903 TurnServer::Channel* TurnServer::Allocation::FindChannel(int channel_id) const {
    904   for (ChannelList::const_iterator it = channels_.begin();
    905        it != channels_.end(); ++it) {
    906     if ((*it)->id() == channel_id)
    907       return *it;
    908   }
    909   return NULL;
    910 }
    911 
    912 TurnServer::Channel* TurnServer::Allocation::FindChannel(
    913     const talk_base::SocketAddress& addr) const {
    914   for (ChannelList::const_iterator it = channels_.begin();
    915        it != channels_.end(); ++it) {
    916     if ((*it)->peer() == addr)
    917       return *it;
    918   }
    919   return NULL;
    920 }
    921 
    922 void TurnServer::Allocation::SendResponse(TurnMessage* msg) {
    923   // Success responses always have M-I.
    924   msg->AddMessageIntegrity(key_);
    925   server_->SendStun(&conn_, msg);
    926 }
    927 
    928 void TurnServer::Allocation::SendBadRequestResponse(const TurnMessage* req) {
    929   SendErrorResponse(req, STUN_ERROR_BAD_REQUEST, STUN_ERROR_REASON_BAD_REQUEST);
    930 }
    931 
    932 void TurnServer::Allocation::SendErrorResponse(const TurnMessage* req, int code,
    933                                        const std::string& reason) {
    934   server_->SendErrorResponse(&conn_, req, code, reason);
    935 }
    936 
    937 void TurnServer::Allocation::SendExternal(const void* data, size_t size,
    938                                   const talk_base::SocketAddress& peer) {
    939   external_socket_->SendTo(data, size, peer);
    940 }
    941 
    942 void TurnServer::Allocation::OnMessage(talk_base::Message* msg) {
    943   ASSERT(msg->message_id == MSG_TIMEOUT);
    944   SignalDestroyed(this);
    945   delete this;
    946 }
    947 
    948 void TurnServer::Allocation::OnPermissionDestroyed(Permission* perm) {
    949   PermissionList::iterator it = std::find(perms_.begin(), perms_.end(), perm);
    950   ASSERT(it != perms_.end());
    951   perms_.erase(it);
    952 }
    953 
    954 void TurnServer::Allocation::OnChannelDestroyed(Channel* channel) {
    955   ChannelList::iterator it =
    956       std::find(channels_.begin(), channels_.end(), channel);
    957   ASSERT(it != channels_.end());
    958   channels_.erase(it);
    959 }
    960 
    961 TurnServer::Permission::Permission(talk_base::Thread* thread,
    962                                    const talk_base::IPAddress& peer)
    963     : thread_(thread), peer_(peer) {
    964   Refresh();
    965 }
    966 
    967 TurnServer::Permission::~Permission() {
    968   thread_->Clear(this, MSG_TIMEOUT);
    969 }
    970 
    971 void TurnServer::Permission::Refresh() {
    972   thread_->Clear(this, MSG_TIMEOUT);
    973   thread_->PostDelayed(kPermissionTimeout, this, MSG_TIMEOUT);
    974 }
    975 
    976 void TurnServer::Permission::OnMessage(talk_base::Message* msg) {
    977   ASSERT(msg->message_id == MSG_TIMEOUT);
    978   SignalDestroyed(this);
    979   delete this;
    980 }
    981 
    982 TurnServer::Channel::Channel(talk_base::Thread* thread, int id,
    983                              const talk_base::SocketAddress& peer)
    984     : thread_(thread), id_(id), peer_(peer) {
    985   Refresh();
    986 }
    987 
    988 TurnServer::Channel::~Channel() {
    989   thread_->Clear(this, MSG_TIMEOUT);
    990 }
    991 
    992 void TurnServer::Channel::Refresh() {
    993   thread_->Clear(this, MSG_TIMEOUT);
    994   thread_->PostDelayed(kChannelTimeout, this, MSG_TIMEOUT);
    995 }
    996 
    997 void TurnServer::Channel::OnMessage(talk_base::Message* msg) {
    998   ASSERT(msg->message_id == MSG_TIMEOUT);
    999   SignalDestroyed(this);
   1000   delete this;
   1001 }
   1002 
   1003 }  // namespace cricket
   1004