1 /* 2 * libjingle 3 * Copyright 2009, Google, Inc. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, 9 * this list of conditions and the following disclaimer. 10 * 2. Redistributions in binary form must reproduce the above copyright notice, 11 * this list of conditions and the following disclaimer in the documentation 12 * and/or other materials provided with the distribution. 13 * 3. The name of the author may not be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED 17 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 18 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO 19 * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 20 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 22 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 23 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 24 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 25 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 */ 27 28 #ifndef TALK_P2P_BASE_FAKESESSION_H_ 29 #define TALK_P2P_BASE_FAKESESSION_H_ 30 31 #include <map> 32 #include <string> 33 #include <vector> 34 35 #include "talk/base/buffer.h" 36 #include "talk/base/fakesslidentity.h" 37 #include "talk/base/sigslot.h" 38 #include "talk/base/sslfingerprint.h" 39 #include "talk/base/messagequeue.h" 40 #include "talk/p2p/base/session.h" 41 #include "talk/p2p/base/transport.h" 42 #include "talk/p2p/base/transportchannel.h" 43 #include "talk/p2p/base/transportchannelimpl.h" 44 45 namespace cricket { 46 47 class FakeTransport; 48 49 struct PacketMessageData : public talk_base::MessageData { 50 PacketMessageData(const char* data, size_t len) : packet(data, len) { 51 } 52 talk_base::Buffer packet; 53 }; 54 55 // Fake transport channel class, which can be passed to anything that needs a 56 // transport channel. Can be informed of another FakeTransportChannel via 57 // SetDestination. 58 class FakeTransportChannel : public TransportChannelImpl, 59 public talk_base::MessageHandler { 60 public: 61 explicit FakeTransportChannel(Transport* transport, 62 const std::string& content_name, 63 int component) 64 : TransportChannelImpl(content_name, component), 65 transport_(transport), 66 dest_(NULL), 67 state_(STATE_INIT), 68 async_(false), 69 identity_(NULL), 70 do_dtls_(false), 71 role_(ICEROLE_UNKNOWN), 72 tiebreaker_(0), 73 ice_proto_(ICEPROTO_HYBRID), 74 remote_ice_mode_(ICEMODE_FULL), 75 dtls_fingerprint_("", NULL, 0), 76 ssl_role_(talk_base::SSL_CLIENT) { 77 } 78 ~FakeTransportChannel() { 79 Reset(); 80 } 81 82 uint64 IceTiebreaker() const { return tiebreaker_; } 83 TransportProtocol protocol() const { return ice_proto_; } 84 IceMode remote_ice_mode() const { return remote_ice_mode_; } 85 const std::string& ice_ufrag() const { return ice_ufrag_; } 86 const std::string& ice_pwd() const { return ice_pwd_; } 87 const std::string& remote_ice_ufrag() const { return remote_ice_ufrag_; } 88 const std::string& remote_ice_pwd() const { return remote_ice_pwd_; } 89 const talk_base::SSLFingerprint& dtls_fingerprint() const { 90 return dtls_fingerprint_; 91 } 92 93 void SetAsync(bool async) { 94 async_ = async; 95 } 96 97 virtual Transport* GetTransport() { 98 return transport_; 99 } 100 101 virtual void SetIceRole(IceRole role) { role_ = role; } 102 virtual IceRole GetIceRole() const { return role_; } 103 virtual void SetIceTiebreaker(uint64 tiebreaker) { tiebreaker_ = tiebreaker; } 104 virtual void SetIceProtocolType(IceProtocolType type) { ice_proto_ = type; } 105 virtual void SetIceCredentials(const std::string& ice_ufrag, 106 const std::string& ice_pwd) { 107 ice_ufrag_ = ice_ufrag; 108 ice_pwd_ = ice_pwd; 109 } 110 virtual void SetRemoteIceCredentials(const std::string& ice_ufrag, 111 const std::string& ice_pwd) { 112 remote_ice_ufrag_ = ice_ufrag; 113 remote_ice_pwd_ = ice_pwd; 114 } 115 116 virtual void SetRemoteIceMode(IceMode mode) { remote_ice_mode_ = mode; } 117 virtual bool SetRemoteFingerprint(const std::string& alg, const uint8* digest, 118 size_t digest_len) { 119 dtls_fingerprint_ = talk_base::SSLFingerprint(alg, digest, digest_len); 120 return true; 121 } 122 virtual bool SetSslRole(talk_base::SSLRole role) { 123 ssl_role_ = role; 124 return true; 125 } 126 virtual bool GetSslRole(talk_base::SSLRole* role) const { 127 *role = ssl_role_; 128 return true; 129 } 130 131 virtual void Connect() { 132 if (state_ == STATE_INIT) { 133 state_ = STATE_CONNECTING; 134 } 135 } 136 virtual void Reset() { 137 if (state_ != STATE_INIT) { 138 state_ = STATE_INIT; 139 if (dest_) { 140 dest_->state_ = STATE_INIT; 141 dest_->dest_ = NULL; 142 dest_ = NULL; 143 } 144 } 145 } 146 147 void SetWritable(bool writable) { 148 set_writable(writable); 149 } 150 151 void SetDestination(FakeTransportChannel* dest) { 152 if (state_ == STATE_CONNECTING && dest) { 153 // This simulates the delivery of candidates. 154 dest_ = dest; 155 dest_->dest_ = this; 156 if (identity_ && dest_->identity_) { 157 do_dtls_ = true; 158 dest_->do_dtls_ = true; 159 NegotiateSrtpCiphers(); 160 } 161 state_ = STATE_CONNECTED; 162 dest_->state_ = STATE_CONNECTED; 163 set_writable(true); 164 dest_->set_writable(true); 165 } else if (state_ == STATE_CONNECTED && !dest) { 166 // Simulates loss of connectivity, by asymmetrically forgetting dest_. 167 dest_ = NULL; 168 state_ = STATE_CONNECTING; 169 set_writable(false); 170 } 171 } 172 173 virtual int SendPacket(const char* data, size_t len, 174 talk_base::DiffServCodePoint dscp, int flags) { 175 if (state_ != STATE_CONNECTED) { 176 return -1; 177 } 178 179 if (flags != PF_SRTP_BYPASS && flags != 0) { 180 return -1; 181 } 182 183 PacketMessageData* packet = new PacketMessageData(data, len); 184 if (async_) { 185 talk_base::Thread::Current()->Post(this, 0, packet); 186 } else { 187 talk_base::Thread::Current()->Send(this, 0, packet); 188 } 189 return static_cast<int>(len); 190 } 191 virtual int SetOption(talk_base::Socket::Option opt, int value) { 192 return true; 193 } 194 virtual int GetError() { 195 return 0; 196 } 197 198 virtual void OnSignalingReady() { 199 } 200 virtual void OnCandidate(const Candidate& candidate) { 201 } 202 203 virtual void OnMessage(talk_base::Message* msg) { 204 PacketMessageData* data = static_cast<PacketMessageData*>( 205 msg->pdata); 206 dest_->SignalReadPacket(dest_, data->packet.data(), 207 data->packet.length(), 208 talk_base::CreatePacketTime(0), 0); 209 delete data; 210 } 211 212 bool SetLocalIdentity(talk_base::SSLIdentity* identity) { 213 identity_ = identity; 214 return true; 215 } 216 217 218 void SetRemoteCertificate(talk_base::FakeSSLCertificate* cert) { 219 remote_cert_ = cert; 220 } 221 222 virtual bool IsDtlsActive() const { 223 return do_dtls_; 224 } 225 226 virtual bool SetSrtpCiphers(const std::vector<std::string>& ciphers) { 227 srtp_ciphers_ = ciphers; 228 return true; 229 } 230 231 virtual bool GetSrtpCipher(std::string* cipher) { 232 if (!chosen_srtp_cipher_.empty()) { 233 *cipher = chosen_srtp_cipher_; 234 return true; 235 } 236 return false; 237 } 238 239 virtual bool GetLocalIdentity(talk_base::SSLIdentity** identity) const { 240 if (!identity_) 241 return false; 242 243 *identity = identity_->GetReference(); 244 return true; 245 } 246 247 virtual bool GetRemoteCertificate(talk_base::SSLCertificate** cert) const { 248 if (!remote_cert_) 249 return false; 250 251 *cert = remote_cert_->GetReference(); 252 return true; 253 } 254 255 virtual bool ExportKeyingMaterial(const std::string& label, 256 const uint8* context, 257 size_t context_len, 258 bool use_context, 259 uint8* result, 260 size_t result_len) { 261 if (!chosen_srtp_cipher_.empty()) { 262 memset(result, 0xff, result_len); 263 return true; 264 } 265 266 return false; 267 } 268 269 virtual void NegotiateSrtpCiphers() { 270 for (std::vector<std::string>::const_iterator it1 = srtp_ciphers_.begin(); 271 it1 != srtp_ciphers_.end(); ++it1) { 272 for (std::vector<std::string>::const_iterator it2 = 273 dest_->srtp_ciphers_.begin(); 274 it2 != dest_->srtp_ciphers_.end(); ++it2) { 275 if (*it1 == *it2) { 276 chosen_srtp_cipher_ = *it1; 277 dest_->chosen_srtp_cipher_ = *it2; 278 return; 279 } 280 } 281 } 282 } 283 284 virtual bool GetStats(ConnectionInfos* infos) OVERRIDE { 285 ConnectionInfo info; 286 infos->clear(); 287 infos->push_back(info); 288 return true; 289 } 290 291 private: 292 enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED }; 293 Transport* transport_; 294 FakeTransportChannel* dest_; 295 State state_; 296 bool async_; 297 talk_base::SSLIdentity* identity_; 298 talk_base::FakeSSLCertificate* remote_cert_; 299 bool do_dtls_; 300 std::vector<std::string> srtp_ciphers_; 301 std::string chosen_srtp_cipher_; 302 IceRole role_; 303 uint64 tiebreaker_; 304 IceProtocolType ice_proto_; 305 std::string ice_ufrag_; 306 std::string ice_pwd_; 307 std::string remote_ice_ufrag_; 308 std::string remote_ice_pwd_; 309 IceMode remote_ice_mode_; 310 talk_base::SSLFingerprint dtls_fingerprint_; 311 talk_base::SSLRole ssl_role_; 312 }; 313 314 // Fake transport class, which can be passed to anything that needs a Transport. 315 // Can be informed of another FakeTransport via SetDestination (low-tech way 316 // of doing candidates) 317 class FakeTransport : public Transport { 318 public: 319 typedef std::map<int, FakeTransportChannel*> ChannelMap; 320 FakeTransport(talk_base::Thread* signaling_thread, 321 talk_base::Thread* worker_thread, 322 const std::string& content_name, 323 PortAllocator* alllocator = NULL) 324 : Transport(signaling_thread, worker_thread, 325 content_name, "test_type", NULL), 326 dest_(NULL), 327 async_(false), 328 identity_(NULL) { 329 } 330 ~FakeTransport() { 331 DestroyAllChannels(); 332 } 333 334 const ChannelMap& channels() const { return channels_; } 335 336 void SetAsync(bool async) { async_ = async; } 337 void SetDestination(FakeTransport* dest) { 338 dest_ = dest; 339 for (ChannelMap::iterator it = channels_.begin(); it != channels_.end(); 340 ++it) { 341 it->second->SetLocalIdentity(identity_); 342 SetChannelDestination(it->first, it->second); 343 } 344 } 345 346 void SetWritable(bool writable) { 347 for (ChannelMap::iterator it = channels_.begin(); it != channels_.end(); 348 ++it) { 349 it->second->SetWritable(writable); 350 } 351 } 352 353 void set_identity(talk_base::SSLIdentity* identity) { 354 identity_ = identity; 355 } 356 357 using Transport::local_description; 358 using Transport::remote_description; 359 360 protected: 361 virtual TransportChannelImpl* CreateTransportChannel(int component) { 362 if (channels_.find(component) != channels_.end()) { 363 return NULL; 364 } 365 FakeTransportChannel* channel = 366 new FakeTransportChannel(this, content_name(), component); 367 channel->SetAsync(async_); 368 SetChannelDestination(component, channel); 369 channels_[component] = channel; 370 return channel; 371 } 372 virtual void DestroyTransportChannel(TransportChannelImpl* channel) { 373 channels_.erase(channel->component()); 374 delete channel; 375 } 376 virtual void SetIdentity_w(talk_base::SSLIdentity* identity) { 377 identity_ = identity; 378 } 379 virtual bool GetIdentity_w(talk_base::SSLIdentity** identity) { 380 if (!identity_) 381 return false; 382 383 *identity = identity_->GetReference(); 384 return true; 385 } 386 387 private: 388 FakeTransportChannel* GetFakeChannel(int component) { 389 ChannelMap::iterator it = channels_.find(component); 390 return (it != channels_.end()) ? it->second : NULL; 391 } 392 void SetChannelDestination(int component, 393 FakeTransportChannel* channel) { 394 FakeTransportChannel* dest_channel = NULL; 395 if (dest_) { 396 dest_channel = dest_->GetFakeChannel(component); 397 if (dest_channel) { 398 dest_channel->SetLocalIdentity(dest_->identity_); 399 } 400 } 401 channel->SetDestination(dest_channel); 402 } 403 404 // Note, this is distinct from the Channel map owned by Transport. 405 // This map just tracks the FakeTransportChannels created by this class. 406 ChannelMap channels_; 407 FakeTransport* dest_; 408 bool async_; 409 talk_base::SSLIdentity* identity_; 410 }; 411 412 // Fake session class, which can be passed into a BaseChannel object for 413 // test purposes. Can be connected to other FakeSessions via Connect(). 414 class FakeSession : public BaseSession { 415 public: 416 explicit FakeSession() 417 : BaseSession(talk_base::Thread::Current(), 418 talk_base::Thread::Current(), 419 NULL, "", "", true), 420 fail_create_channel_(false) { 421 } 422 explicit FakeSession(bool initiator) 423 : BaseSession(talk_base::Thread::Current(), 424 talk_base::Thread::Current(), 425 NULL, "", "", initiator), 426 fail_create_channel_(false) { 427 } 428 FakeSession(talk_base::Thread* worker_thread, bool initiator) 429 : BaseSession(talk_base::Thread::Current(), 430 worker_thread, 431 NULL, "", "", initiator), 432 fail_create_channel_(false) { 433 } 434 435 FakeTransport* GetTransport(const std::string& content_name) { 436 return static_cast<FakeTransport*>( 437 BaseSession::GetTransport(content_name)); 438 } 439 440 void Connect(FakeSession* dest) { 441 // Simulate the exchange of candidates. 442 CompleteNegotiation(); 443 dest->CompleteNegotiation(); 444 for (TransportMap::const_iterator it = transport_proxies().begin(); 445 it != transport_proxies().end(); ++it) { 446 static_cast<FakeTransport*>(it->second->impl())->SetDestination( 447 dest->GetTransport(it->first)); 448 } 449 } 450 451 virtual TransportChannel* CreateChannel( 452 const std::string& content_name, 453 const std::string& channel_name, 454 int component) { 455 if (fail_create_channel_) { 456 return NULL; 457 } 458 return BaseSession::CreateChannel(content_name, channel_name, component); 459 } 460 461 void set_fail_channel_creation(bool fail_channel_creation) { 462 fail_create_channel_ = fail_channel_creation; 463 } 464 465 // TODO: Hoist this into Session when we re-work the Session code. 466 void set_ssl_identity(talk_base::SSLIdentity* identity) { 467 for (TransportMap::const_iterator it = transport_proxies().begin(); 468 it != transport_proxies().end(); ++it) { 469 // We know that we have a FakeTransport* 470 471 static_cast<FakeTransport*>(it->second->impl())->set_identity 472 (identity); 473 } 474 } 475 476 protected: 477 virtual Transport* CreateTransport(const std::string& content_name) { 478 return new FakeTransport(signaling_thread(), worker_thread(), content_name); 479 } 480 481 void CompleteNegotiation() { 482 for (TransportMap::const_iterator it = transport_proxies().begin(); 483 it != transport_proxies().end(); ++it) { 484 it->second->CompleteNegotiation(); 485 it->second->ConnectChannels(); 486 } 487 } 488 489 private: 490 bool fail_create_channel_; 491 }; 492 493 } // namespace cricket 494 495 #endif // TALK_P2P_BASE_FAKESESSION_H_ 496