Home | History | Annotate | Download | only in base
      1 /*
      2  *  Copyright 2015 The WebRTC Project Authors. All rights reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #include "webrtc/p2p/base/transportcontroller.h"
     12 
     13 #include <algorithm>
     14 
     15 #include "webrtc/base/bind.h"
     16 #include "webrtc/base/checks.h"
     17 #include "webrtc/base/thread.h"
     18 #include "webrtc/p2p/base/dtlstransport.h"
     19 #include "webrtc/p2p/base/p2ptransport.h"
     20 #include "webrtc/p2p/base/port.h"
     21 
     22 namespace cricket {
     23 
     24 enum {
     25   MSG_ICECONNECTIONSTATE,
     26   MSG_RECEIVING,
     27   MSG_ICEGATHERINGSTATE,
     28   MSG_CANDIDATESGATHERED,
     29 };
     30 
     31 struct CandidatesData : public rtc::MessageData {
     32   CandidatesData(const std::string& transport_name,
     33                  const Candidates& candidates)
     34       : transport_name(transport_name), candidates(candidates) {}
     35 
     36   std::string transport_name;
     37   Candidates candidates;
     38 };
     39 
     40 TransportController::TransportController(rtc::Thread* signaling_thread,
     41                                          rtc::Thread* worker_thread,
     42                                          PortAllocator* port_allocator)
     43     : signaling_thread_(signaling_thread),
     44       worker_thread_(worker_thread),
     45       port_allocator_(port_allocator) {}
     46 
     47 TransportController::~TransportController() {
     48   worker_thread_->Invoke<void>(
     49       rtc::Bind(&TransportController::DestroyAllTransports_w, this));
     50   signaling_thread_->Clear(this);
     51 }
     52 
     53 bool TransportController::SetSslMaxProtocolVersion(
     54     rtc::SSLProtocolVersion version) {
     55   return worker_thread_->Invoke<bool>(rtc::Bind(
     56       &TransportController::SetSslMaxProtocolVersion_w, this, version));
     57 }
     58 
     59 void TransportController::SetIceConfig(const IceConfig& config) {
     60   worker_thread_->Invoke<void>(
     61       rtc::Bind(&TransportController::SetIceConfig_w, this, config));
     62 }
     63 
     64 void TransportController::SetIceRole(IceRole ice_role) {
     65   worker_thread_->Invoke<void>(
     66       rtc::Bind(&TransportController::SetIceRole_w, this, ice_role));
     67 }
     68 
     69 bool TransportController::GetSslRole(const std::string& transport_name,
     70                                      rtc::SSLRole* role) {
     71   return worker_thread_->Invoke<bool>(rtc::Bind(
     72       &TransportController::GetSslRole_w, this, transport_name, role));
     73 }
     74 
     75 bool TransportController::SetLocalCertificate(
     76     const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) {
     77   return worker_thread_->Invoke<bool>(rtc::Bind(
     78       &TransportController::SetLocalCertificate_w, this, certificate));
     79 }
     80 
     81 bool TransportController::GetLocalCertificate(
     82     const std::string& transport_name,
     83     rtc::scoped_refptr<rtc::RTCCertificate>* certificate) {
     84   return worker_thread_->Invoke<bool>(
     85       rtc::Bind(&TransportController::GetLocalCertificate_w, this,
     86                 transport_name, certificate));
     87 }
     88 
     89 bool TransportController::GetRemoteSSLCertificate(
     90     const std::string& transport_name,
     91     rtc::SSLCertificate** cert) {
     92   return worker_thread_->Invoke<bool>(
     93       rtc::Bind(&TransportController::GetRemoteSSLCertificate_w, this,
     94                 transport_name, cert));
     95 }
     96 
     97 bool TransportController::SetLocalTransportDescription(
     98     const std::string& transport_name,
     99     const TransportDescription& tdesc,
    100     ContentAction action,
    101     std::string* err) {
    102   return worker_thread_->Invoke<bool>(
    103       rtc::Bind(&TransportController::SetLocalTransportDescription_w, this,
    104                 transport_name, tdesc, action, err));
    105 }
    106 
    107 bool TransportController::SetRemoteTransportDescription(
    108     const std::string& transport_name,
    109     const TransportDescription& tdesc,
    110     ContentAction action,
    111     std::string* err) {
    112   return worker_thread_->Invoke<bool>(
    113       rtc::Bind(&TransportController::SetRemoteTransportDescription_w, this,
    114                 transport_name, tdesc, action, err));
    115 }
    116 
    117 void TransportController::MaybeStartGathering() {
    118   worker_thread_->Invoke<void>(
    119       rtc::Bind(&TransportController::MaybeStartGathering_w, this));
    120 }
    121 
    122 bool TransportController::AddRemoteCandidates(const std::string& transport_name,
    123                                               const Candidates& candidates,
    124                                               std::string* err) {
    125   return worker_thread_->Invoke<bool>(
    126       rtc::Bind(&TransportController::AddRemoteCandidates_w, this,
    127                 transport_name, candidates, err));
    128 }
    129 
    130 bool TransportController::ReadyForRemoteCandidates(
    131     const std::string& transport_name) {
    132   return worker_thread_->Invoke<bool>(rtc::Bind(
    133       &TransportController::ReadyForRemoteCandidates_w, this, transport_name));
    134 }
    135 
    136 bool TransportController::GetStats(const std::string& transport_name,
    137                                    TransportStats* stats) {
    138   return worker_thread_->Invoke<bool>(
    139       rtc::Bind(&TransportController::GetStats_w, this, transport_name, stats));
    140 }
    141 
    142 TransportChannel* TransportController::CreateTransportChannel_w(
    143     const std::string& transport_name,
    144     int component) {
    145   RTC_DCHECK(worker_thread_->IsCurrent());
    146 
    147   auto it = FindChannel_w(transport_name, component);
    148   if (it != channels_.end()) {
    149     // Channel already exists; increment reference count and return.
    150     it->AddRef();
    151     return it->get();
    152   }
    153 
    154   // Need to create a new channel.
    155   Transport* transport = GetOrCreateTransport_w(transport_name);
    156   TransportChannelImpl* channel = transport->CreateChannel(component);
    157   channel->SignalWritableState.connect(
    158       this, &TransportController::OnChannelWritableState_w);
    159   channel->SignalReceivingState.connect(
    160       this, &TransportController::OnChannelReceivingState_w);
    161   channel->SignalGatheringState.connect(
    162       this, &TransportController::OnChannelGatheringState_w);
    163   channel->SignalCandidateGathered.connect(
    164       this, &TransportController::OnChannelCandidateGathered_w);
    165   channel->SignalRoleConflict.connect(
    166       this, &TransportController::OnChannelRoleConflict_w);
    167   channel->SignalConnectionRemoved.connect(
    168       this, &TransportController::OnChannelConnectionRemoved_w);
    169   channels_.insert(channels_.end(), RefCountedChannel(channel))->AddRef();
    170   // Adding a channel could cause aggregate state to change.
    171   UpdateAggregateStates_w();
    172   return channel;
    173 }
    174 
    175 void TransportController::DestroyTransportChannel_w(
    176     const std::string& transport_name,
    177     int component) {
    178   RTC_DCHECK(worker_thread_->IsCurrent());
    179 
    180   auto it = FindChannel_w(transport_name, component);
    181   if (it == channels_.end()) {
    182     LOG(LS_WARNING) << "Attempting to delete " << transport_name
    183                     << " TransportChannel " << component
    184                     << ", which doesn't exist.";
    185     return;
    186   }
    187 
    188   it->DecRef();
    189   if (it->ref() > 0) {
    190     return;
    191   }
    192 
    193   channels_.erase(it);
    194   Transport* transport = GetTransport_w(transport_name);
    195   transport->DestroyChannel(component);
    196   // Just as we create a Transport when its first channel is created,
    197   // we delete it when its last channel is deleted.
    198   if (!transport->HasChannels()) {
    199     DestroyTransport_w(transport_name);
    200   }
    201   // Removing a channel could cause aggregate state to change.
    202   UpdateAggregateStates_w();
    203 }
    204 
    205 const rtc::scoped_refptr<rtc::RTCCertificate>&
    206 TransportController::certificate_for_testing() {
    207   return certificate_;
    208 }
    209 
    210 Transport* TransportController::CreateTransport_w(
    211     const std::string& transport_name) {
    212   RTC_DCHECK(worker_thread_->IsCurrent());
    213 
    214   Transport* transport = new DtlsTransport<P2PTransport>(
    215       transport_name, port_allocator(), certificate_);
    216   return transport;
    217 }
    218 
    219 Transport* TransportController::GetTransport_w(
    220     const std::string& transport_name) {
    221   RTC_DCHECK(worker_thread_->IsCurrent());
    222 
    223   auto iter = transports_.find(transport_name);
    224   return (iter != transports_.end()) ? iter->second : nullptr;
    225 }
    226 
    227 void TransportController::OnMessage(rtc::Message* pmsg) {
    228   RTC_DCHECK(signaling_thread_->IsCurrent());
    229 
    230   switch (pmsg->message_id) {
    231     case MSG_ICECONNECTIONSTATE: {
    232       rtc::TypedMessageData<IceConnectionState>* data =
    233           static_cast<rtc::TypedMessageData<IceConnectionState>*>(pmsg->pdata);
    234       SignalConnectionState(data->data());
    235       delete data;
    236       break;
    237     }
    238     case MSG_RECEIVING: {
    239       rtc::TypedMessageData<bool>* data =
    240           static_cast<rtc::TypedMessageData<bool>*>(pmsg->pdata);
    241       SignalReceiving(data->data());
    242       delete data;
    243       break;
    244     }
    245     case MSG_ICEGATHERINGSTATE: {
    246       rtc::TypedMessageData<IceGatheringState>* data =
    247           static_cast<rtc::TypedMessageData<IceGatheringState>*>(pmsg->pdata);
    248       SignalGatheringState(data->data());
    249       delete data;
    250       break;
    251     }
    252     case MSG_CANDIDATESGATHERED: {
    253       CandidatesData* data = static_cast<CandidatesData*>(pmsg->pdata);
    254       SignalCandidatesGathered(data->transport_name, data->candidates);
    255       delete data;
    256       break;
    257     }
    258     default:
    259       ASSERT(false);
    260   }
    261 }
    262 
    263 std::vector<TransportController::RefCountedChannel>::iterator
    264 TransportController::FindChannel_w(const std::string& transport_name,
    265                                    int component) {
    266   return std::find_if(
    267       channels_.begin(), channels_.end(),
    268       [transport_name, component](const RefCountedChannel& channel) {
    269         return channel->transport_name() == transport_name &&
    270                channel->component() == component;
    271       });
    272 }
    273 
    274 Transport* TransportController::GetOrCreateTransport_w(
    275     const std::string& transport_name) {
    276   RTC_DCHECK(worker_thread_->IsCurrent());
    277 
    278   Transport* transport = GetTransport_w(transport_name);
    279   if (transport) {
    280     return transport;
    281   }
    282 
    283   transport = CreateTransport_w(transport_name);
    284   // The stuff below happens outside of CreateTransport_w so that unit tests
    285   // can override CreateTransport_w to return a different type of transport.
    286   transport->SetSslMaxProtocolVersion(ssl_max_version_);
    287   transport->SetIceConfig(ice_config_);
    288   transport->SetIceRole(ice_role_);
    289   transport->SetIceTiebreaker(ice_tiebreaker_);
    290   if (certificate_) {
    291     transport->SetLocalCertificate(certificate_);
    292   }
    293   transports_[transport_name] = transport;
    294 
    295   return transport;
    296 }
    297 
    298 void TransportController::DestroyTransport_w(
    299     const std::string& transport_name) {
    300   RTC_DCHECK(worker_thread_->IsCurrent());
    301 
    302   auto iter = transports_.find(transport_name);
    303   if (iter != transports_.end()) {
    304     delete iter->second;
    305     transports_.erase(transport_name);
    306   }
    307 }
    308 
    309 void TransportController::DestroyAllTransports_w() {
    310   RTC_DCHECK(worker_thread_->IsCurrent());
    311 
    312   for (const auto& kv : transports_) {
    313     delete kv.second;
    314   }
    315   transports_.clear();
    316 }
    317 
    318 bool TransportController::SetSslMaxProtocolVersion_w(
    319     rtc::SSLProtocolVersion version) {
    320   RTC_DCHECK(worker_thread_->IsCurrent());
    321 
    322   // Max SSL version can only be set before transports are created.
    323   if (!transports_.empty()) {
    324     return false;
    325   }
    326 
    327   ssl_max_version_ = version;
    328   return true;
    329 }
    330 
    331 void TransportController::SetIceConfig_w(const IceConfig& config) {
    332   RTC_DCHECK(worker_thread_->IsCurrent());
    333   ice_config_ = config;
    334   for (const auto& kv : transports_) {
    335     kv.second->SetIceConfig(ice_config_);
    336   }
    337 }
    338 
    339 void TransportController::SetIceRole_w(IceRole ice_role) {
    340   RTC_DCHECK(worker_thread_->IsCurrent());
    341   ice_role_ = ice_role;
    342   for (const auto& kv : transports_) {
    343     kv.second->SetIceRole(ice_role_);
    344   }
    345 }
    346 
    347 bool TransportController::GetSslRole_w(const std::string& transport_name,
    348                                        rtc::SSLRole* role) {
    349   RTC_DCHECK(worker_thread()->IsCurrent());
    350 
    351   Transport* t = GetTransport_w(transport_name);
    352   if (!t) {
    353     return false;
    354   }
    355 
    356   return t->GetSslRole(role);
    357 }
    358 
    359 bool TransportController::SetLocalCertificate_w(
    360     const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) {
    361   RTC_DCHECK(worker_thread_->IsCurrent());
    362 
    363   if (certificate_) {
    364     return false;
    365   }
    366   if (!certificate) {
    367     return false;
    368   }
    369   certificate_ = certificate;
    370 
    371   for (const auto& kv : transports_) {
    372     kv.second->SetLocalCertificate(certificate_);
    373   }
    374   return true;
    375 }
    376 
    377 bool TransportController::GetLocalCertificate_w(
    378     const std::string& transport_name,
    379     rtc::scoped_refptr<rtc::RTCCertificate>* certificate) {
    380   RTC_DCHECK(worker_thread_->IsCurrent());
    381 
    382   Transport* t = GetTransport_w(transport_name);
    383   if (!t) {
    384     return false;
    385   }
    386 
    387   return t->GetLocalCertificate(certificate);
    388 }
    389 
    390 bool TransportController::GetRemoteSSLCertificate_w(
    391     const std::string& transport_name,
    392     rtc::SSLCertificate** cert) {
    393   RTC_DCHECK(worker_thread_->IsCurrent());
    394 
    395   Transport* t = GetTransport_w(transport_name);
    396   if (!t) {
    397     return false;
    398   }
    399 
    400   return t->GetRemoteSSLCertificate(cert);
    401 }
    402 
    403 bool TransportController::SetLocalTransportDescription_w(
    404     const std::string& transport_name,
    405     const TransportDescription& tdesc,
    406     ContentAction action,
    407     std::string* err) {
    408   RTC_DCHECK(worker_thread()->IsCurrent());
    409 
    410   Transport* transport = GetTransport_w(transport_name);
    411   if (!transport) {
    412     // If we didn't find a transport, that's not an error;
    413     // it could have been deleted as a result of bundling.
    414     // TODO(deadbeef): Make callers smarter so they won't attempt to set a
    415     // description on a deleted transport.
    416     return true;
    417   }
    418 
    419   return transport->SetLocalTransportDescription(tdesc, action, err);
    420 }
    421 
    422 bool TransportController::SetRemoteTransportDescription_w(
    423     const std::string& transport_name,
    424     const TransportDescription& tdesc,
    425     ContentAction action,
    426     std::string* err) {
    427   RTC_DCHECK(worker_thread()->IsCurrent());
    428 
    429   Transport* transport = GetTransport_w(transport_name);
    430   if (!transport) {
    431     // If we didn't find a transport, that's not an error;
    432     // it could have been deleted as a result of bundling.
    433     // TODO(deadbeef): Make callers smarter so they won't attempt to set a
    434     // description on a deleted transport.
    435     return true;
    436   }
    437 
    438   return transport->SetRemoteTransportDescription(tdesc, action, err);
    439 }
    440 
    441 void TransportController::MaybeStartGathering_w() {
    442   for (const auto& kv : transports_) {
    443     kv.second->MaybeStartGathering();
    444   }
    445 }
    446 
    447 bool TransportController::AddRemoteCandidates_w(
    448     const std::string& transport_name,
    449     const Candidates& candidates,
    450     std::string* err) {
    451   RTC_DCHECK(worker_thread()->IsCurrent());
    452 
    453   Transport* transport = GetTransport_w(transport_name);
    454   if (!transport) {
    455     // If we didn't find a transport, that's not an error;
    456     // it could have been deleted as a result of bundling.
    457     return true;
    458   }
    459 
    460   return transport->AddRemoteCandidates(candidates, err);
    461 }
    462 
    463 bool TransportController::ReadyForRemoteCandidates_w(
    464     const std::string& transport_name) {
    465   RTC_DCHECK(worker_thread()->IsCurrent());
    466 
    467   Transport* transport = GetTransport_w(transport_name);
    468   if (!transport) {
    469     return false;
    470   }
    471   return transport->ready_for_remote_candidates();
    472 }
    473 
    474 bool TransportController::GetStats_w(const std::string& transport_name,
    475                                      TransportStats* stats) {
    476   RTC_DCHECK(worker_thread()->IsCurrent());
    477 
    478   Transport* transport = GetTransport_w(transport_name);
    479   if (!transport) {
    480     return false;
    481   }
    482   return transport->GetStats(stats);
    483 }
    484 
    485 void TransportController::OnChannelWritableState_w(TransportChannel* channel) {
    486   RTC_DCHECK(worker_thread_->IsCurrent());
    487   LOG(LS_INFO) << channel->transport_name() << " TransportChannel "
    488                << channel->component() << " writability changed to "
    489                << channel->writable() << ".";
    490   UpdateAggregateStates_w();
    491 }
    492 
    493 void TransportController::OnChannelReceivingState_w(TransportChannel* channel) {
    494   RTC_DCHECK(worker_thread_->IsCurrent());
    495   UpdateAggregateStates_w();
    496 }
    497 
    498 void TransportController::OnChannelGatheringState_w(
    499     TransportChannelImpl* channel) {
    500   RTC_DCHECK(worker_thread_->IsCurrent());
    501   UpdateAggregateStates_w();
    502 }
    503 
    504 void TransportController::OnChannelCandidateGathered_w(
    505     TransportChannelImpl* channel,
    506     const Candidate& candidate) {
    507   RTC_DCHECK(worker_thread_->IsCurrent());
    508 
    509   // We should never signal peer-reflexive candidates.
    510   if (candidate.type() == PRFLX_PORT_TYPE) {
    511     RTC_DCHECK(false);
    512     return;
    513   }
    514   std::vector<Candidate> candidates;
    515   candidates.push_back(candidate);
    516   CandidatesData* data =
    517       new CandidatesData(channel->transport_name(), candidates);
    518   signaling_thread_->Post(this, MSG_CANDIDATESGATHERED, data);
    519 }
    520 
    521 void TransportController::OnChannelRoleConflict_w(
    522     TransportChannelImpl* channel) {
    523   RTC_DCHECK(worker_thread_->IsCurrent());
    524 
    525   if (ice_role_switch_) {
    526     LOG(LS_WARNING)
    527         << "Repeat of role conflict signal from TransportChannelImpl.";
    528     return;
    529   }
    530 
    531   ice_role_switch_ = true;
    532   IceRole reversed_role = (ice_role_ == ICEROLE_CONTROLLING)
    533                               ? ICEROLE_CONTROLLED
    534                               : ICEROLE_CONTROLLING;
    535   for (const auto& kv : transports_) {
    536     kv.second->SetIceRole(reversed_role);
    537   }
    538 }
    539 
    540 void TransportController::OnChannelConnectionRemoved_w(
    541     TransportChannelImpl* channel) {
    542   RTC_DCHECK(worker_thread_->IsCurrent());
    543   LOG(LS_INFO) << channel->transport_name() << " TransportChannel "
    544                << channel->component()
    545                << " connection removed. Check if state is complete.";
    546   UpdateAggregateStates_w();
    547 }
    548 
    549 void TransportController::UpdateAggregateStates_w() {
    550   RTC_DCHECK(worker_thread_->IsCurrent());
    551 
    552   IceConnectionState new_connection_state = kIceConnectionConnecting;
    553   IceGatheringState new_gathering_state = kIceGatheringNew;
    554   bool any_receiving = false;
    555   bool any_failed = false;
    556   bool all_connected = !channels_.empty();
    557   bool all_completed = !channels_.empty();
    558   bool any_gathering = false;
    559   bool all_done_gathering = !channels_.empty();
    560   for (const auto& channel : channels_) {
    561     any_receiving = any_receiving || channel->receiving();
    562     any_failed = any_failed ||
    563                  channel->GetState() == TransportChannelState::STATE_FAILED;
    564     all_connected = all_connected && channel->writable();
    565     all_completed =
    566         all_completed && channel->writable() &&
    567         channel->GetState() == TransportChannelState::STATE_COMPLETED &&
    568         channel->GetIceRole() == ICEROLE_CONTROLLING &&
    569         channel->gathering_state() == kIceGatheringComplete;
    570     any_gathering =
    571         any_gathering || channel->gathering_state() != kIceGatheringNew;
    572     all_done_gathering = all_done_gathering &&
    573                          channel->gathering_state() == kIceGatheringComplete;
    574   }
    575 
    576   if (any_failed) {
    577     new_connection_state = kIceConnectionFailed;
    578   } else if (all_completed) {
    579     new_connection_state = kIceConnectionCompleted;
    580   } else if (all_connected) {
    581     new_connection_state = kIceConnectionConnected;
    582   }
    583   if (connection_state_ != new_connection_state) {
    584     connection_state_ = new_connection_state;
    585     signaling_thread_->Post(
    586         this, MSG_ICECONNECTIONSTATE,
    587         new rtc::TypedMessageData<IceConnectionState>(new_connection_state));
    588   }
    589 
    590   if (receiving_ != any_receiving) {
    591     receiving_ = any_receiving;
    592     signaling_thread_->Post(this, MSG_RECEIVING,
    593                             new rtc::TypedMessageData<bool>(any_receiving));
    594   }
    595 
    596   if (all_done_gathering) {
    597     new_gathering_state = kIceGatheringComplete;
    598   } else if (any_gathering) {
    599     new_gathering_state = kIceGatheringGathering;
    600   }
    601   if (gathering_state_ != new_gathering_state) {
    602     gathering_state_ = new_gathering_state;
    603     signaling_thread_->Post(
    604         this, MSG_ICEGATHERINGSTATE,
    605         new rtc::TypedMessageData<IceGatheringState>(new_gathering_state));
    606   }
    607 }
    608 
    609 }  // namespace cricket
    610