Home | History | Annotate | Download | only in protocol
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "remoting/protocol/channel_multiplexer.h"
      6 
      7 #include <string.h>
      8 
      9 #include "base/bind.h"
     10 #include "base/callback.h"
     11 #include "base/location.h"
     12 #include "base/single_thread_task_runner.h"
     13 #include "base/stl_util.h"
     14 #include "base/thread_task_runner_handle.h"
     15 #include "net/base/net_errors.h"
     16 #include "net/socket/stream_socket.h"
     17 #include "remoting/protocol/util.h"
     18 
     19 namespace remoting {
     20 namespace protocol {
     21 
     22 namespace {
     23 const int kChannelIdUnknown = -1;
     24 const int kMaxPacketSize = 1024;
     25 
     26 class PendingPacket {
     27  public:
     28   PendingPacket(scoped_ptr<MultiplexPacket> packet,
     29                 const base::Closure& done_task)
     30       : packet(packet.Pass()),
     31         done_task(done_task),
     32         pos(0U) {
     33   }
     34   ~PendingPacket() {
     35     done_task.Run();
     36   }
     37 
     38   bool is_empty() { return pos >= packet->data().size(); }
     39 
     40   int Read(char* buffer, size_t size) {
     41     size = std::min(size, packet->data().size() - pos);
     42     memcpy(buffer, packet->data().data() + pos, size);
     43     pos += size;
     44     return size;
     45   }
     46 
     47  private:
     48   scoped_ptr<MultiplexPacket> packet;
     49   base::Closure done_task;
     50   size_t pos;
     51 
     52   DISALLOW_COPY_AND_ASSIGN(PendingPacket);
     53 };
     54 
     55 }  // namespace
     56 
     57 const char ChannelMultiplexer::kMuxChannelName[] = "mux";
     58 
     59 struct ChannelMultiplexer::PendingChannel {
     60   PendingChannel(const std::string& name,
     61                  const StreamChannelCallback& callback)
     62       : name(name), callback(callback) {
     63   }
     64   std::string name;
     65   StreamChannelCallback callback;
     66 };
     67 
     68 class ChannelMultiplexer::MuxChannel {
     69  public:
     70   MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name,
     71              int send_id);
     72   ~MuxChannel();
     73 
     74   const std::string& name() { return name_; }
     75   int receive_id() { return receive_id_; }
     76   void set_receive_id(int id) { receive_id_ = id; }
     77 
     78   // Called by ChannelMultiplexer.
     79   scoped_ptr<net::StreamSocket> CreateSocket();
     80   void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
     81                         const base::Closure& done_task);
     82   void OnWriteFailed();
     83 
     84   // Called by MuxSocket.
     85   void OnSocketDestroyed();
     86   bool DoWrite(scoped_ptr<MultiplexPacket> packet,
     87                const base::Closure& done_task);
     88   int DoRead(net::IOBuffer* buffer, int buffer_len);
     89 
     90  private:
     91   ChannelMultiplexer* multiplexer_;
     92   std::string name_;
     93   int send_id_;
     94   bool id_sent_;
     95   int receive_id_;
     96   MuxSocket* socket_;
     97   std::list<PendingPacket*> pending_packets_;
     98 
     99   DISALLOW_COPY_AND_ASSIGN(MuxChannel);
    100 };
    101 
    102 class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
    103                                       public base::NonThreadSafe,
    104                                       public base::SupportsWeakPtr<MuxSocket> {
    105  public:
    106   MuxSocket(MuxChannel* channel);
    107   virtual ~MuxSocket();
    108 
    109   void OnWriteComplete();
    110   void OnWriteFailed();
    111   void OnPacketReceived();
    112 
    113   // net::StreamSocket interface.
    114   virtual int Read(net::IOBuffer* buffer, int buffer_len,
    115                    const net::CompletionCallback& callback) OVERRIDE;
    116   virtual int Write(net::IOBuffer* buffer, int buffer_len,
    117                     const net::CompletionCallback& callback) OVERRIDE;
    118 
    119   virtual bool SetReceiveBufferSize(int32 size) OVERRIDE {
    120     NOTIMPLEMENTED();
    121     return false;
    122   }
    123   virtual bool SetSendBufferSize(int32 size) OVERRIDE {
    124     NOTIMPLEMENTED();
    125     return false;
    126   }
    127 
    128   virtual int Connect(const net::CompletionCallback& callback) OVERRIDE {
    129     NOTIMPLEMENTED();
    130     return net::ERR_FAILED;
    131   }
    132   virtual void Disconnect() OVERRIDE {
    133     NOTIMPLEMENTED();
    134   }
    135   virtual bool IsConnected() const OVERRIDE {
    136     NOTIMPLEMENTED();
    137     return true;
    138   }
    139   virtual bool IsConnectedAndIdle() const OVERRIDE {
    140     NOTIMPLEMENTED();
    141     return false;
    142   }
    143   virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE {
    144     NOTIMPLEMENTED();
    145     return net::ERR_FAILED;
    146   }
    147   virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE {
    148     NOTIMPLEMENTED();
    149     return net::ERR_FAILED;
    150   }
    151   virtual const net::BoundNetLog& NetLog() const OVERRIDE {
    152     NOTIMPLEMENTED();
    153     return net_log_;
    154   }
    155   virtual void SetSubresourceSpeculation() OVERRIDE {
    156     NOTIMPLEMENTED();
    157   }
    158   virtual void SetOmniboxSpeculation() OVERRIDE {
    159     NOTIMPLEMENTED();
    160   }
    161   virtual bool WasEverUsed() const OVERRIDE {
    162     return true;
    163   }
    164   virtual bool UsingTCPFastOpen() const OVERRIDE {
    165     return false;
    166   }
    167   virtual bool WasNpnNegotiated() const OVERRIDE {
    168     return false;
    169   }
    170   virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE {
    171     return net::kProtoUnknown;
    172   }
    173   virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE {
    174     NOTIMPLEMENTED();
    175     return false;
    176   }
    177 
    178  private:
    179   MuxChannel* channel_;
    180 
    181   net::CompletionCallback read_callback_;
    182   scoped_refptr<net::IOBuffer> read_buffer_;
    183   int read_buffer_size_;
    184 
    185   bool write_pending_;
    186   int write_result_;
    187   net::CompletionCallback write_callback_;
    188 
    189   net::BoundNetLog net_log_;
    190 
    191   DISALLOW_COPY_AND_ASSIGN(MuxSocket);
    192 };
    193 
    194 
    195 ChannelMultiplexer::MuxChannel::MuxChannel(
    196     ChannelMultiplexer* multiplexer,
    197     const std::string& name,
    198     int send_id)
    199     : multiplexer_(multiplexer),
    200       name_(name),
    201       send_id_(send_id),
    202       id_sent_(false),
    203       receive_id_(kChannelIdUnknown),
    204       socket_(NULL) {
    205 }
    206 
    207 ChannelMultiplexer::MuxChannel::~MuxChannel() {
    208   // Socket must be destroyed before the channel.
    209   DCHECK(!socket_);
    210   STLDeleteElements(&pending_packets_);
    211 }
    212 
    213 scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() {
    214   DCHECK(!socket_);  // Can't create more than one socket per channel.
    215   scoped_ptr<MuxSocket> result(new MuxSocket(this));
    216   socket_ = result.get();
    217   return result.PassAs<net::StreamSocket>();
    218 }
    219 
    220 void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
    221     scoped_ptr<MultiplexPacket> packet,
    222     const base::Closure& done_task) {
    223   DCHECK_EQ(packet->channel_id(), receive_id_);
    224   if (packet->data().size() > 0) {
    225     pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task));
    226     if (socket_) {
    227       // Notify the socket that we have more data.
    228       socket_->OnPacketReceived();
    229     }
    230   }
    231 }
    232 
    233 void ChannelMultiplexer::MuxChannel::OnWriteFailed() {
    234   if (socket_)
    235     socket_->OnWriteFailed();
    236 }
    237 
    238 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
    239   DCHECK(socket_);
    240   socket_ = NULL;
    241 }
    242 
    243 bool ChannelMultiplexer::MuxChannel::DoWrite(
    244     scoped_ptr<MultiplexPacket> packet,
    245     const base::Closure& done_task) {
    246   packet->set_channel_id(send_id_);
    247   if (!id_sent_) {
    248     packet->set_channel_name(name_);
    249     id_sent_ = true;
    250   }
    251   return multiplexer_->DoWrite(packet.Pass(), done_task);
    252 }
    253 
    254 int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer,
    255                                            int buffer_len) {
    256   int pos = 0;
    257   while (buffer_len > 0 && !pending_packets_.empty()) {
    258     DCHECK(!pending_packets_.front()->is_empty());
    259     int result = pending_packets_.front()->Read(
    260         buffer->data() + pos, buffer_len);
    261     DCHECK_LE(result, buffer_len);
    262     pos += result;
    263     buffer_len -= pos;
    264     if (pending_packets_.front()->is_empty()) {
    265       delete pending_packets_.front();
    266       pending_packets_.erase(pending_packets_.begin());
    267     }
    268   }
    269   return pos;
    270 }
    271 
    272 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel)
    273     : channel_(channel),
    274       read_buffer_size_(0),
    275       write_pending_(false),
    276       write_result_(0) {
    277 }
    278 
    279 ChannelMultiplexer::MuxSocket::~MuxSocket() {
    280   channel_->OnSocketDestroyed();
    281 }
    282 
    283 int ChannelMultiplexer::MuxSocket::Read(
    284     net::IOBuffer* buffer, int buffer_len,
    285     const net::CompletionCallback& callback) {
    286   DCHECK(CalledOnValidThread());
    287   DCHECK(read_callback_.is_null());
    288 
    289   int result = channel_->DoRead(buffer, buffer_len);
    290   if (result == 0) {
    291     read_buffer_ = buffer;
    292     read_buffer_size_ = buffer_len;
    293     read_callback_ = callback;
    294     return net::ERR_IO_PENDING;
    295   }
    296   return result;
    297 }
    298 
    299 int ChannelMultiplexer::MuxSocket::Write(
    300     net::IOBuffer* buffer, int buffer_len,
    301     const net::CompletionCallback& callback) {
    302   DCHECK(CalledOnValidThread());
    303 
    304   scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
    305   size_t size = std::min(kMaxPacketSize, buffer_len);
    306   packet->mutable_data()->assign(buffer->data(), size);
    307 
    308   write_pending_ = true;
    309   bool result = channel_->DoWrite(packet.Pass(), base::Bind(
    310       &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr()));
    311 
    312   if (!result) {
    313     // Cannot complete the write, e.g. if the connection has been terminated.
    314     return net::ERR_FAILED;
    315   }
    316 
    317   // OnWriteComplete() might be called above synchronously.
    318   if (write_pending_) {
    319     DCHECK(write_callback_.is_null());
    320     write_callback_ = callback;
    321     write_result_ = size;
    322     return net::ERR_IO_PENDING;
    323   }
    324 
    325   return size;
    326 }
    327 
    328 void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
    329   write_pending_ = false;
    330   if (!write_callback_.is_null()) {
    331     net::CompletionCallback cb;
    332     std::swap(cb, write_callback_);
    333     cb.Run(write_result_);
    334   }
    335 }
    336 
    337 void ChannelMultiplexer::MuxSocket::OnWriteFailed() {
    338   if (!write_callback_.is_null()) {
    339     net::CompletionCallback cb;
    340     std::swap(cb, write_callback_);
    341     cb.Run(net::ERR_FAILED);
    342   }
    343 }
    344 
    345 void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
    346   if (!read_callback_.is_null()) {
    347     int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
    348     read_buffer_ = NULL;
    349     DCHECK_GT(result, 0);
    350     net::CompletionCallback cb;
    351     std::swap(cb, read_callback_);
    352     cb.Run(result);
    353   }
    354 }
    355 
    356 ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory,
    357                                        const std::string& base_channel_name)
    358     : base_channel_factory_(factory),
    359       base_channel_name_(base_channel_name),
    360       next_channel_id_(0),
    361       weak_factory_(this) {
    362 }
    363 
    364 ChannelMultiplexer::~ChannelMultiplexer() {
    365   DCHECK(pending_channels_.empty());
    366   STLDeleteValues(&channels_);
    367 
    368   // Cancel creation of the base channel if it hasn't finished.
    369   if (base_channel_factory_)
    370     base_channel_factory_->CancelChannelCreation(base_channel_name_);
    371 }
    372 
    373 void ChannelMultiplexer::CreateStreamChannel(
    374     const std::string& name,
    375     const StreamChannelCallback& callback) {
    376   if (base_channel_.get()) {
    377     // Already have |base_channel_|. Create new multiplexed channel
    378     // synchronously.
    379     callback.Run(GetOrCreateChannel(name)->CreateSocket());
    380   } else if (!base_channel_.get() && !base_channel_factory_) {
    381     // Fail synchronously if we failed to create |base_channel_|.
    382     callback.Run(scoped_ptr<net::StreamSocket>());
    383   } else {
    384     // Still waiting for the |base_channel_|.
    385     pending_channels_.push_back(PendingChannel(name, callback));
    386 
    387     // If this is the first multiplexed channel then create the base channel.
    388     if (pending_channels_.size() == 1U) {
    389       base_channel_factory_->CreateStreamChannel(
    390           base_channel_name_,
    391           base::Bind(&ChannelMultiplexer::OnBaseChannelReady,
    392                      base::Unretained(this)));
    393     }
    394   }
    395 }
    396 
    397 void ChannelMultiplexer::CreateDatagramChannel(
    398     const std::string& name,
    399     const DatagramChannelCallback& callback) {
    400   NOTIMPLEMENTED();
    401   callback.Run(scoped_ptr<net::Socket>());
    402 }
    403 
    404 void ChannelMultiplexer::CancelChannelCreation(const std::string& name) {
    405   for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
    406        it != pending_channels_.end(); ++it) {
    407     if (it->name == name) {
    408       pending_channels_.erase(it);
    409       return;
    410     }
    411   }
    412 }
    413 
    414 void ChannelMultiplexer::OnBaseChannelReady(
    415     scoped_ptr<net::StreamSocket> socket) {
    416   base_channel_factory_ = NULL;
    417   base_channel_ = socket.Pass();
    418 
    419   if (base_channel_.get()) {
    420     // Initialize reader and writer.
    421     reader_.Init(base_channel_.get(),
    422                  base::Bind(&ChannelMultiplexer::OnIncomingPacket,
    423                             base::Unretained(this)));
    424     writer_.Init(base_channel_.get(),
    425                  base::Bind(&ChannelMultiplexer::OnWriteFailed,
    426                             base::Unretained(this)));
    427   }
    428 
    429   DoCreatePendingChannels();
    430 }
    431 
    432 void ChannelMultiplexer::DoCreatePendingChannels() {
    433   if (pending_channels_.empty())
    434     return;
    435 
    436   // Every time this function is called it connects a single channel and posts a
    437   // separate task to connect other channels. This is necessary because the
    438   // callback may destroy the multiplexer or somehow else modify
    439   // |pending_channels_| list (e.g. call CancelChannelCreation()).
    440   base::ThreadTaskRunnerHandle::Get()->PostTask(
    441       FROM_HERE, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels,
    442                             weak_factory_.GetWeakPtr()));
    443 
    444   PendingChannel c = pending_channels_.front();
    445   pending_channels_.erase(pending_channels_.begin());
    446   scoped_ptr<net::StreamSocket> socket;
    447   if (base_channel_.get())
    448     socket = GetOrCreateChannel(c.name)->CreateSocket();
    449   c.callback.Run(socket.Pass());
    450 }
    451 
    452 ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
    453     const std::string& name) {
    454   // Check if we already have a channel with the requested name.
    455   std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
    456   if (it != channels_.end())
    457     return it->second;
    458 
    459   // Create a new channel if we haven't found existing one.
    460   MuxChannel* channel = new MuxChannel(this, name, next_channel_id_);
    461   ++next_channel_id_;
    462   channels_[channel->name()] = channel;
    463   return channel;
    464 }
    465 
    466 
    467 void ChannelMultiplexer::OnWriteFailed(int error) {
    468   for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
    469        it != channels_.end(); ++it) {
    470     base::ThreadTaskRunnerHandle::Get()->PostTask(
    471         FROM_HERE, base::Bind(&ChannelMultiplexer::NotifyWriteFailed,
    472                               weak_factory_.GetWeakPtr(), it->second->name()));
    473   }
    474 }
    475 
    476 void ChannelMultiplexer::NotifyWriteFailed(const std::string& name) {
    477   std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
    478   if (it != channels_.end()) {
    479     it->second->OnWriteFailed();
    480   }
    481 }
    482 
    483 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
    484                                           const base::Closure& done_task) {
    485   if (!packet->has_channel_id()) {
    486     LOG(ERROR) << "Received packet without channel_id.";
    487     done_task.Run();
    488     return;
    489   }
    490 
    491   int receive_id = packet->channel_id();
    492   MuxChannel* channel = NULL;
    493   std::map<int, MuxChannel*>::iterator it =
    494       channels_by_receive_id_.find(receive_id);
    495   if (it != channels_by_receive_id_.end()) {
    496     channel = it->second;
    497   } else {
    498     // This is a new |channel_id| we haven't seen before. Look it up by name.
    499     if (!packet->has_channel_name()) {
    500       LOG(ERROR) << "Received packet with unknown channel_id and "
    501           "without channel_name.";
    502       done_task.Run();
    503       return;
    504     }
    505     channel = GetOrCreateChannel(packet->channel_name());
    506     channel->set_receive_id(receive_id);
    507     channels_by_receive_id_[receive_id] = channel;
    508   }
    509 
    510   channel->OnIncomingPacket(packet.Pass(), done_task);
    511 }
    512 
    513 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet,
    514                                  const base::Closure& done_task) {
    515   return writer_.Write(SerializeAndFrameMessage(*packet), done_task);
    516 }
    517 
    518 }  // namespace protocol
    519 }  // namespace remoting
    520