Home | History | Annotate | Download | only in base
      1 /*
      2  * libjingle
      3  * Copyright 2009, Google, Inc.
      4  *
      5  * Redistribution and use in source and binary forms, with or without
      6  * modification, are permitted provided that the following conditions are met:
      7  *
      8  *  1. Redistributions of source code must retain the above copyright notice,
      9  *     this list of conditions and the following disclaimer.
     10  *  2. Redistributions in binary form must reproduce the above copyright notice,
     11  *     this list of conditions and the following disclaimer in the documentation
     12  *     and/or other materials provided with the distribution.
     13  *  3. The name of the author may not be used to endorse or promote products
     14  *     derived from this software without specific prior written permission.
     15  *
     16  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
     17  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
     18  * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
     19  * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     20  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
     21  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
     22  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
     23  * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
     24  * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
     25  * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     26  */
     27 
     28 #ifndef TALK_P2P_BASE_FAKESESSION_H_
     29 #define TALK_P2P_BASE_FAKESESSION_H_
     30 
     31 #include <map>
     32 #include <string>
     33 #include <vector>
     34 
     35 #include "talk/base/buffer.h"
     36 #include "talk/base/sigslot.h"
     37 #include "talk/base/sslfingerprint.h"
     38 #include "talk/base/messagequeue.h"
     39 #include "talk/p2p/base/session.h"
     40 #include "talk/p2p/base/transport.h"
     41 #include "talk/p2p/base/transportchannel.h"
     42 #include "talk/p2p/base/transportchannelimpl.h"
     43 
     44 namespace cricket {
     45 
     46 class FakeTransport;
     47 
     48 struct PacketMessageData : public talk_base::MessageData {
     49   PacketMessageData(const char* data, size_t len) : packet(data, len) {
     50   }
     51   talk_base::Buffer packet;
     52 };
     53 
     54 // Fake transport channel class, which can be passed to anything that needs a
     55 // transport channel. Can be informed of another FakeTransportChannel via
     56 // SetDestination.
     57 class FakeTransportChannel : public TransportChannelImpl,
     58                              public talk_base::MessageHandler {
     59  public:
     60   explicit FakeTransportChannel(Transport* transport,
     61                                 const std::string& content_name,
     62                                 int component)
     63       : TransportChannelImpl(content_name, component),
     64         transport_(transport),
     65         dest_(NULL),
     66         state_(STATE_INIT),
     67         async_(false),
     68         identity_(NULL),
     69         do_dtls_(false),
     70         role_(ICEROLE_UNKNOWN),
     71         tiebreaker_(0),
     72         ice_proto_(ICEPROTO_HYBRID),
     73         remote_ice_mode_(ICEMODE_FULL),
     74         dtls_fingerprint_("", NULL, 0) {
     75   }
     76   ~FakeTransportChannel() {
     77     Reset();
     78   }
     79 
     80   uint64 IceTiebreaker() const { return tiebreaker_; }
     81   TransportProtocol protocol() const { return ice_proto_; }
     82   IceMode remote_ice_mode() const { return remote_ice_mode_; }
     83   const std::string& ice_ufrag() const { return ice_ufrag_; }
     84   const std::string& ice_pwd() const { return ice_pwd_; }
     85   const std::string& remote_ice_ufrag() const { return remote_ice_ufrag_; }
     86   const std::string& remote_ice_pwd() const { return remote_ice_pwd_; }
     87   const talk_base::SSLFingerprint& dtls_fingerprint() const {
     88     return dtls_fingerprint_;
     89   }
     90 
     91   void SetAsync(bool async) {
     92     async_ = async;
     93   }
     94 
     95   virtual Transport* GetTransport() {
     96     return transport_;
     97   }
     98 
     99   virtual void SetIceRole(IceRole role) { role_ = role; }
    100   virtual IceRole GetIceRole() const { return role_; }
    101   virtual void SetIceTiebreaker(uint64 tiebreaker) { tiebreaker_ = tiebreaker; }
    102   virtual void SetIceProtocolType(IceProtocolType type) { ice_proto_ = type; }
    103   virtual void SetIceCredentials(const std::string& ice_ufrag,
    104                                  const std::string& ice_pwd) {
    105     ice_ufrag_ = ice_ufrag;
    106     ice_pwd_ = ice_pwd;
    107   }
    108   virtual void SetRemoteIceCredentials(const std::string& ice_ufrag,
    109                                        const std::string& ice_pwd) {
    110     remote_ice_ufrag_ = ice_ufrag;
    111     remote_ice_pwd_ = ice_pwd;
    112   }
    113 
    114   virtual void SetRemoteIceMode(IceMode mode) { remote_ice_mode_ = mode; }
    115   virtual bool SetRemoteFingerprint(const std::string& alg, const uint8* digest,
    116                                     size_t digest_len) {
    117     dtls_fingerprint_ = talk_base::SSLFingerprint(alg, digest, digest_len);
    118     return true;
    119   }
    120 
    121   virtual void Connect() {
    122     if (state_ == STATE_INIT) {
    123       state_ = STATE_CONNECTING;
    124     }
    125   }
    126   virtual void Reset() {
    127     if (state_ != STATE_INIT) {
    128       state_ = STATE_INIT;
    129       if (dest_) {
    130         dest_->state_ = STATE_INIT;
    131         dest_->dest_ = NULL;
    132         dest_ = NULL;
    133       }
    134     }
    135   }
    136 
    137   void SetWritable(bool writable) {
    138     set_writable(writable);
    139   }
    140 
    141   void SetDestination(FakeTransportChannel* dest) {
    142     if (state_ == STATE_CONNECTING && dest) {
    143       // This simulates the delivery of candidates.
    144       dest_ = dest;
    145       dest_->dest_ = this;
    146       if (identity_ && dest_->identity_) {
    147         do_dtls_ = true;
    148         dest_->do_dtls_ = true;
    149         NegotiateSrtpCiphers();
    150       }
    151       state_ = STATE_CONNECTED;
    152       dest_->state_ = STATE_CONNECTED;
    153       set_writable(true);
    154       dest_->set_writable(true);
    155     } else if (state_ == STATE_CONNECTED && !dest) {
    156       // Simulates loss of connectivity, by asymmetrically forgetting dest_.
    157       dest_ = NULL;
    158       state_ = STATE_CONNECTING;
    159       set_writable(false);
    160     }
    161   }
    162 
    163   virtual int SendPacket(const char* data, size_t len, int flags) {
    164     if (state_ != STATE_CONNECTED) {
    165       return -1;
    166     }
    167 
    168     if (flags != PF_SRTP_BYPASS && flags != 0) {
    169       return -1;
    170     }
    171 
    172     PacketMessageData* packet = new PacketMessageData(data, len);
    173     if (async_) {
    174       talk_base::Thread::Current()->Post(this, 0, packet);
    175     } else {
    176       talk_base::Thread::Current()->Send(this, 0, packet);
    177     }
    178     return static_cast<int>(len);
    179   }
    180   virtual int SetOption(talk_base::Socket::Option opt, int value) {
    181     return true;
    182   }
    183   virtual int GetError() {
    184     return 0;
    185   }
    186 
    187   virtual void OnSignalingReady() {
    188   }
    189   virtual void OnCandidate(const Candidate& candidate) {
    190   }
    191 
    192   virtual void OnMessage(talk_base::Message* msg) {
    193     PacketMessageData* data = static_cast<PacketMessageData*>(
    194         msg->pdata);
    195     dest_->SignalReadPacket(dest_, data->packet.data(),
    196                             data->packet.length(), 0);
    197     delete data;
    198   }
    199 
    200   bool SetLocalIdentity(talk_base::SSLIdentity* identity) {
    201     identity_ = identity;
    202 
    203     return true;
    204   }
    205 
    206   bool IsDtlsActive() const {
    207     return do_dtls_;
    208   }
    209 
    210   bool SetSrtpCiphers(const std::vector<std::string>& ciphers) {
    211     srtp_ciphers_ = ciphers;
    212     return true;
    213   }
    214 
    215   virtual bool GetSrtpCipher(std::string* cipher) {
    216     if (!chosen_srtp_cipher_.empty()) {
    217       *cipher = chosen_srtp_cipher_;
    218       return true;
    219     }
    220     return false;
    221   }
    222 
    223   virtual bool ExportKeyingMaterial(const std::string& label,
    224                                     const uint8* context,
    225                                     size_t context_len,
    226                                     bool use_context,
    227                                     uint8* result,
    228                                     size_t result_len) {
    229     if (!chosen_srtp_cipher_.empty()) {
    230       memset(result, 0xff, result_len);
    231       return true;
    232     }
    233 
    234     return false;
    235   }
    236 
    237   virtual void NegotiateSrtpCiphers() {
    238     for (std::vector<std::string>::const_iterator it1 = srtp_ciphers_.begin();
    239         it1 != srtp_ciphers_.end(); ++it1) {
    240       for (std::vector<std::string>::const_iterator it2 =
    241               dest_->srtp_ciphers_.begin();
    242           it2 != dest_->srtp_ciphers_.end(); ++it2) {
    243         if (*it1 == *it2) {
    244           chosen_srtp_cipher_ = *it1;
    245           dest_->chosen_srtp_cipher_ = *it2;
    246           return;
    247         }
    248       }
    249     }
    250   }
    251 
    252   virtual bool GetStats(ConnectionInfos* infos) OVERRIDE {
    253     ConnectionInfo info;
    254     infos->clear();
    255     infos->push_back(info);
    256     return true;
    257   }
    258 
    259  private:
    260   enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED };
    261   Transport* transport_;
    262   FakeTransportChannel* dest_;
    263   State state_;
    264   bool async_;
    265   talk_base::SSLIdentity* identity_;
    266   bool do_dtls_;
    267   std::vector<std::string> srtp_ciphers_;
    268   std::string chosen_srtp_cipher_;
    269   IceRole role_;
    270   uint64 tiebreaker_;
    271   IceProtocolType ice_proto_;
    272   std::string ice_ufrag_;
    273   std::string ice_pwd_;
    274   std::string remote_ice_ufrag_;
    275   std::string remote_ice_pwd_;
    276   IceMode remote_ice_mode_;
    277   talk_base::SSLFingerprint dtls_fingerprint_;
    278 };
    279 
    280 // Fake transport class, which can be passed to anything that needs a Transport.
    281 // Can be informed of another FakeTransport via SetDestination (low-tech way
    282 // of doing candidates)
    283 class FakeTransport : public Transport {
    284  public:
    285   typedef std::map<int, FakeTransportChannel*> ChannelMap;
    286   FakeTransport(talk_base::Thread* signaling_thread,
    287                 talk_base::Thread* worker_thread,
    288                 const std::string& content_name,
    289                 PortAllocator* alllocator = NULL)
    290       : Transport(signaling_thread, worker_thread,
    291                   content_name, "test_type", NULL),
    292       dest_(NULL),
    293       async_(false),
    294       identity_(NULL) {
    295   }
    296   ~FakeTransport() {
    297     DestroyAllChannels();
    298   }
    299 
    300   const ChannelMap& channels() const { return channels_; }
    301 
    302   void SetAsync(bool async) { async_ = async; }
    303   void SetDestination(FakeTransport* dest) {
    304     dest_ = dest;
    305     for (ChannelMap::iterator it = channels_.begin(); it != channels_.end();
    306          ++it) {
    307       it->second->SetLocalIdentity(identity_);
    308       SetChannelDestination(it->first, it->second);
    309     }
    310   }
    311 
    312   void SetWritable(bool writable) {
    313     for (ChannelMap::iterator it = channels_.begin(); it != channels_.end();
    314          ++it) {
    315       it->second->SetWritable(writable);
    316     }
    317   }
    318 
    319   void set_identity(talk_base::SSLIdentity* identity) {
    320     identity_ = identity;
    321   }
    322 
    323   using Transport::local_description;
    324   using Transport::remote_description;
    325 
    326  protected:
    327   virtual TransportChannelImpl* CreateTransportChannel(int component) {
    328     if (channels_.find(component) != channels_.end()) {
    329       return NULL;
    330     }
    331     FakeTransportChannel* channel =
    332         new FakeTransportChannel(this, content_name(), component);
    333     channel->SetAsync(async_);
    334     SetChannelDestination(component, channel);
    335     channels_[component] = channel;
    336     return channel;
    337   }
    338   virtual void DestroyTransportChannel(TransportChannelImpl* channel) {
    339     channels_.erase(channel->component());
    340     delete channel;
    341   }
    342 
    343  private:
    344   FakeTransportChannel* GetFakeChannel(int component) {
    345     ChannelMap::iterator it = channels_.find(component);
    346     return (it != channels_.end()) ? it->second : NULL;
    347   }
    348   void SetChannelDestination(int component,
    349                              FakeTransportChannel* channel) {
    350     FakeTransportChannel* dest_channel = NULL;
    351     if (dest_) {
    352       dest_channel = dest_->GetFakeChannel(component);
    353       if (dest_channel) {
    354         dest_channel->SetLocalIdentity(dest_->identity_);
    355       }
    356     }
    357     channel->SetDestination(dest_channel);
    358   }
    359 
    360   // Note, this is distinct from the Channel map owned by Transport.
    361   // This map just tracks the FakeTransportChannels created by this class.
    362   ChannelMap channels_;
    363   FakeTransport* dest_;
    364   bool async_;
    365   talk_base::SSLIdentity* identity_;
    366 };
    367 
    368 // Fake session class, which can be passed into a BaseChannel object for
    369 // test purposes. Can be connected to other FakeSessions via Connect().
    370 class FakeSession : public BaseSession {
    371  public:
    372   explicit FakeSession()
    373       : BaseSession(talk_base::Thread::Current(),
    374                     talk_base::Thread::Current(),
    375                     NULL, "", "", true),
    376       fail_create_channel_(false) {
    377   }
    378   explicit FakeSession(bool initiator)
    379       : BaseSession(talk_base::Thread::Current(),
    380                     talk_base::Thread::Current(),
    381                     NULL, "", "", initiator),
    382       fail_create_channel_(false) {
    383   }
    384 
    385   FakeTransport* GetTransport(const std::string& content_name) {
    386     return static_cast<FakeTransport*>(
    387         BaseSession::GetTransport(content_name));
    388   }
    389 
    390   void Connect(FakeSession* dest) {
    391     // Simulate the exchange of candidates.
    392     CompleteNegotiation();
    393     dest->CompleteNegotiation();
    394     for (TransportMap::const_iterator it = transport_proxies().begin();
    395         it != transport_proxies().end(); ++it) {
    396       static_cast<FakeTransport*>(it->second->impl())->SetDestination(
    397           dest->GetTransport(it->first));
    398     }
    399   }
    400 
    401   virtual TransportChannel* CreateChannel(
    402       const std::string& content_name,
    403       const std::string& channel_name,
    404       int component) {
    405     if (fail_create_channel_) {
    406       return NULL;
    407     }
    408     return BaseSession::CreateChannel(content_name, channel_name, component);
    409   }
    410 
    411   void set_fail_channel_creation(bool fail_channel_creation) {
    412     fail_create_channel_ = fail_channel_creation;
    413   }
    414 
    415   // TODO: Hoist this into Session when we re-work the Session code.
    416   void set_ssl_identity(talk_base::SSLIdentity* identity) {
    417     for (TransportMap::const_iterator it = transport_proxies().begin();
    418         it != transport_proxies().end(); ++it) {
    419       // We know that we have a FakeTransport*
    420 
    421       static_cast<FakeTransport*>(it->second->impl())->set_identity
    422           (identity);
    423     }
    424   }
    425 
    426  protected:
    427   virtual Transport* CreateTransport(const std::string& content_name) {
    428     return new FakeTransport(signaling_thread(), worker_thread(), content_name);
    429   }
    430 
    431   void CompleteNegotiation() {
    432     for (TransportMap::const_iterator it = transport_proxies().begin();
    433         it != transport_proxies().end(); ++it) {
    434       it->second->CompleteNegotiation();
    435       it->second->ConnectChannels();
    436     }
    437   }
    438 
    439  private:
    440   bool fail_create_channel_;
    441 };
    442 
    443 }  // namespace cricket
    444 
    445 #endif  // TALK_P2P_BASE_FAKESESSION_H_
    446