1 /* 2 * libjingle 3 * Copyright 2009, Google, Inc. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, 9 * this list of conditions and the following disclaimer. 10 * 2. Redistributions in binary form must reproduce the above copyright notice, 11 * this list of conditions and the following disclaimer in the documentation 12 * and/or other materials provided with the distribution. 13 * 3. The name of the author may not be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED 17 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 18 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO 19 * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 20 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 22 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 23 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 24 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 25 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 */ 27 28 #ifndef TALK_P2P_BASE_FAKESESSION_H_ 29 #define TALK_P2P_BASE_FAKESESSION_H_ 30 31 #include <map> 32 #include <string> 33 #include <vector> 34 35 #include "talk/p2p/base/session.h" 36 #include "talk/p2p/base/transport.h" 37 #include "talk/p2p/base/transportchannel.h" 38 #include "talk/p2p/base/transportchannelimpl.h" 39 #include "webrtc/base/buffer.h" 40 #include "webrtc/base/fakesslidentity.h" 41 #include "webrtc/base/messagequeue.h" 42 #include "webrtc/base/sigslot.h" 43 #include "webrtc/base/sslfingerprint.h" 44 45 namespace cricket { 46 47 class FakeTransport; 48 49 struct PacketMessageData : public rtc::MessageData { 50 PacketMessageData(const char* data, size_t len) : packet(data, len) { 51 } 52 rtc::Buffer packet; 53 }; 54 55 // Fake transport channel class, which can be passed to anything that needs a 56 // transport channel. Can be informed of another FakeTransportChannel via 57 // SetDestination. 58 class FakeTransportChannel : public TransportChannelImpl, 59 public rtc::MessageHandler { 60 public: 61 explicit FakeTransportChannel(Transport* transport, 62 const std::string& content_name, 63 int component) 64 : TransportChannelImpl(content_name, component), 65 transport_(transport), 66 dest_(NULL), 67 state_(STATE_INIT), 68 async_(false), 69 identity_(NULL), 70 do_dtls_(false), 71 role_(ICEROLE_UNKNOWN), 72 tiebreaker_(0), 73 ice_proto_(ICEPROTO_HYBRID), 74 remote_ice_mode_(ICEMODE_FULL), 75 dtls_fingerprint_("", NULL, 0), 76 ssl_role_(rtc::SSL_CLIENT), 77 connection_count_(0) { 78 } 79 ~FakeTransportChannel() { 80 Reset(); 81 } 82 83 uint64 IceTiebreaker() const { return tiebreaker_; } 84 TransportProtocol protocol() const { return ice_proto_; } 85 IceMode remote_ice_mode() const { return remote_ice_mode_; } 86 const std::string& ice_ufrag() const { return ice_ufrag_; } 87 const std::string& ice_pwd() const { return ice_pwd_; } 88 const std::string& remote_ice_ufrag() const { return remote_ice_ufrag_; } 89 const std::string& remote_ice_pwd() const { return remote_ice_pwd_; } 90 const rtc::SSLFingerprint& dtls_fingerprint() const { 91 return dtls_fingerprint_; 92 } 93 94 void SetAsync(bool async) { 95 async_ = async; 96 } 97 98 virtual Transport* GetTransport() { 99 return transport_; 100 } 101 102 virtual void SetIceRole(IceRole role) { role_ = role; } 103 virtual IceRole GetIceRole() const { return role_; } 104 virtual size_t GetConnectionCount() const { return connection_count_; } 105 virtual void SetIceTiebreaker(uint64 tiebreaker) { tiebreaker_ = tiebreaker; } 106 virtual bool GetIceProtocolType(IceProtocolType* type) const { 107 *type = ice_proto_; 108 return true; 109 } 110 virtual void SetIceProtocolType(IceProtocolType type) { ice_proto_ = type; } 111 virtual void SetIceCredentials(const std::string& ice_ufrag, 112 const std::string& ice_pwd) { 113 ice_ufrag_ = ice_ufrag; 114 ice_pwd_ = ice_pwd; 115 } 116 virtual void SetRemoteIceCredentials(const std::string& ice_ufrag, 117 const std::string& ice_pwd) { 118 remote_ice_ufrag_ = ice_ufrag; 119 remote_ice_pwd_ = ice_pwd; 120 } 121 122 virtual void SetRemoteIceMode(IceMode mode) { remote_ice_mode_ = mode; } 123 virtual bool SetRemoteFingerprint(const std::string& alg, const uint8* digest, 124 size_t digest_len) { 125 dtls_fingerprint_ = rtc::SSLFingerprint(alg, digest, digest_len); 126 return true; 127 } 128 virtual bool SetSslRole(rtc::SSLRole role) { 129 ssl_role_ = role; 130 return true; 131 } 132 virtual bool GetSslRole(rtc::SSLRole* role) const { 133 *role = ssl_role_; 134 return true; 135 } 136 137 virtual void Connect() { 138 if (state_ == STATE_INIT) { 139 state_ = STATE_CONNECTING; 140 } 141 } 142 virtual void Reset() { 143 if (state_ != STATE_INIT) { 144 state_ = STATE_INIT; 145 if (dest_) { 146 dest_->state_ = STATE_INIT; 147 dest_->dest_ = NULL; 148 dest_ = NULL; 149 } 150 } 151 } 152 153 void SetWritable(bool writable) { 154 set_writable(writable); 155 } 156 157 void SetDestination(FakeTransportChannel* dest) { 158 if (state_ == STATE_CONNECTING && dest) { 159 // This simulates the delivery of candidates. 160 dest_ = dest; 161 dest_->dest_ = this; 162 if (identity_ && dest_->identity_) { 163 do_dtls_ = true; 164 dest_->do_dtls_ = true; 165 NegotiateSrtpCiphers(); 166 } 167 state_ = STATE_CONNECTED; 168 dest_->state_ = STATE_CONNECTED; 169 set_writable(true); 170 dest_->set_writable(true); 171 } else if (state_ == STATE_CONNECTED && !dest) { 172 // Simulates loss of connectivity, by asymmetrically forgetting dest_. 173 dest_ = NULL; 174 state_ = STATE_CONNECTING; 175 set_writable(false); 176 } 177 } 178 179 void SetConnectionCount(size_t connection_count) { 180 size_t old_connection_count = connection_count_; 181 connection_count_ = connection_count; 182 if (connection_count_ < old_connection_count) 183 SignalConnectionRemoved(this); 184 } 185 186 virtual int SendPacket(const char* data, size_t len, 187 const rtc::PacketOptions& options, int flags) { 188 if (state_ != STATE_CONNECTED) { 189 return -1; 190 } 191 192 if (flags != PF_SRTP_BYPASS && flags != 0) { 193 return -1; 194 } 195 196 PacketMessageData* packet = new PacketMessageData(data, len); 197 if (async_) { 198 rtc::Thread::Current()->Post(this, 0, packet); 199 } else { 200 rtc::Thread::Current()->Send(this, 0, packet); 201 } 202 return static_cast<int>(len); 203 } 204 virtual int SetOption(rtc::Socket::Option opt, int value) { 205 return true; 206 } 207 virtual int GetError() { 208 return 0; 209 } 210 211 virtual void OnSignalingReady() { 212 } 213 virtual void OnCandidate(const Candidate& candidate) { 214 } 215 216 virtual void OnMessage(rtc::Message* msg) { 217 PacketMessageData* data = static_cast<PacketMessageData*>( 218 msg->pdata); 219 dest_->SignalReadPacket(dest_, data->packet.data(), 220 data->packet.length(), 221 rtc::CreatePacketTime(0), 0); 222 delete data; 223 } 224 225 bool SetLocalIdentity(rtc::SSLIdentity* identity) { 226 identity_ = identity; 227 return true; 228 } 229 230 231 void SetRemoteCertificate(rtc::FakeSSLCertificate* cert) { 232 remote_cert_ = cert; 233 } 234 235 virtual bool IsDtlsActive() const { 236 return do_dtls_; 237 } 238 239 virtual bool SetSrtpCiphers(const std::vector<std::string>& ciphers) { 240 srtp_ciphers_ = ciphers; 241 return true; 242 } 243 244 virtual bool GetSrtpCipher(std::string* cipher) { 245 if (!chosen_srtp_cipher_.empty()) { 246 *cipher = chosen_srtp_cipher_; 247 return true; 248 } 249 return false; 250 } 251 252 virtual bool GetLocalIdentity(rtc::SSLIdentity** identity) const { 253 if (!identity_) 254 return false; 255 256 *identity = identity_->GetReference(); 257 return true; 258 } 259 260 virtual bool GetRemoteCertificate(rtc::SSLCertificate** cert) const { 261 if (!remote_cert_) 262 return false; 263 264 *cert = remote_cert_->GetReference(); 265 return true; 266 } 267 268 virtual bool ExportKeyingMaterial(const std::string& label, 269 const uint8* context, 270 size_t context_len, 271 bool use_context, 272 uint8* result, 273 size_t result_len) { 274 if (!chosen_srtp_cipher_.empty()) { 275 memset(result, 0xff, result_len); 276 return true; 277 } 278 279 return false; 280 } 281 282 virtual void NegotiateSrtpCiphers() { 283 for (std::vector<std::string>::const_iterator it1 = srtp_ciphers_.begin(); 284 it1 != srtp_ciphers_.end(); ++it1) { 285 for (std::vector<std::string>::const_iterator it2 = 286 dest_->srtp_ciphers_.begin(); 287 it2 != dest_->srtp_ciphers_.end(); ++it2) { 288 if (*it1 == *it2) { 289 chosen_srtp_cipher_ = *it1; 290 dest_->chosen_srtp_cipher_ = *it2; 291 return; 292 } 293 } 294 } 295 } 296 297 virtual bool GetStats(ConnectionInfos* infos) OVERRIDE { 298 ConnectionInfo info; 299 infos->clear(); 300 infos->push_back(info); 301 return true; 302 } 303 304 private: 305 enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED }; 306 Transport* transport_; 307 FakeTransportChannel* dest_; 308 State state_; 309 bool async_; 310 rtc::SSLIdentity* identity_; 311 rtc::FakeSSLCertificate* remote_cert_; 312 bool do_dtls_; 313 std::vector<std::string> srtp_ciphers_; 314 std::string chosen_srtp_cipher_; 315 IceRole role_; 316 uint64 tiebreaker_; 317 IceProtocolType ice_proto_; 318 std::string ice_ufrag_; 319 std::string ice_pwd_; 320 std::string remote_ice_ufrag_; 321 std::string remote_ice_pwd_; 322 IceMode remote_ice_mode_; 323 rtc::SSLFingerprint dtls_fingerprint_; 324 rtc::SSLRole ssl_role_; 325 size_t connection_count_; 326 }; 327 328 // Fake transport class, which can be passed to anything that needs a Transport. 329 // Can be informed of another FakeTransport via SetDestination (low-tech way 330 // of doing candidates) 331 class FakeTransport : public Transport { 332 public: 333 typedef std::map<int, FakeTransportChannel*> ChannelMap; 334 FakeTransport(rtc::Thread* signaling_thread, 335 rtc::Thread* worker_thread, 336 const std::string& content_name, 337 PortAllocator* alllocator = NULL) 338 : Transport(signaling_thread, worker_thread, 339 content_name, "test_type", NULL), 340 dest_(NULL), 341 async_(false), 342 identity_(NULL) { 343 } 344 ~FakeTransport() { 345 DestroyAllChannels(); 346 } 347 348 const ChannelMap& channels() const { return channels_; } 349 350 void SetAsync(bool async) { async_ = async; } 351 void SetDestination(FakeTransport* dest) { 352 dest_ = dest; 353 for (ChannelMap::iterator it = channels_.begin(); it != channels_.end(); 354 ++it) { 355 it->second->SetLocalIdentity(identity_); 356 SetChannelDestination(it->first, it->second); 357 } 358 } 359 360 void SetWritable(bool writable) { 361 for (ChannelMap::iterator it = channels_.begin(); it != channels_.end(); 362 ++it) { 363 it->second->SetWritable(writable); 364 } 365 } 366 367 void set_identity(rtc::SSLIdentity* identity) { 368 identity_ = identity; 369 } 370 371 using Transport::local_description; 372 using Transport::remote_description; 373 374 protected: 375 virtual TransportChannelImpl* CreateTransportChannel(int component) { 376 if (channels_.find(component) != channels_.end()) { 377 return NULL; 378 } 379 FakeTransportChannel* channel = 380 new FakeTransportChannel(this, content_name(), component); 381 channel->SetAsync(async_); 382 SetChannelDestination(component, channel); 383 channels_[component] = channel; 384 return channel; 385 } 386 virtual void DestroyTransportChannel(TransportChannelImpl* channel) { 387 channels_.erase(channel->component()); 388 delete channel; 389 } 390 virtual void SetIdentity_w(rtc::SSLIdentity* identity) { 391 identity_ = identity; 392 } 393 virtual bool GetIdentity_w(rtc::SSLIdentity** identity) { 394 if (!identity_) 395 return false; 396 397 *identity = identity_->GetReference(); 398 return true; 399 } 400 401 private: 402 FakeTransportChannel* GetFakeChannel(int component) { 403 ChannelMap::iterator it = channels_.find(component); 404 return (it != channels_.end()) ? it->second : NULL; 405 } 406 void SetChannelDestination(int component, 407 FakeTransportChannel* channel) { 408 FakeTransportChannel* dest_channel = NULL; 409 if (dest_) { 410 dest_channel = dest_->GetFakeChannel(component); 411 if (dest_channel) { 412 dest_channel->SetLocalIdentity(dest_->identity_); 413 } 414 } 415 channel->SetDestination(dest_channel); 416 } 417 418 // Note, this is distinct from the Channel map owned by Transport. 419 // This map just tracks the FakeTransportChannels created by this class. 420 ChannelMap channels_; 421 FakeTransport* dest_; 422 bool async_; 423 rtc::SSLIdentity* identity_; 424 }; 425 426 // Fake session class, which can be passed into a BaseChannel object for 427 // test purposes. Can be connected to other FakeSessions via Connect(). 428 class FakeSession : public BaseSession { 429 public: 430 explicit FakeSession() 431 : BaseSession(rtc::Thread::Current(), 432 rtc::Thread::Current(), 433 NULL, "", "", true), 434 fail_create_channel_(false) { 435 } 436 explicit FakeSession(bool initiator) 437 : BaseSession(rtc::Thread::Current(), 438 rtc::Thread::Current(), 439 NULL, "", "", initiator), 440 fail_create_channel_(false) { 441 } 442 FakeSession(rtc::Thread* worker_thread, bool initiator) 443 : BaseSession(rtc::Thread::Current(), 444 worker_thread, 445 NULL, "", "", initiator), 446 fail_create_channel_(false) { 447 } 448 449 FakeTransport* GetTransport(const std::string& content_name) { 450 return static_cast<FakeTransport*>( 451 BaseSession::GetTransport(content_name)); 452 } 453 454 void Connect(FakeSession* dest) { 455 // Simulate the exchange of candidates. 456 CompleteNegotiation(); 457 dest->CompleteNegotiation(); 458 for (TransportMap::const_iterator it = transport_proxies().begin(); 459 it != transport_proxies().end(); ++it) { 460 static_cast<FakeTransport*>(it->second->impl())->SetDestination( 461 dest->GetTransport(it->first)); 462 } 463 } 464 465 virtual TransportChannel* CreateChannel( 466 const std::string& content_name, 467 const std::string& channel_name, 468 int component) { 469 if (fail_create_channel_) { 470 return NULL; 471 } 472 return BaseSession::CreateChannel(content_name, channel_name, component); 473 } 474 475 void set_fail_channel_creation(bool fail_channel_creation) { 476 fail_create_channel_ = fail_channel_creation; 477 } 478 479 // TODO: Hoist this into Session when we re-work the Session code. 480 void set_ssl_identity(rtc::SSLIdentity* identity) { 481 for (TransportMap::const_iterator it = transport_proxies().begin(); 482 it != transport_proxies().end(); ++it) { 483 // We know that we have a FakeTransport* 484 485 static_cast<FakeTransport*>(it->second->impl())->set_identity 486 (identity); 487 } 488 } 489 490 protected: 491 virtual Transport* CreateTransport(const std::string& content_name) { 492 return new FakeTransport(signaling_thread(), worker_thread(), content_name); 493 } 494 495 void CompleteNegotiation() { 496 for (TransportMap::const_iterator it = transport_proxies().begin(); 497 it != transport_proxies().end(); ++it) { 498 it->second->CompleteNegotiation(); 499 it->second->ConnectChannels(); 500 } 501 } 502 503 private: 504 bool fail_create_channel_; 505 }; 506 507 } // namespace cricket 508 509 #endif // TALK_P2P_BASE_FAKESESSION_H_ 510