1 /* 2 * Copyright 2009 The WebRTC Project Authors. All rights reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #ifndef WEBRTC_P2P_BASE_FAKETRANSPORTCONTROLLER_H_ 12 #define WEBRTC_P2P_BASE_FAKETRANSPORTCONTROLLER_H_ 13 14 #include <map> 15 #include <string> 16 #include <vector> 17 18 #include "webrtc/p2p/base/transport.h" 19 #include "webrtc/p2p/base/transportchannel.h" 20 #include "webrtc/p2p/base/transportcontroller.h" 21 #include "webrtc/p2p/base/transportchannelimpl.h" 22 #include "webrtc/base/bind.h" 23 #include "webrtc/base/buffer.h" 24 #include "webrtc/base/fakesslidentity.h" 25 #include "webrtc/base/messagequeue.h" 26 #include "webrtc/base/sigslot.h" 27 #include "webrtc/base/sslfingerprint.h" 28 #include "webrtc/base/thread.h" 29 30 namespace cricket { 31 32 class FakeTransport; 33 34 namespace { 35 struct PacketMessageData : public rtc::MessageData { 36 PacketMessageData(const char* data, size_t len) : packet(data, len) {} 37 rtc::Buffer packet; 38 }; 39 } // namespace 40 41 // Fake transport channel class, which can be passed to anything that needs a 42 // transport channel. Can be informed of another FakeTransportChannel via 43 // SetDestination. 44 // TODO(hbos): Move implementation to .cc file, this and other classes in file. 45 class FakeTransportChannel : public TransportChannelImpl, 46 public rtc::MessageHandler { 47 public: 48 explicit FakeTransportChannel(Transport* transport, 49 const std::string& name, 50 int component) 51 : TransportChannelImpl(name, component), 52 transport_(transport), 53 dtls_fingerprint_("", nullptr, 0) {} 54 ~FakeTransportChannel() { Reset(); } 55 56 uint64_t IceTiebreaker() const { return tiebreaker_; } 57 IceMode remote_ice_mode() const { return remote_ice_mode_; } 58 const std::string& ice_ufrag() const { return ice_ufrag_; } 59 const std::string& ice_pwd() const { return ice_pwd_; } 60 const std::string& remote_ice_ufrag() const { return remote_ice_ufrag_; } 61 const std::string& remote_ice_pwd() const { return remote_ice_pwd_; } 62 const rtc::SSLFingerprint& dtls_fingerprint() const { 63 return dtls_fingerprint_; 64 } 65 66 // If async, will send packets by "Post"-ing to message queue instead of 67 // synchronously "Send"-ing. 68 void SetAsync(bool async) { async_ = async; } 69 70 Transport* GetTransport() override { return transport_; } 71 72 TransportChannelState GetState() const override { 73 if (connection_count_ == 0) { 74 return had_connection_ ? TransportChannelState::STATE_FAILED 75 : TransportChannelState::STATE_INIT; 76 } 77 78 if (connection_count_ == 1) { 79 return TransportChannelState::STATE_COMPLETED; 80 } 81 82 return TransportChannelState::STATE_CONNECTING; 83 } 84 85 void SetIceRole(IceRole role) override { role_ = role; } 86 IceRole GetIceRole() const override { return role_; } 87 void SetIceTiebreaker(uint64_t tiebreaker) override { 88 tiebreaker_ = tiebreaker; 89 } 90 void SetIceCredentials(const std::string& ice_ufrag, 91 const std::string& ice_pwd) override { 92 ice_ufrag_ = ice_ufrag; 93 ice_pwd_ = ice_pwd; 94 } 95 void SetRemoteIceCredentials(const std::string& ice_ufrag, 96 const std::string& ice_pwd) override { 97 remote_ice_ufrag_ = ice_ufrag; 98 remote_ice_pwd_ = ice_pwd; 99 } 100 101 void SetRemoteIceMode(IceMode mode) override { remote_ice_mode_ = mode; } 102 bool SetRemoteFingerprint(const std::string& alg, 103 const uint8_t* digest, 104 size_t digest_len) override { 105 dtls_fingerprint_ = rtc::SSLFingerprint(alg, digest, digest_len); 106 return true; 107 } 108 bool SetSslRole(rtc::SSLRole role) override { 109 ssl_role_ = role; 110 return true; 111 } 112 bool GetSslRole(rtc::SSLRole* role) const override { 113 *role = ssl_role_; 114 return true; 115 } 116 117 void Connect() override { 118 if (state_ == STATE_INIT) { 119 state_ = STATE_CONNECTING; 120 } 121 } 122 123 void MaybeStartGathering() override { 124 if (gathering_state_ == kIceGatheringNew) { 125 gathering_state_ = kIceGatheringGathering; 126 SignalGatheringState(this); 127 } 128 } 129 130 IceGatheringState gathering_state() const override { 131 return gathering_state_; 132 } 133 134 void Reset() { 135 if (state_ != STATE_INIT) { 136 state_ = STATE_INIT; 137 if (dest_) { 138 dest_->state_ = STATE_INIT; 139 dest_->dest_ = nullptr; 140 dest_ = nullptr; 141 } 142 } 143 } 144 145 void SetWritable(bool writable) { set_writable(writable); } 146 147 void SetDestination(FakeTransportChannel* dest) { 148 if (state_ == STATE_CONNECTING && dest) { 149 // This simulates the delivery of candidates. 150 dest_ = dest; 151 dest_->dest_ = this; 152 if (local_cert_ && dest_->local_cert_) { 153 do_dtls_ = true; 154 dest_->do_dtls_ = true; 155 NegotiateSrtpCiphers(); 156 } 157 state_ = STATE_CONNECTED; 158 dest_->state_ = STATE_CONNECTED; 159 set_writable(true); 160 dest_->set_writable(true); 161 } else if (state_ == STATE_CONNECTED && !dest) { 162 // Simulates loss of connectivity, by asymmetrically forgetting dest_. 163 dest_ = nullptr; 164 state_ = STATE_CONNECTING; 165 set_writable(false); 166 } 167 } 168 169 void SetConnectionCount(size_t connection_count) { 170 size_t old_connection_count = connection_count_; 171 connection_count_ = connection_count; 172 if (connection_count) 173 had_connection_ = true; 174 if (connection_count_ < old_connection_count) 175 SignalConnectionRemoved(this); 176 } 177 178 void SetCandidatesGatheringComplete() { 179 if (gathering_state_ != kIceGatheringComplete) { 180 gathering_state_ = kIceGatheringComplete; 181 SignalGatheringState(this); 182 } 183 } 184 185 void SetReceiving(bool receiving) { set_receiving(receiving); } 186 187 void SetIceConfig(const IceConfig& config) override { 188 receiving_timeout_ = config.receiving_timeout_ms; 189 gather_continually_ = config.gather_continually; 190 } 191 192 int receiving_timeout() const { return receiving_timeout_; } 193 bool gather_continually() const { return gather_continually_; } 194 195 int SendPacket(const char* data, 196 size_t len, 197 const rtc::PacketOptions& options, 198 int flags) override { 199 if (state_ != STATE_CONNECTED) { 200 return -1; 201 } 202 203 if (flags != PF_SRTP_BYPASS && flags != 0) { 204 return -1; 205 } 206 207 PacketMessageData* packet = new PacketMessageData(data, len); 208 if (async_) { 209 rtc::Thread::Current()->Post(this, 0, packet); 210 } else { 211 rtc::Thread::Current()->Send(this, 0, packet); 212 } 213 rtc::SentPacket sent_packet(options.packet_id, rtc::Time()); 214 SignalSentPacket(this, sent_packet); 215 return static_cast<int>(len); 216 } 217 int SetOption(rtc::Socket::Option opt, int value) override { return true; } 218 bool GetOption(rtc::Socket::Option opt, int* value) override { return true; } 219 int GetError() override { return 0; } 220 221 void AddRemoteCandidate(const Candidate& candidate) override { 222 remote_candidates_.push_back(candidate); 223 } 224 const Candidates& remote_candidates() const { return remote_candidates_; } 225 226 void OnMessage(rtc::Message* msg) override { 227 PacketMessageData* data = static_cast<PacketMessageData*>(msg->pdata); 228 dest_->SignalReadPacket(dest_, data->packet.data<char>(), 229 data->packet.size(), rtc::CreatePacketTime(0), 0); 230 delete data; 231 } 232 233 bool SetLocalCertificate( 234 const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) { 235 local_cert_ = certificate; 236 return true; 237 } 238 239 void SetRemoteSSLCertificate(rtc::FakeSSLCertificate* cert) { 240 remote_cert_ = cert; 241 } 242 243 bool IsDtlsActive() const override { return do_dtls_; } 244 245 bool SetSrtpCryptoSuites(const std::vector<int>& ciphers) override { 246 srtp_ciphers_ = ciphers; 247 return true; 248 } 249 250 bool GetSrtpCryptoSuite(int* crypto_suite) override { 251 if (chosen_crypto_suite_ != rtc::SRTP_INVALID_CRYPTO_SUITE) { 252 *crypto_suite = chosen_crypto_suite_; 253 return true; 254 } 255 return false; 256 } 257 258 bool GetSslCipherSuite(int* cipher_suite) override { return false; } 259 260 rtc::scoped_refptr<rtc::RTCCertificate> GetLocalCertificate() const { 261 return local_cert_; 262 } 263 264 bool GetRemoteSSLCertificate(rtc::SSLCertificate** cert) const override { 265 if (!remote_cert_) 266 return false; 267 268 *cert = remote_cert_->GetReference(); 269 return true; 270 } 271 272 bool ExportKeyingMaterial(const std::string& label, 273 const uint8_t* context, 274 size_t context_len, 275 bool use_context, 276 uint8_t* result, 277 size_t result_len) override { 278 if (chosen_crypto_suite_ != rtc::SRTP_INVALID_CRYPTO_SUITE) { 279 memset(result, 0xff, result_len); 280 return true; 281 } 282 283 return false; 284 } 285 286 void NegotiateSrtpCiphers() { 287 for (std::vector<int>::const_iterator it1 = srtp_ciphers_.begin(); 288 it1 != srtp_ciphers_.end(); ++it1) { 289 for (std::vector<int>::const_iterator it2 = dest_->srtp_ciphers_.begin(); 290 it2 != dest_->srtp_ciphers_.end(); ++it2) { 291 if (*it1 == *it2) { 292 chosen_crypto_suite_ = *it1; 293 dest_->chosen_crypto_suite_ = *it2; 294 return; 295 } 296 } 297 } 298 } 299 300 bool GetStats(ConnectionInfos* infos) override { 301 ConnectionInfo info; 302 infos->clear(); 303 infos->push_back(info); 304 return true; 305 } 306 307 void set_ssl_max_protocol_version(rtc::SSLProtocolVersion version) { 308 ssl_max_version_ = version; 309 } 310 rtc::SSLProtocolVersion ssl_max_protocol_version() const { 311 return ssl_max_version_; 312 } 313 314 private: 315 enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED }; 316 Transport* transport_; 317 FakeTransportChannel* dest_ = nullptr; 318 State state_ = STATE_INIT; 319 bool async_ = false; 320 Candidates remote_candidates_; 321 rtc::scoped_refptr<rtc::RTCCertificate> local_cert_; 322 rtc::FakeSSLCertificate* remote_cert_ = nullptr; 323 bool do_dtls_ = false; 324 std::vector<int> srtp_ciphers_; 325 int chosen_crypto_suite_ = rtc::SRTP_INVALID_CRYPTO_SUITE; 326 int receiving_timeout_ = -1; 327 bool gather_continually_ = false; 328 IceRole role_ = ICEROLE_UNKNOWN; 329 uint64_t tiebreaker_ = 0; 330 std::string ice_ufrag_; 331 std::string ice_pwd_; 332 std::string remote_ice_ufrag_; 333 std::string remote_ice_pwd_; 334 IceMode remote_ice_mode_ = ICEMODE_FULL; 335 rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_12; 336 rtc::SSLFingerprint dtls_fingerprint_; 337 rtc::SSLRole ssl_role_ = rtc::SSL_CLIENT; 338 size_t connection_count_ = 0; 339 IceGatheringState gathering_state_ = kIceGatheringNew; 340 bool had_connection_ = false; 341 }; 342 343 // Fake transport class, which can be passed to anything that needs a Transport. 344 // Can be informed of another FakeTransport via SetDestination (low-tech way 345 // of doing candidates) 346 class FakeTransport : public Transport { 347 public: 348 typedef std::map<int, FakeTransportChannel*> ChannelMap; 349 350 explicit FakeTransport(const std::string& name) : Transport(name, nullptr) {} 351 352 // Note that we only have a constructor with the allocator parameter so it can 353 // be wrapped by a DtlsTransport. 354 FakeTransport(const std::string& name, PortAllocator* allocator) 355 : Transport(name, nullptr) {} 356 357 ~FakeTransport() { DestroyAllChannels(); } 358 359 const ChannelMap& channels() const { return channels_; } 360 361 // If async, will send packets by "Post"-ing to message queue instead of 362 // synchronously "Send"-ing. 363 void SetAsync(bool async) { async_ = async; } 364 void SetDestination(FakeTransport* dest) { 365 dest_ = dest; 366 for (const auto& kv : channels_) { 367 kv.second->SetLocalCertificate(certificate_); 368 SetChannelDestination(kv.first, kv.second); 369 } 370 } 371 372 void SetWritable(bool writable) { 373 for (const auto& kv : channels_) { 374 kv.second->SetWritable(writable); 375 } 376 } 377 378 void SetLocalCertificate( 379 const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) override { 380 certificate_ = certificate; 381 } 382 bool GetLocalCertificate( 383 rtc::scoped_refptr<rtc::RTCCertificate>* certificate) override { 384 if (!certificate_) 385 return false; 386 387 *certificate = certificate_; 388 return true; 389 } 390 391 bool GetSslRole(rtc::SSLRole* role) const override { 392 if (channels_.empty()) { 393 return false; 394 } 395 return channels_.begin()->second->GetSslRole(role); 396 } 397 398 bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version) override { 399 ssl_max_version_ = version; 400 for (const auto& kv : channels_) { 401 kv.second->set_ssl_max_protocol_version(ssl_max_version_); 402 } 403 return true; 404 } 405 rtc::SSLProtocolVersion ssl_max_protocol_version() const { 406 return ssl_max_version_; 407 } 408 409 using Transport::local_description; 410 using Transport::remote_description; 411 412 protected: 413 TransportChannelImpl* CreateTransportChannel(int component) override { 414 if (channels_.find(component) != channels_.end()) { 415 return nullptr; 416 } 417 FakeTransportChannel* channel = 418 new FakeTransportChannel(this, name(), component); 419 channel->set_ssl_max_protocol_version(ssl_max_version_); 420 channel->SetAsync(async_); 421 SetChannelDestination(component, channel); 422 channels_[component] = channel; 423 return channel; 424 } 425 426 void DestroyTransportChannel(TransportChannelImpl* channel) override { 427 channels_.erase(channel->component()); 428 delete channel; 429 } 430 431 private: 432 FakeTransportChannel* GetFakeChannel(int component) { 433 auto it = channels_.find(component); 434 return (it != channels_.end()) ? it->second : nullptr; 435 } 436 437 void SetChannelDestination(int component, FakeTransportChannel* channel) { 438 FakeTransportChannel* dest_channel = nullptr; 439 if (dest_) { 440 dest_channel = dest_->GetFakeChannel(component); 441 if (dest_channel) { 442 dest_channel->SetLocalCertificate(dest_->certificate_); 443 } 444 } 445 channel->SetDestination(dest_channel); 446 } 447 448 // Note, this is distinct from the Channel map owned by Transport. 449 // This map just tracks the FakeTransportChannels created by this class. 450 // It's mainly needed so that we can access a FakeTransportChannel directly, 451 // even if wrapped by a DtlsTransportChannelWrapper. 452 ChannelMap channels_; 453 FakeTransport* dest_ = nullptr; 454 bool async_ = false; 455 rtc::scoped_refptr<rtc::RTCCertificate> certificate_; 456 rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_12; 457 }; 458 459 // Fake TransportController class, which can be passed into a BaseChannel object 460 // for test purposes. Can be connected to other FakeTransportControllers via 461 // Connect(). 462 // 463 // This fake is unusual in that for the most part, it's implemented with the 464 // real TransportController code, but with fake TransportChannels underneath. 465 class FakeTransportController : public TransportController { 466 public: 467 FakeTransportController() 468 : TransportController(rtc::Thread::Current(), 469 rtc::Thread::Current(), 470 nullptr), 471 fail_create_channel_(false) {} 472 473 explicit FakeTransportController(IceRole role) 474 : TransportController(rtc::Thread::Current(), 475 rtc::Thread::Current(), 476 nullptr), 477 fail_create_channel_(false) { 478 SetIceRole(role); 479 } 480 481 explicit FakeTransportController(rtc::Thread* worker_thread) 482 : TransportController(rtc::Thread::Current(), worker_thread, nullptr), 483 fail_create_channel_(false) {} 484 485 FakeTransportController(rtc::Thread* worker_thread, IceRole role) 486 : TransportController(rtc::Thread::Current(), worker_thread, nullptr), 487 fail_create_channel_(false) { 488 SetIceRole(role); 489 } 490 491 FakeTransport* GetTransport_w(const std::string& transport_name) { 492 return static_cast<FakeTransport*>( 493 TransportController::GetTransport_w(transport_name)); 494 } 495 496 void Connect(FakeTransportController* dest) { 497 worker_thread()->Invoke<void>( 498 rtc::Bind(&FakeTransportController::Connect_w, this, dest)); 499 } 500 501 TransportChannel* CreateTransportChannel_w(const std::string& transport_name, 502 int component) override { 503 if (fail_create_channel_) { 504 return nullptr; 505 } 506 return TransportController::CreateTransportChannel_w(transport_name, 507 component); 508 } 509 510 void set_fail_channel_creation(bool fail_channel_creation) { 511 fail_create_channel_ = fail_channel_creation; 512 } 513 514 protected: 515 Transport* CreateTransport_w(const std::string& transport_name) override { 516 return new FakeTransport(transport_name); 517 } 518 519 void Connect_w(FakeTransportController* dest) { 520 // Simulate the exchange of candidates. 521 ConnectChannels_w(); 522 dest->ConnectChannels_w(); 523 for (auto& kv : transports()) { 524 FakeTransport* transport = static_cast<FakeTransport*>(kv.second); 525 transport->SetDestination(dest->GetTransport_w(kv.first)); 526 } 527 } 528 529 void ConnectChannels_w() { 530 for (auto& kv : transports()) { 531 FakeTransport* transport = static_cast<FakeTransport*>(kv.second); 532 transport->ConnectChannels(); 533 transport->MaybeStartGathering(); 534 } 535 } 536 537 private: 538 bool fail_create_channel_; 539 }; 540 541 } // namespace cricket 542 543 #endif // WEBRTC_P2P_BASE_FAKETRANSPORTCONTROLLER_H_ 544