Home | History | Annotate | Download | only in base
      1 /*
      2  * libjingle
      3  * Copyright 2009, Google, Inc.
      4  *
      5  * Redistribution and use in source and binary forms, with or without
      6  * modification, are permitted provided that the following conditions are met:
      7  *
      8  *  1. Redistributions of source code must retain the above copyright notice,
      9  *     this list of conditions and the following disclaimer.
     10  *  2. Redistributions in binary form must reproduce the above copyright notice,
     11  *     this list of conditions and the following disclaimer in the documentation
     12  *     and/or other materials provided with the distribution.
     13  *  3. The name of the author may not be used to endorse or promote products
     14  *     derived from this software without specific prior written permission.
     15  *
     16  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
     17  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
     18  * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
     19  * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     20  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
     21  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
     22  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
     23  * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
     24  * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
     25  * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     26  */
     27 
     28 #ifndef TALK_P2P_BASE_FAKESESSION_H_
     29 #define TALK_P2P_BASE_FAKESESSION_H_
     30 
     31 #include <map>
     32 #include <string>
     33 #include <vector>
     34 
     35 #include "talk/base/buffer.h"
     36 #include "talk/base/fakesslidentity.h"
     37 #include "talk/base/sigslot.h"
     38 #include "talk/base/sslfingerprint.h"
     39 #include "talk/base/messagequeue.h"
     40 #include "talk/p2p/base/session.h"
     41 #include "talk/p2p/base/transport.h"
     42 #include "talk/p2p/base/transportchannel.h"
     43 #include "talk/p2p/base/transportchannelimpl.h"
     44 
     45 namespace cricket {
     46 
     47 class FakeTransport;
     48 
     49 struct PacketMessageData : public talk_base::MessageData {
     50   PacketMessageData(const char* data, size_t len) : packet(data, len) {
     51   }
     52   talk_base::Buffer packet;
     53 };
     54 
     55 // Fake transport channel class, which can be passed to anything that needs a
     56 // transport channel. Can be informed of another FakeTransportChannel via
     57 // SetDestination.
     58 class FakeTransportChannel : public TransportChannelImpl,
     59                              public talk_base::MessageHandler {
     60  public:
     61   explicit FakeTransportChannel(Transport* transport,
     62                                 const std::string& content_name,
     63                                 int component)
     64       : TransportChannelImpl(content_name, component),
     65         transport_(transport),
     66         dest_(NULL),
     67         state_(STATE_INIT),
     68         async_(false),
     69         identity_(NULL),
     70         do_dtls_(false),
     71         role_(ICEROLE_UNKNOWN),
     72         tiebreaker_(0),
     73         ice_proto_(ICEPROTO_HYBRID),
     74         remote_ice_mode_(ICEMODE_FULL),
     75         dtls_fingerprint_("", NULL, 0),
     76         ssl_role_(talk_base::SSL_CLIENT) {
     77   }
     78   ~FakeTransportChannel() {
     79     Reset();
     80   }
     81 
     82   uint64 IceTiebreaker() const { return tiebreaker_; }
     83   TransportProtocol protocol() const { return ice_proto_; }
     84   IceMode remote_ice_mode() const { return remote_ice_mode_; }
     85   const std::string& ice_ufrag() const { return ice_ufrag_; }
     86   const std::string& ice_pwd() const { return ice_pwd_; }
     87   const std::string& remote_ice_ufrag() const { return remote_ice_ufrag_; }
     88   const std::string& remote_ice_pwd() const { return remote_ice_pwd_; }
     89   const talk_base::SSLFingerprint& dtls_fingerprint() const {
     90     return dtls_fingerprint_;
     91   }
     92 
     93   void SetAsync(bool async) {
     94     async_ = async;
     95   }
     96 
     97   virtual Transport* GetTransport() {
     98     return transport_;
     99   }
    100 
    101   virtual void SetIceRole(IceRole role) { role_ = role; }
    102   virtual IceRole GetIceRole() const { return role_; }
    103   virtual void SetIceTiebreaker(uint64 tiebreaker) { tiebreaker_ = tiebreaker; }
    104   virtual void SetIceProtocolType(IceProtocolType type) { ice_proto_ = type; }
    105   virtual void SetIceCredentials(const std::string& ice_ufrag,
    106                                  const std::string& ice_pwd) {
    107     ice_ufrag_ = ice_ufrag;
    108     ice_pwd_ = ice_pwd;
    109   }
    110   virtual void SetRemoteIceCredentials(const std::string& ice_ufrag,
    111                                        const std::string& ice_pwd) {
    112     remote_ice_ufrag_ = ice_ufrag;
    113     remote_ice_pwd_ = ice_pwd;
    114   }
    115 
    116   virtual void SetRemoteIceMode(IceMode mode) { remote_ice_mode_ = mode; }
    117   virtual bool SetRemoteFingerprint(const std::string& alg, const uint8* digest,
    118                                     size_t digest_len) {
    119     dtls_fingerprint_ = talk_base::SSLFingerprint(alg, digest, digest_len);
    120     return true;
    121   }
    122   virtual bool SetSslRole(talk_base::SSLRole role) {
    123     ssl_role_ = role;
    124     return true;
    125   }
    126   virtual bool GetSslRole(talk_base::SSLRole* role) const {
    127     *role = ssl_role_;
    128     return true;
    129   }
    130 
    131   virtual void Connect() {
    132     if (state_ == STATE_INIT) {
    133       state_ = STATE_CONNECTING;
    134     }
    135   }
    136   virtual void Reset() {
    137     if (state_ != STATE_INIT) {
    138       state_ = STATE_INIT;
    139       if (dest_) {
    140         dest_->state_ = STATE_INIT;
    141         dest_->dest_ = NULL;
    142         dest_ = NULL;
    143       }
    144     }
    145   }
    146 
    147   void SetWritable(bool writable) {
    148     set_writable(writable);
    149   }
    150 
    151   void SetDestination(FakeTransportChannel* dest) {
    152     if (state_ == STATE_CONNECTING && dest) {
    153       // This simulates the delivery of candidates.
    154       dest_ = dest;
    155       dest_->dest_ = this;
    156       if (identity_ && dest_->identity_) {
    157         do_dtls_ = true;
    158         dest_->do_dtls_ = true;
    159         NegotiateSrtpCiphers();
    160       }
    161       state_ = STATE_CONNECTED;
    162       dest_->state_ = STATE_CONNECTED;
    163       set_writable(true);
    164       dest_->set_writable(true);
    165     } else if (state_ == STATE_CONNECTED && !dest) {
    166       // Simulates loss of connectivity, by asymmetrically forgetting dest_.
    167       dest_ = NULL;
    168       state_ = STATE_CONNECTING;
    169       set_writable(false);
    170     }
    171   }
    172 
    173   virtual int SendPacket(const char* data, size_t len,
    174                          talk_base::DiffServCodePoint dscp, int flags) {
    175     if (state_ != STATE_CONNECTED) {
    176       return -1;
    177     }
    178 
    179     if (flags != PF_SRTP_BYPASS && flags != 0) {
    180       return -1;
    181     }
    182 
    183     PacketMessageData* packet = new PacketMessageData(data, len);
    184     if (async_) {
    185       talk_base::Thread::Current()->Post(this, 0, packet);
    186     } else {
    187       talk_base::Thread::Current()->Send(this, 0, packet);
    188     }
    189     return static_cast<int>(len);
    190   }
    191   virtual int SetOption(talk_base::Socket::Option opt, int value) {
    192     return true;
    193   }
    194   virtual int GetError() {
    195     return 0;
    196   }
    197 
    198   virtual void OnSignalingReady() {
    199   }
    200   virtual void OnCandidate(const Candidate& candidate) {
    201   }
    202 
    203   virtual void OnMessage(talk_base::Message* msg) {
    204     PacketMessageData* data = static_cast<PacketMessageData*>(
    205         msg->pdata);
    206     dest_->SignalReadPacket(dest_, data->packet.data(),
    207                             data->packet.length(),
    208                             talk_base::CreatePacketTime(0), 0);
    209     delete data;
    210   }
    211 
    212   bool SetLocalIdentity(talk_base::SSLIdentity* identity) {
    213     identity_ = identity;
    214     return true;
    215   }
    216 
    217 
    218   void SetRemoteCertificate(talk_base::FakeSSLCertificate* cert) {
    219     remote_cert_ = cert;
    220   }
    221 
    222   virtual bool IsDtlsActive() const {
    223     return do_dtls_;
    224   }
    225 
    226   virtual bool SetSrtpCiphers(const std::vector<std::string>& ciphers) {
    227     srtp_ciphers_ = ciphers;
    228     return true;
    229   }
    230 
    231   virtual bool GetSrtpCipher(std::string* cipher) {
    232     if (!chosen_srtp_cipher_.empty()) {
    233       *cipher = chosen_srtp_cipher_;
    234       return true;
    235     }
    236     return false;
    237   }
    238 
    239   virtual bool GetLocalIdentity(talk_base::SSLIdentity** identity) const {
    240     if (!identity_)
    241       return false;
    242 
    243     *identity = identity_->GetReference();
    244     return true;
    245   }
    246 
    247   virtual bool GetRemoteCertificate(talk_base::SSLCertificate** cert) const {
    248     if (!remote_cert_)
    249       return false;
    250 
    251     *cert = remote_cert_->GetReference();
    252     return true;
    253   }
    254 
    255   virtual bool ExportKeyingMaterial(const std::string& label,
    256                                     const uint8* context,
    257                                     size_t context_len,
    258                                     bool use_context,
    259                                     uint8* result,
    260                                     size_t result_len) {
    261     if (!chosen_srtp_cipher_.empty()) {
    262       memset(result, 0xff, result_len);
    263       return true;
    264     }
    265 
    266     return false;
    267   }
    268 
    269   virtual void NegotiateSrtpCiphers() {
    270     for (std::vector<std::string>::const_iterator it1 = srtp_ciphers_.begin();
    271         it1 != srtp_ciphers_.end(); ++it1) {
    272       for (std::vector<std::string>::const_iterator it2 =
    273               dest_->srtp_ciphers_.begin();
    274           it2 != dest_->srtp_ciphers_.end(); ++it2) {
    275         if (*it1 == *it2) {
    276           chosen_srtp_cipher_ = *it1;
    277           dest_->chosen_srtp_cipher_ = *it2;
    278           return;
    279         }
    280       }
    281     }
    282   }
    283 
    284   virtual bool GetStats(ConnectionInfos* infos) OVERRIDE {
    285     ConnectionInfo info;
    286     infos->clear();
    287     infos->push_back(info);
    288     return true;
    289   }
    290 
    291  private:
    292   enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED };
    293   Transport* transport_;
    294   FakeTransportChannel* dest_;
    295   State state_;
    296   bool async_;
    297   talk_base::SSLIdentity* identity_;
    298   talk_base::FakeSSLCertificate* remote_cert_;
    299   bool do_dtls_;
    300   std::vector<std::string> srtp_ciphers_;
    301   std::string chosen_srtp_cipher_;
    302   IceRole role_;
    303   uint64 tiebreaker_;
    304   IceProtocolType ice_proto_;
    305   std::string ice_ufrag_;
    306   std::string ice_pwd_;
    307   std::string remote_ice_ufrag_;
    308   std::string remote_ice_pwd_;
    309   IceMode remote_ice_mode_;
    310   talk_base::SSLFingerprint dtls_fingerprint_;
    311   talk_base::SSLRole ssl_role_;
    312 };
    313 
    314 // Fake transport class, which can be passed to anything that needs a Transport.
    315 // Can be informed of another FakeTransport via SetDestination (low-tech way
    316 // of doing candidates)
    317 class FakeTransport : public Transport {
    318  public:
    319   typedef std::map<int, FakeTransportChannel*> ChannelMap;
    320   FakeTransport(talk_base::Thread* signaling_thread,
    321                 talk_base::Thread* worker_thread,
    322                 const std::string& content_name,
    323                 PortAllocator* alllocator = NULL)
    324       : Transport(signaling_thread, worker_thread,
    325                   content_name, "test_type", NULL),
    326       dest_(NULL),
    327       async_(false),
    328       identity_(NULL) {
    329   }
    330   ~FakeTransport() {
    331     DestroyAllChannels();
    332   }
    333 
    334   const ChannelMap& channels() const { return channels_; }
    335 
    336   void SetAsync(bool async) { async_ = async; }
    337   void SetDestination(FakeTransport* dest) {
    338     dest_ = dest;
    339     for (ChannelMap::iterator it = channels_.begin(); it != channels_.end();
    340          ++it) {
    341       it->second->SetLocalIdentity(identity_);
    342       SetChannelDestination(it->first, it->second);
    343     }
    344   }
    345 
    346   void SetWritable(bool writable) {
    347     for (ChannelMap::iterator it = channels_.begin(); it != channels_.end();
    348          ++it) {
    349       it->second->SetWritable(writable);
    350     }
    351   }
    352 
    353   void set_identity(talk_base::SSLIdentity* identity) {
    354     identity_ = identity;
    355   }
    356 
    357   using Transport::local_description;
    358   using Transport::remote_description;
    359 
    360  protected:
    361   virtual TransportChannelImpl* CreateTransportChannel(int component) {
    362     if (channels_.find(component) != channels_.end()) {
    363       return NULL;
    364     }
    365     FakeTransportChannel* channel =
    366         new FakeTransportChannel(this, content_name(), component);
    367     channel->SetAsync(async_);
    368     SetChannelDestination(component, channel);
    369     channels_[component] = channel;
    370     return channel;
    371   }
    372   virtual void DestroyTransportChannel(TransportChannelImpl* channel) {
    373     channels_.erase(channel->component());
    374     delete channel;
    375   }
    376   virtual void SetIdentity_w(talk_base::SSLIdentity* identity) {
    377     identity_ = identity;
    378   }
    379   virtual bool GetIdentity_w(talk_base::SSLIdentity** identity) {
    380     if (!identity_)
    381       return false;
    382 
    383     *identity = identity_->GetReference();
    384     return true;
    385   }
    386 
    387  private:
    388   FakeTransportChannel* GetFakeChannel(int component) {
    389     ChannelMap::iterator it = channels_.find(component);
    390     return (it != channels_.end()) ? it->second : NULL;
    391   }
    392   void SetChannelDestination(int component,
    393                              FakeTransportChannel* channel) {
    394     FakeTransportChannel* dest_channel = NULL;
    395     if (dest_) {
    396       dest_channel = dest_->GetFakeChannel(component);
    397       if (dest_channel) {
    398         dest_channel->SetLocalIdentity(dest_->identity_);
    399       }
    400     }
    401     channel->SetDestination(dest_channel);
    402   }
    403 
    404   // Note, this is distinct from the Channel map owned by Transport.
    405   // This map just tracks the FakeTransportChannels created by this class.
    406   ChannelMap channels_;
    407   FakeTransport* dest_;
    408   bool async_;
    409   talk_base::SSLIdentity* identity_;
    410 };
    411 
    412 // Fake session class, which can be passed into a BaseChannel object for
    413 // test purposes. Can be connected to other FakeSessions via Connect().
    414 class FakeSession : public BaseSession {
    415  public:
    416   explicit FakeSession()
    417       : BaseSession(talk_base::Thread::Current(),
    418                     talk_base::Thread::Current(),
    419                     NULL, "", "", true),
    420       fail_create_channel_(false) {
    421   }
    422   explicit FakeSession(bool initiator)
    423       : BaseSession(talk_base::Thread::Current(),
    424                     talk_base::Thread::Current(),
    425                     NULL, "", "", initiator),
    426       fail_create_channel_(false) {
    427   }
    428   FakeSession(talk_base::Thread* worker_thread, bool initiator)
    429       : BaseSession(talk_base::Thread::Current(),
    430                     worker_thread,
    431                     NULL, "", "", initiator),
    432       fail_create_channel_(false) {
    433   }
    434 
    435   FakeTransport* GetTransport(const std::string& content_name) {
    436     return static_cast<FakeTransport*>(
    437         BaseSession::GetTransport(content_name));
    438   }
    439 
    440   void Connect(FakeSession* dest) {
    441     // Simulate the exchange of candidates.
    442     CompleteNegotiation();
    443     dest->CompleteNegotiation();
    444     for (TransportMap::const_iterator it = transport_proxies().begin();
    445         it != transport_proxies().end(); ++it) {
    446       static_cast<FakeTransport*>(it->second->impl())->SetDestination(
    447           dest->GetTransport(it->first));
    448     }
    449   }
    450 
    451   virtual TransportChannel* CreateChannel(
    452       const std::string& content_name,
    453       const std::string& channel_name,
    454       int component) {
    455     if (fail_create_channel_) {
    456       return NULL;
    457     }
    458     return BaseSession::CreateChannel(content_name, channel_name, component);
    459   }
    460 
    461   void set_fail_channel_creation(bool fail_channel_creation) {
    462     fail_create_channel_ = fail_channel_creation;
    463   }
    464 
    465   // TODO: Hoist this into Session when we re-work the Session code.
    466   void set_ssl_identity(talk_base::SSLIdentity* identity) {
    467     for (TransportMap::const_iterator it = transport_proxies().begin();
    468         it != transport_proxies().end(); ++it) {
    469       // We know that we have a FakeTransport*
    470 
    471       static_cast<FakeTransport*>(it->second->impl())->set_identity
    472           (identity);
    473     }
    474   }
    475 
    476  protected:
    477   virtual Transport* CreateTransport(const std::string& content_name) {
    478     return new FakeTransport(signaling_thread(), worker_thread(), content_name);
    479   }
    480 
    481   void CompleteNegotiation() {
    482     for (TransportMap::const_iterator it = transport_proxies().begin();
    483         it != transport_proxies().end(); ++it) {
    484       it->second->CompleteNegotiation();
    485       it->second->ConnectChannels();
    486     }
    487   }
    488 
    489  private:
    490   bool fail_create_channel_;
    491 };
    492 
    493 }  // namespace cricket
    494 
    495 #endif  // TALK_P2P_BASE_FAKESESSION_H_
    496