1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "remoting/protocol/channel_multiplexer.h" 6 7 #include <string.h> 8 9 #include "base/bind.h" 10 #include "base/callback.h" 11 #include "base/location.h" 12 #include "base/single_thread_task_runner.h" 13 #include "base/stl_util.h" 14 #include "base/thread_task_runner_handle.h" 15 #include "net/base/net_errors.h" 16 #include "net/socket/stream_socket.h" 17 #include "remoting/protocol/util.h" 18 19 namespace remoting { 20 namespace protocol { 21 22 namespace { 23 const int kChannelIdUnknown = -1; 24 const int kMaxPacketSize = 1024; 25 26 class PendingPacket { 27 public: 28 PendingPacket(scoped_ptr<MultiplexPacket> packet, 29 const base::Closure& done_task) 30 : packet(packet.Pass()), 31 done_task(done_task), 32 pos(0U) { 33 } 34 ~PendingPacket() { 35 done_task.Run(); 36 } 37 38 bool is_empty() { return pos >= packet->data().size(); } 39 40 int Read(char* buffer, size_t size) { 41 size = std::min(size, packet->data().size() - pos); 42 memcpy(buffer, packet->data().data() + pos, size); 43 pos += size; 44 return size; 45 } 46 47 private: 48 scoped_ptr<MultiplexPacket> packet; 49 base::Closure done_task; 50 size_t pos; 51 52 DISALLOW_COPY_AND_ASSIGN(PendingPacket); 53 }; 54 55 } // namespace 56 57 const char ChannelMultiplexer::kMuxChannelName[] = "mux"; 58 59 struct ChannelMultiplexer::PendingChannel { 60 PendingChannel(const std::string& name, 61 const StreamChannelCallback& callback) 62 : name(name), callback(callback) { 63 } 64 std::string name; 65 StreamChannelCallback callback; 66 }; 67 68 class ChannelMultiplexer::MuxChannel { 69 public: 70 MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name, 71 int send_id); 72 ~MuxChannel(); 73 74 const std::string& name() { return name_; } 75 int receive_id() { return receive_id_; } 76 void set_receive_id(int id) { receive_id_ = id; } 77 78 // Called by ChannelMultiplexer. 79 scoped_ptr<net::StreamSocket> CreateSocket(); 80 void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, 81 const base::Closure& done_task); 82 void OnWriteFailed(); 83 84 // Called by MuxSocket. 85 void OnSocketDestroyed(); 86 bool DoWrite(scoped_ptr<MultiplexPacket> packet, 87 const base::Closure& done_task); 88 int DoRead(net::IOBuffer* buffer, int buffer_len); 89 90 private: 91 ChannelMultiplexer* multiplexer_; 92 std::string name_; 93 int send_id_; 94 bool id_sent_; 95 int receive_id_; 96 MuxSocket* socket_; 97 std::list<PendingPacket*> pending_packets_; 98 99 DISALLOW_COPY_AND_ASSIGN(MuxChannel); 100 }; 101 102 class ChannelMultiplexer::MuxSocket : public net::StreamSocket, 103 public base::NonThreadSafe, 104 public base::SupportsWeakPtr<MuxSocket> { 105 public: 106 MuxSocket(MuxChannel* channel); 107 virtual ~MuxSocket(); 108 109 void OnWriteComplete(); 110 void OnWriteFailed(); 111 void OnPacketReceived(); 112 113 // net::StreamSocket interface. 114 virtual int Read(net::IOBuffer* buffer, int buffer_len, 115 const net::CompletionCallback& callback) OVERRIDE; 116 virtual int Write(net::IOBuffer* buffer, int buffer_len, 117 const net::CompletionCallback& callback) OVERRIDE; 118 119 virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { 120 NOTIMPLEMENTED(); 121 return false; 122 } 123 virtual bool SetSendBufferSize(int32 size) OVERRIDE { 124 NOTIMPLEMENTED(); 125 return false; 126 } 127 128 virtual int Connect(const net::CompletionCallback& callback) OVERRIDE { 129 NOTIMPLEMENTED(); 130 return net::ERR_FAILED; 131 } 132 virtual void Disconnect() OVERRIDE { 133 NOTIMPLEMENTED(); 134 } 135 virtual bool IsConnected() const OVERRIDE { 136 NOTIMPLEMENTED(); 137 return true; 138 } 139 virtual bool IsConnectedAndIdle() const OVERRIDE { 140 NOTIMPLEMENTED(); 141 return false; 142 } 143 virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE { 144 NOTIMPLEMENTED(); 145 return net::ERR_FAILED; 146 } 147 virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE { 148 NOTIMPLEMENTED(); 149 return net::ERR_FAILED; 150 } 151 virtual const net::BoundNetLog& NetLog() const OVERRIDE { 152 NOTIMPLEMENTED(); 153 return net_log_; 154 } 155 virtual void SetSubresourceSpeculation() OVERRIDE { 156 NOTIMPLEMENTED(); 157 } 158 virtual void SetOmniboxSpeculation() OVERRIDE { 159 NOTIMPLEMENTED(); 160 } 161 virtual bool WasEverUsed() const OVERRIDE { 162 return true; 163 } 164 virtual bool UsingTCPFastOpen() const OVERRIDE { 165 return false; 166 } 167 virtual bool WasNpnNegotiated() const OVERRIDE { 168 return false; 169 } 170 virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE { 171 return net::kProtoUnknown; 172 } 173 virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE { 174 NOTIMPLEMENTED(); 175 return false; 176 } 177 178 private: 179 MuxChannel* channel_; 180 181 net::CompletionCallback read_callback_; 182 scoped_refptr<net::IOBuffer> read_buffer_; 183 int read_buffer_size_; 184 185 bool write_pending_; 186 int write_result_; 187 net::CompletionCallback write_callback_; 188 189 net::BoundNetLog net_log_; 190 191 DISALLOW_COPY_AND_ASSIGN(MuxSocket); 192 }; 193 194 195 ChannelMultiplexer::MuxChannel::MuxChannel( 196 ChannelMultiplexer* multiplexer, 197 const std::string& name, 198 int send_id) 199 : multiplexer_(multiplexer), 200 name_(name), 201 send_id_(send_id), 202 id_sent_(false), 203 receive_id_(kChannelIdUnknown), 204 socket_(NULL) { 205 } 206 207 ChannelMultiplexer::MuxChannel::~MuxChannel() { 208 // Socket must be destroyed before the channel. 209 DCHECK(!socket_); 210 STLDeleteElements(&pending_packets_); 211 } 212 213 scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() { 214 DCHECK(!socket_); // Can't create more than one socket per channel. 215 scoped_ptr<MuxSocket> result(new MuxSocket(this)); 216 socket_ = result.get(); 217 return result.PassAs<net::StreamSocket>(); 218 } 219 220 void ChannelMultiplexer::MuxChannel::OnIncomingPacket( 221 scoped_ptr<MultiplexPacket> packet, 222 const base::Closure& done_task) { 223 DCHECK_EQ(packet->channel_id(), receive_id_); 224 if (packet->data().size() > 0) { 225 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task)); 226 if (socket_) { 227 // Notify the socket that we have more data. 228 socket_->OnPacketReceived(); 229 } 230 } 231 } 232 233 void ChannelMultiplexer::MuxChannel::OnWriteFailed() { 234 if (socket_) 235 socket_->OnWriteFailed(); 236 } 237 238 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() { 239 DCHECK(socket_); 240 socket_ = NULL; 241 } 242 243 bool ChannelMultiplexer::MuxChannel::DoWrite( 244 scoped_ptr<MultiplexPacket> packet, 245 const base::Closure& done_task) { 246 packet->set_channel_id(send_id_); 247 if (!id_sent_) { 248 packet->set_channel_name(name_); 249 id_sent_ = true; 250 } 251 return multiplexer_->DoWrite(packet.Pass(), done_task); 252 } 253 254 int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer, 255 int buffer_len) { 256 int pos = 0; 257 while (buffer_len > 0 && !pending_packets_.empty()) { 258 DCHECK(!pending_packets_.front()->is_empty()); 259 int result = pending_packets_.front()->Read( 260 buffer->data() + pos, buffer_len); 261 DCHECK_LE(result, buffer_len); 262 pos += result; 263 buffer_len -= pos; 264 if (pending_packets_.front()->is_empty()) { 265 delete pending_packets_.front(); 266 pending_packets_.erase(pending_packets_.begin()); 267 } 268 } 269 return pos; 270 } 271 272 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel) 273 : channel_(channel), 274 read_buffer_size_(0), 275 write_pending_(false), 276 write_result_(0) { 277 } 278 279 ChannelMultiplexer::MuxSocket::~MuxSocket() { 280 channel_->OnSocketDestroyed(); 281 } 282 283 int ChannelMultiplexer::MuxSocket::Read( 284 net::IOBuffer* buffer, int buffer_len, 285 const net::CompletionCallback& callback) { 286 DCHECK(CalledOnValidThread()); 287 DCHECK(read_callback_.is_null()); 288 289 int result = channel_->DoRead(buffer, buffer_len); 290 if (result == 0) { 291 read_buffer_ = buffer; 292 read_buffer_size_ = buffer_len; 293 read_callback_ = callback; 294 return net::ERR_IO_PENDING; 295 } 296 return result; 297 } 298 299 int ChannelMultiplexer::MuxSocket::Write( 300 net::IOBuffer* buffer, int buffer_len, 301 const net::CompletionCallback& callback) { 302 DCHECK(CalledOnValidThread()); 303 304 scoped_ptr<MultiplexPacket> packet(new MultiplexPacket()); 305 size_t size = std::min(kMaxPacketSize, buffer_len); 306 packet->mutable_data()->assign(buffer->data(), size); 307 308 write_pending_ = true; 309 bool result = channel_->DoWrite(packet.Pass(), base::Bind( 310 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr())); 311 312 if (!result) { 313 // Cannot complete the write, e.g. if the connection has been terminated. 314 return net::ERR_FAILED; 315 } 316 317 // OnWriteComplete() might be called above synchronously. 318 if (write_pending_) { 319 DCHECK(write_callback_.is_null()); 320 write_callback_ = callback; 321 write_result_ = size; 322 return net::ERR_IO_PENDING; 323 } 324 325 return size; 326 } 327 328 void ChannelMultiplexer::MuxSocket::OnWriteComplete() { 329 write_pending_ = false; 330 if (!write_callback_.is_null()) { 331 net::CompletionCallback cb; 332 std::swap(cb, write_callback_); 333 cb.Run(write_result_); 334 } 335 } 336 337 void ChannelMultiplexer::MuxSocket::OnWriteFailed() { 338 if (!write_callback_.is_null()) { 339 net::CompletionCallback cb; 340 std::swap(cb, write_callback_); 341 cb.Run(net::ERR_FAILED); 342 } 343 } 344 345 void ChannelMultiplexer::MuxSocket::OnPacketReceived() { 346 if (!read_callback_.is_null()) { 347 int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_); 348 read_buffer_ = NULL; 349 DCHECK_GT(result, 0); 350 net::CompletionCallback cb; 351 std::swap(cb, read_callback_); 352 cb.Run(result); 353 } 354 } 355 356 ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory, 357 const std::string& base_channel_name) 358 : base_channel_factory_(factory), 359 base_channel_name_(base_channel_name), 360 next_channel_id_(0), 361 weak_factory_(this) { 362 } 363 364 ChannelMultiplexer::~ChannelMultiplexer() { 365 DCHECK(pending_channels_.empty()); 366 STLDeleteValues(&channels_); 367 368 // Cancel creation of the base channel if it hasn't finished. 369 if (base_channel_factory_) 370 base_channel_factory_->CancelChannelCreation(base_channel_name_); 371 } 372 373 void ChannelMultiplexer::CreateStreamChannel( 374 const std::string& name, 375 const StreamChannelCallback& callback) { 376 if (base_channel_.get()) { 377 // Already have |base_channel_|. Create new multiplexed channel 378 // synchronously. 379 callback.Run(GetOrCreateChannel(name)->CreateSocket()); 380 } else if (!base_channel_.get() && !base_channel_factory_) { 381 // Fail synchronously if we failed to create |base_channel_|. 382 callback.Run(scoped_ptr<net::StreamSocket>()); 383 } else { 384 // Still waiting for the |base_channel_|. 385 pending_channels_.push_back(PendingChannel(name, callback)); 386 387 // If this is the first multiplexed channel then create the base channel. 388 if (pending_channels_.size() == 1U) { 389 base_channel_factory_->CreateStreamChannel( 390 base_channel_name_, 391 base::Bind(&ChannelMultiplexer::OnBaseChannelReady, 392 base::Unretained(this))); 393 } 394 } 395 } 396 397 void ChannelMultiplexer::CreateDatagramChannel( 398 const std::string& name, 399 const DatagramChannelCallback& callback) { 400 NOTIMPLEMENTED(); 401 callback.Run(scoped_ptr<net::Socket>()); 402 } 403 404 void ChannelMultiplexer::CancelChannelCreation(const std::string& name) { 405 for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); 406 it != pending_channels_.end(); ++it) { 407 if (it->name == name) { 408 pending_channels_.erase(it); 409 return; 410 } 411 } 412 } 413 414 void ChannelMultiplexer::OnBaseChannelReady( 415 scoped_ptr<net::StreamSocket> socket) { 416 base_channel_factory_ = NULL; 417 base_channel_ = socket.Pass(); 418 419 if (base_channel_.get()) { 420 // Initialize reader and writer. 421 reader_.Init(base_channel_.get(), 422 base::Bind(&ChannelMultiplexer::OnIncomingPacket, 423 base::Unretained(this))); 424 writer_.Init(base_channel_.get(), 425 base::Bind(&ChannelMultiplexer::OnWriteFailed, 426 base::Unretained(this))); 427 } 428 429 DoCreatePendingChannels(); 430 } 431 432 void ChannelMultiplexer::DoCreatePendingChannels() { 433 if (pending_channels_.empty()) 434 return; 435 436 // Every time this function is called it connects a single channel and posts a 437 // separate task to connect other channels. This is necessary because the 438 // callback may destroy the multiplexer or somehow else modify 439 // |pending_channels_| list (e.g. call CancelChannelCreation()). 440 base::ThreadTaskRunnerHandle::Get()->PostTask( 441 FROM_HERE, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels, 442 weak_factory_.GetWeakPtr())); 443 444 PendingChannel c = pending_channels_.front(); 445 pending_channels_.erase(pending_channels_.begin()); 446 scoped_ptr<net::StreamSocket> socket; 447 if (base_channel_.get()) 448 socket = GetOrCreateChannel(c.name)->CreateSocket(); 449 c.callback.Run(socket.Pass()); 450 } 451 452 ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel( 453 const std::string& name) { 454 // Check if we already have a channel with the requested name. 455 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); 456 if (it != channels_.end()) 457 return it->second; 458 459 // Create a new channel if we haven't found existing one. 460 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_); 461 ++next_channel_id_; 462 channels_[channel->name()] = channel; 463 return channel; 464 } 465 466 467 void ChannelMultiplexer::OnWriteFailed(int error) { 468 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); 469 it != channels_.end(); ++it) { 470 base::ThreadTaskRunnerHandle::Get()->PostTask( 471 FROM_HERE, base::Bind(&ChannelMultiplexer::NotifyWriteFailed, 472 weak_factory_.GetWeakPtr(), it->second->name())); 473 } 474 } 475 476 void ChannelMultiplexer::NotifyWriteFailed(const std::string& name) { 477 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); 478 if (it != channels_.end()) { 479 it->second->OnWriteFailed(); 480 } 481 } 482 483 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, 484 const base::Closure& done_task) { 485 if (!packet->has_channel_id()) { 486 LOG(ERROR) << "Received packet without channel_id."; 487 done_task.Run(); 488 return; 489 } 490 491 int receive_id = packet->channel_id(); 492 MuxChannel* channel = NULL; 493 std::map<int, MuxChannel*>::iterator it = 494 channels_by_receive_id_.find(receive_id); 495 if (it != channels_by_receive_id_.end()) { 496 channel = it->second; 497 } else { 498 // This is a new |channel_id| we haven't seen before. Look it up by name. 499 if (!packet->has_channel_name()) { 500 LOG(ERROR) << "Received packet with unknown channel_id and " 501 "without channel_name."; 502 done_task.Run(); 503 return; 504 } 505 channel = GetOrCreateChannel(packet->channel_name()); 506 channel->set_receive_id(receive_id); 507 channels_by_receive_id_[receive_id] = channel; 508 } 509 510 channel->OnIncomingPacket(packet.Pass(), done_task); 511 } 512 513 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet, 514 const base::Closure& done_task) { 515 return writer_.Write(SerializeAndFrameMessage(*packet), done_task); 516 } 517 518 } // namespace protocol 519 } // namespace remoting 520