Home | History | Annotate | Download | only in base
      1 /*
      2  * libjingle
      3  * Copyright 2009, Google, Inc.
      4  *
      5  * Redistribution and use in source and binary forms, with or without
      6  * modification, are permitted provided that the following conditions are met:
      7  *
      8  *  1. Redistributions of source code must retain the above copyright notice,
      9  *     this list of conditions and the following disclaimer.
     10  *  2. Redistributions in binary form must reproduce the above copyright notice,
     11  *     this list of conditions and the following disclaimer in the documentation
     12  *     and/or other materials provided with the distribution.
     13  *  3. The name of the author may not be used to endorse or promote products
     14  *     derived from this software without specific prior written permission.
     15  *
     16  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
     17  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
     18  * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
     19  * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     20  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
     21  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
     22  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
     23  * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
     24  * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
     25  * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     26  */
     27 
     28 #ifndef TALK_P2P_BASE_FAKESESSION_H_
     29 #define TALK_P2P_BASE_FAKESESSION_H_
     30 
     31 #include <map>
     32 #include <string>
     33 #include <vector>
     34 
     35 #include "talk/p2p/base/session.h"
     36 #include "talk/p2p/base/transport.h"
     37 #include "talk/p2p/base/transportchannel.h"
     38 #include "talk/p2p/base/transportchannelimpl.h"
     39 #include "webrtc/base/buffer.h"
     40 #include "webrtc/base/fakesslidentity.h"
     41 #include "webrtc/base/messagequeue.h"
     42 #include "webrtc/base/sigslot.h"
     43 #include "webrtc/base/sslfingerprint.h"
     44 
     45 namespace cricket {
     46 
     47 class FakeTransport;
     48 
     49 struct PacketMessageData : public rtc::MessageData {
     50   PacketMessageData(const char* data, size_t len) : packet(data, len) {
     51   }
     52   rtc::Buffer packet;
     53 };
     54 
     55 // Fake transport channel class, which can be passed to anything that needs a
     56 // transport channel. Can be informed of another FakeTransportChannel via
     57 // SetDestination.
     58 class FakeTransportChannel : public TransportChannelImpl,
     59                              public rtc::MessageHandler {
     60  public:
     61   explicit FakeTransportChannel(Transport* transport,
     62                                 const std::string& content_name,
     63                                 int component)
     64       : TransportChannelImpl(content_name, component),
     65         transport_(transport),
     66         dest_(NULL),
     67         state_(STATE_INIT),
     68         async_(false),
     69         identity_(NULL),
     70         do_dtls_(false),
     71         role_(ICEROLE_UNKNOWN),
     72         tiebreaker_(0),
     73         ice_proto_(ICEPROTO_HYBRID),
     74         remote_ice_mode_(ICEMODE_FULL),
     75         dtls_fingerprint_("", NULL, 0),
     76         ssl_role_(rtc::SSL_CLIENT),
     77         connection_count_(0) {
     78   }
     79   ~FakeTransportChannel() {
     80     Reset();
     81   }
     82 
     83   uint64 IceTiebreaker() const { return tiebreaker_; }
     84   TransportProtocol protocol() const { return ice_proto_; }
     85   IceMode remote_ice_mode() const { return remote_ice_mode_; }
     86   const std::string& ice_ufrag() const { return ice_ufrag_; }
     87   const std::string& ice_pwd() const { return ice_pwd_; }
     88   const std::string& remote_ice_ufrag() const { return remote_ice_ufrag_; }
     89   const std::string& remote_ice_pwd() const { return remote_ice_pwd_; }
     90   const rtc::SSLFingerprint& dtls_fingerprint() const {
     91     return dtls_fingerprint_;
     92   }
     93 
     94   void SetAsync(bool async) {
     95     async_ = async;
     96   }
     97 
     98   virtual Transport* GetTransport() {
     99     return transport_;
    100   }
    101 
    102   virtual void SetIceRole(IceRole role) { role_ = role; }
    103   virtual IceRole GetIceRole() const { return role_; }
    104   virtual size_t GetConnectionCount() const { return connection_count_; }
    105   virtual void SetIceTiebreaker(uint64 tiebreaker) { tiebreaker_ = tiebreaker; }
    106   virtual bool GetIceProtocolType(IceProtocolType* type) const {
    107     *type = ice_proto_;
    108     return true;
    109   }
    110   virtual void SetIceProtocolType(IceProtocolType type) { ice_proto_ = type; }
    111   virtual void SetIceCredentials(const std::string& ice_ufrag,
    112                                  const std::string& ice_pwd) {
    113     ice_ufrag_ = ice_ufrag;
    114     ice_pwd_ = ice_pwd;
    115   }
    116   virtual void SetRemoteIceCredentials(const std::string& ice_ufrag,
    117                                        const std::string& ice_pwd) {
    118     remote_ice_ufrag_ = ice_ufrag;
    119     remote_ice_pwd_ = ice_pwd;
    120   }
    121 
    122   virtual void SetRemoteIceMode(IceMode mode) { remote_ice_mode_ = mode; }
    123   virtual bool SetRemoteFingerprint(const std::string& alg, const uint8* digest,
    124                                     size_t digest_len) {
    125     dtls_fingerprint_ = rtc::SSLFingerprint(alg, digest, digest_len);
    126     return true;
    127   }
    128   virtual bool SetSslRole(rtc::SSLRole role) {
    129     ssl_role_ = role;
    130     return true;
    131   }
    132   virtual bool GetSslRole(rtc::SSLRole* role) const {
    133     *role = ssl_role_;
    134     return true;
    135   }
    136 
    137   virtual void Connect() {
    138     if (state_ == STATE_INIT) {
    139       state_ = STATE_CONNECTING;
    140     }
    141   }
    142   virtual void Reset() {
    143     if (state_ != STATE_INIT) {
    144       state_ = STATE_INIT;
    145       if (dest_) {
    146         dest_->state_ = STATE_INIT;
    147         dest_->dest_ = NULL;
    148         dest_ = NULL;
    149       }
    150     }
    151   }
    152 
    153   void SetWritable(bool writable) {
    154     set_writable(writable);
    155   }
    156 
    157   void SetDestination(FakeTransportChannel* dest) {
    158     if (state_ == STATE_CONNECTING && dest) {
    159       // This simulates the delivery of candidates.
    160       dest_ = dest;
    161       dest_->dest_ = this;
    162       if (identity_ && dest_->identity_) {
    163         do_dtls_ = true;
    164         dest_->do_dtls_ = true;
    165         NegotiateSrtpCiphers();
    166       }
    167       state_ = STATE_CONNECTED;
    168       dest_->state_ = STATE_CONNECTED;
    169       set_writable(true);
    170       dest_->set_writable(true);
    171     } else if (state_ == STATE_CONNECTED && !dest) {
    172       // Simulates loss of connectivity, by asymmetrically forgetting dest_.
    173       dest_ = NULL;
    174       state_ = STATE_CONNECTING;
    175       set_writable(false);
    176     }
    177   }
    178 
    179   void SetConnectionCount(size_t connection_count) {
    180     size_t old_connection_count = connection_count_;
    181     connection_count_ = connection_count;
    182     if (connection_count_ < old_connection_count)
    183       SignalConnectionRemoved(this);
    184   }
    185 
    186   virtual int SendPacket(const char* data, size_t len,
    187                          const rtc::PacketOptions& options, int flags) {
    188     if (state_ != STATE_CONNECTED) {
    189       return -1;
    190     }
    191 
    192     if (flags != PF_SRTP_BYPASS && flags != 0) {
    193       return -1;
    194     }
    195 
    196     PacketMessageData* packet = new PacketMessageData(data, len);
    197     if (async_) {
    198       rtc::Thread::Current()->Post(this, 0, packet);
    199     } else {
    200       rtc::Thread::Current()->Send(this, 0, packet);
    201     }
    202     return static_cast<int>(len);
    203   }
    204   virtual int SetOption(rtc::Socket::Option opt, int value) {
    205     return true;
    206   }
    207   virtual int GetError() {
    208     return 0;
    209   }
    210 
    211   virtual void OnSignalingReady() {
    212   }
    213   virtual void OnCandidate(const Candidate& candidate) {
    214   }
    215 
    216   virtual void OnMessage(rtc::Message* msg) {
    217     PacketMessageData* data = static_cast<PacketMessageData*>(
    218         msg->pdata);
    219     dest_->SignalReadPacket(dest_, data->packet.data(),
    220                             data->packet.length(),
    221                             rtc::CreatePacketTime(0), 0);
    222     delete data;
    223   }
    224 
    225   bool SetLocalIdentity(rtc::SSLIdentity* identity) {
    226     identity_ = identity;
    227     return true;
    228   }
    229 
    230 
    231   void SetRemoteCertificate(rtc::FakeSSLCertificate* cert) {
    232     remote_cert_ = cert;
    233   }
    234 
    235   virtual bool IsDtlsActive() const {
    236     return do_dtls_;
    237   }
    238 
    239   virtual bool SetSrtpCiphers(const std::vector<std::string>& ciphers) {
    240     srtp_ciphers_ = ciphers;
    241     return true;
    242   }
    243 
    244   virtual bool GetSrtpCipher(std::string* cipher) {
    245     if (!chosen_srtp_cipher_.empty()) {
    246       *cipher = chosen_srtp_cipher_;
    247       return true;
    248     }
    249     return false;
    250   }
    251 
    252   virtual bool GetLocalIdentity(rtc::SSLIdentity** identity) const {
    253     if (!identity_)
    254       return false;
    255 
    256     *identity = identity_->GetReference();
    257     return true;
    258   }
    259 
    260   virtual bool GetRemoteCertificate(rtc::SSLCertificate** cert) const {
    261     if (!remote_cert_)
    262       return false;
    263 
    264     *cert = remote_cert_->GetReference();
    265     return true;
    266   }
    267 
    268   virtual bool ExportKeyingMaterial(const std::string& label,
    269                                     const uint8* context,
    270                                     size_t context_len,
    271                                     bool use_context,
    272                                     uint8* result,
    273                                     size_t result_len) {
    274     if (!chosen_srtp_cipher_.empty()) {
    275       memset(result, 0xff, result_len);
    276       return true;
    277     }
    278 
    279     return false;
    280   }
    281 
    282   virtual void NegotiateSrtpCiphers() {
    283     for (std::vector<std::string>::const_iterator it1 = srtp_ciphers_.begin();
    284         it1 != srtp_ciphers_.end(); ++it1) {
    285       for (std::vector<std::string>::const_iterator it2 =
    286               dest_->srtp_ciphers_.begin();
    287           it2 != dest_->srtp_ciphers_.end(); ++it2) {
    288         if (*it1 == *it2) {
    289           chosen_srtp_cipher_ = *it1;
    290           dest_->chosen_srtp_cipher_ = *it2;
    291           return;
    292         }
    293       }
    294     }
    295   }
    296 
    297   virtual bool GetStats(ConnectionInfos* infos) OVERRIDE {
    298     ConnectionInfo info;
    299     infos->clear();
    300     infos->push_back(info);
    301     return true;
    302   }
    303 
    304  private:
    305   enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED };
    306   Transport* transport_;
    307   FakeTransportChannel* dest_;
    308   State state_;
    309   bool async_;
    310   rtc::SSLIdentity* identity_;
    311   rtc::FakeSSLCertificate* remote_cert_;
    312   bool do_dtls_;
    313   std::vector<std::string> srtp_ciphers_;
    314   std::string chosen_srtp_cipher_;
    315   IceRole role_;
    316   uint64 tiebreaker_;
    317   IceProtocolType ice_proto_;
    318   std::string ice_ufrag_;
    319   std::string ice_pwd_;
    320   std::string remote_ice_ufrag_;
    321   std::string remote_ice_pwd_;
    322   IceMode remote_ice_mode_;
    323   rtc::SSLFingerprint dtls_fingerprint_;
    324   rtc::SSLRole ssl_role_;
    325   size_t connection_count_;
    326 };
    327 
    328 // Fake transport class, which can be passed to anything that needs a Transport.
    329 // Can be informed of another FakeTransport via SetDestination (low-tech way
    330 // of doing candidates)
    331 class FakeTransport : public Transport {
    332  public:
    333   typedef std::map<int, FakeTransportChannel*> ChannelMap;
    334   FakeTransport(rtc::Thread* signaling_thread,
    335                 rtc::Thread* worker_thread,
    336                 const std::string& content_name,
    337                 PortAllocator* alllocator = NULL)
    338       : Transport(signaling_thread, worker_thread,
    339                   content_name, "test_type", NULL),
    340       dest_(NULL),
    341       async_(false),
    342       identity_(NULL) {
    343   }
    344   ~FakeTransport() {
    345     DestroyAllChannels();
    346   }
    347 
    348   const ChannelMap& channels() const { return channels_; }
    349 
    350   void SetAsync(bool async) { async_ = async; }
    351   void SetDestination(FakeTransport* dest) {
    352     dest_ = dest;
    353     for (ChannelMap::iterator it = channels_.begin(); it != channels_.end();
    354          ++it) {
    355       it->second->SetLocalIdentity(identity_);
    356       SetChannelDestination(it->first, it->second);
    357     }
    358   }
    359 
    360   void SetWritable(bool writable) {
    361     for (ChannelMap::iterator it = channels_.begin(); it != channels_.end();
    362          ++it) {
    363       it->second->SetWritable(writable);
    364     }
    365   }
    366 
    367   void set_identity(rtc::SSLIdentity* identity) {
    368     identity_ = identity;
    369   }
    370 
    371   using Transport::local_description;
    372   using Transport::remote_description;
    373 
    374  protected:
    375   virtual TransportChannelImpl* CreateTransportChannel(int component) {
    376     if (channels_.find(component) != channels_.end()) {
    377       return NULL;
    378     }
    379     FakeTransportChannel* channel =
    380         new FakeTransportChannel(this, content_name(), component);
    381     channel->SetAsync(async_);
    382     SetChannelDestination(component, channel);
    383     channels_[component] = channel;
    384     return channel;
    385   }
    386   virtual void DestroyTransportChannel(TransportChannelImpl* channel) {
    387     channels_.erase(channel->component());
    388     delete channel;
    389   }
    390   virtual void SetIdentity_w(rtc::SSLIdentity* identity) {
    391     identity_ = identity;
    392   }
    393   virtual bool GetIdentity_w(rtc::SSLIdentity** identity) {
    394     if (!identity_)
    395       return false;
    396 
    397     *identity = identity_->GetReference();
    398     return true;
    399   }
    400 
    401  private:
    402   FakeTransportChannel* GetFakeChannel(int component) {
    403     ChannelMap::iterator it = channels_.find(component);
    404     return (it != channels_.end()) ? it->second : NULL;
    405   }
    406   void SetChannelDestination(int component,
    407                              FakeTransportChannel* channel) {
    408     FakeTransportChannel* dest_channel = NULL;
    409     if (dest_) {
    410       dest_channel = dest_->GetFakeChannel(component);
    411       if (dest_channel) {
    412         dest_channel->SetLocalIdentity(dest_->identity_);
    413       }
    414     }
    415     channel->SetDestination(dest_channel);
    416   }
    417 
    418   // Note, this is distinct from the Channel map owned by Transport.
    419   // This map just tracks the FakeTransportChannels created by this class.
    420   ChannelMap channels_;
    421   FakeTransport* dest_;
    422   bool async_;
    423   rtc::SSLIdentity* identity_;
    424 };
    425 
    426 // Fake session class, which can be passed into a BaseChannel object for
    427 // test purposes. Can be connected to other FakeSessions via Connect().
    428 class FakeSession : public BaseSession {
    429  public:
    430   explicit FakeSession()
    431       : BaseSession(rtc::Thread::Current(),
    432                     rtc::Thread::Current(),
    433                     NULL, "", "", true),
    434       fail_create_channel_(false) {
    435   }
    436   explicit FakeSession(bool initiator)
    437       : BaseSession(rtc::Thread::Current(),
    438                     rtc::Thread::Current(),
    439                     NULL, "", "", initiator),
    440       fail_create_channel_(false) {
    441   }
    442   FakeSession(rtc::Thread* worker_thread, bool initiator)
    443       : BaseSession(rtc::Thread::Current(),
    444                     worker_thread,
    445                     NULL, "", "", initiator),
    446       fail_create_channel_(false) {
    447   }
    448 
    449   FakeTransport* GetTransport(const std::string& content_name) {
    450     return static_cast<FakeTransport*>(
    451         BaseSession::GetTransport(content_name));
    452   }
    453 
    454   void Connect(FakeSession* dest) {
    455     // Simulate the exchange of candidates.
    456     CompleteNegotiation();
    457     dest->CompleteNegotiation();
    458     for (TransportMap::const_iterator it = transport_proxies().begin();
    459         it != transport_proxies().end(); ++it) {
    460       static_cast<FakeTransport*>(it->second->impl())->SetDestination(
    461           dest->GetTransport(it->first));
    462     }
    463   }
    464 
    465   virtual TransportChannel* CreateChannel(
    466       const std::string& content_name,
    467       const std::string& channel_name,
    468       int component) {
    469     if (fail_create_channel_) {
    470       return NULL;
    471     }
    472     return BaseSession::CreateChannel(content_name, channel_name, component);
    473   }
    474 
    475   void set_fail_channel_creation(bool fail_channel_creation) {
    476     fail_create_channel_ = fail_channel_creation;
    477   }
    478 
    479   // TODO: Hoist this into Session when we re-work the Session code.
    480   void set_ssl_identity(rtc::SSLIdentity* identity) {
    481     for (TransportMap::const_iterator it = transport_proxies().begin();
    482         it != transport_proxies().end(); ++it) {
    483       // We know that we have a FakeTransport*
    484 
    485       static_cast<FakeTransport*>(it->second->impl())->set_identity
    486           (identity);
    487     }
    488   }
    489 
    490  protected:
    491   virtual Transport* CreateTransport(const std::string& content_name) {
    492     return new FakeTransport(signaling_thread(), worker_thread(), content_name);
    493   }
    494 
    495   void CompleteNegotiation() {
    496     for (TransportMap::const_iterator it = transport_proxies().begin();
    497         it != transport_proxies().end(); ++it) {
    498       it->second->CompleteNegotiation();
    499       it->second->ConnectChannels();
    500     }
    501   }
    502 
    503  private:
    504   bool fail_create_channel_;
    505 };
    506 
    507 }  // namespace cricket
    508 
    509 #endif  // TALK_P2P_BASE_FAKESESSION_H_
    510