Home | History | Annotate | Download | only in protocol
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "remoting/protocol/fake_datagram_socket.h"
      6 
      7 #include "base/bind.h"
      8 #include "base/single_thread_task_runner.h"
      9 #include "base/thread_task_runner_handle.h"
     10 #include "net/base/address_list.h"
     11 #include "net/base/io_buffer.h"
     12 #include "net/base/net_errors.h"
     13 #include "net/base/net_util.h"
     14 #include "testing/gtest/include/gtest/gtest.h"
     15 
     16 namespace remoting {
     17 namespace protocol {
     18 
     19 FakeDatagramSocket::FakeDatagramSocket()
     20     : input_pos_(0),
     21       task_runner_(base::ThreadTaskRunnerHandle::Get()),
     22       weak_factory_(this) {
     23 }
     24 
     25 FakeDatagramSocket::~FakeDatagramSocket() {
     26   EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
     27 }
     28 
     29 void FakeDatagramSocket::AppendInputPacket(const std::string& data) {
     30   EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
     31   input_packets_.push_back(data);
     32 
     33   // Complete pending read if any.
     34   if (!read_callback_.is_null()) {
     35     DCHECK_EQ(input_pos_, static_cast<int>(input_packets_.size()) - 1);
     36     int result = CopyReadData(read_buffer_.get(), read_buffer_size_);
     37     read_buffer_ = NULL;
     38 
     39     net::CompletionCallback callback = read_callback_;
     40     read_callback_.Reset();
     41     callback.Run(result);
     42   }
     43 }
     44 
     45 void FakeDatagramSocket::PairWith(FakeDatagramSocket* peer_socket) {
     46   EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
     47   peer_socket_ = peer_socket->GetWeakPtr();
     48   peer_socket->peer_socket_ = GetWeakPtr();
     49 }
     50 
     51 base::WeakPtr<FakeDatagramSocket> FakeDatagramSocket::GetWeakPtr() {
     52   return weak_factory_.GetWeakPtr();
     53 }
     54 
     55 int FakeDatagramSocket::Read(net::IOBuffer* buf, int buf_len,
     56                              const net::CompletionCallback& callback) {
     57   EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
     58   if (input_pos_ < static_cast<int>(input_packets_.size())) {
     59     return CopyReadData(buf, buf_len);
     60   } else {
     61     read_buffer_ = buf;
     62     read_buffer_size_ = buf_len;
     63     read_callback_ = callback;
     64     return net::ERR_IO_PENDING;
     65   }
     66 }
     67 
     68 int FakeDatagramSocket::Write(net::IOBuffer* buf, int buf_len,
     69                          const net::CompletionCallback& callback) {
     70   EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
     71   written_packets_.push_back(std::string());
     72   written_packets_.back().assign(buf->data(), buf->data() + buf_len);
     73 
     74   if (peer_socket_.get()) {
     75     task_runner_->PostTask(
     76         FROM_HERE,
     77         base::Bind(&FakeDatagramSocket::AppendInputPacket,
     78                    peer_socket_,
     79                    std::string(buf->data(), buf->data() + buf_len)));
     80   }
     81 
     82   return buf_len;
     83 }
     84 
     85 int FakeDatagramSocket::SetReceiveBufferSize(int32 size) {
     86   NOTIMPLEMENTED();
     87   return net::ERR_NOT_IMPLEMENTED;
     88 }
     89 
     90 int FakeDatagramSocket::SetSendBufferSize(int32 size) {
     91   NOTIMPLEMENTED();
     92   return net::ERR_NOT_IMPLEMENTED;
     93 }
     94 
     95 int FakeDatagramSocket::CopyReadData(net::IOBuffer* buf, int buf_len) {
     96   int size = std::min(
     97       buf_len, static_cast<int>(input_packets_[input_pos_].size()));
     98   memcpy(buf->data(), &(*input_packets_[input_pos_].begin()), size);
     99   ++input_pos_;
    100   return size;
    101 }
    102 
    103 FakeDatagramChannelFactory::FakeDatagramChannelFactory()
    104     : task_runner_(base::ThreadTaskRunnerHandle::Get()),
    105       asynchronous_create_(false),
    106       fail_create_(false),
    107       weak_factory_(this) {
    108 }
    109 
    110 FakeDatagramChannelFactory::~FakeDatagramChannelFactory() {
    111   for (ChannelsMap::iterator it = channels_.begin(); it != channels_.end();
    112        ++it) {
    113     EXPECT_TRUE(it->second == NULL);
    114   }
    115 }
    116 
    117 void FakeDatagramChannelFactory::PairWith(
    118     FakeDatagramChannelFactory* peer_factory) {
    119   peer_factory_ = peer_factory->weak_factory_.GetWeakPtr();
    120   peer_factory_->peer_factory_ = weak_factory_.GetWeakPtr();
    121 }
    122 
    123 FakeDatagramSocket* FakeDatagramChannelFactory::GetFakeChannel(
    124     const std::string& name) {
    125   return channels_[name].get();
    126 }
    127 
    128 void FakeDatagramChannelFactory::CreateChannel(
    129     const std::string& name,
    130     const ChannelCreatedCallback& callback) {
    131   EXPECT_TRUE(channels_[name] == NULL);
    132 
    133   scoped_ptr<FakeDatagramSocket> channel(new FakeDatagramSocket());
    134   channels_[name] = channel->GetWeakPtr();
    135 
    136   if (peer_factory_) {
    137     FakeDatagramSocket* peer_socket = peer_factory_->GetFakeChannel(name);
    138     if (peer_socket)
    139       channel->PairWith(peer_socket);
    140   }
    141 
    142   if (fail_create_)
    143     channel.reset();
    144 
    145   if (asynchronous_create_) {
    146     task_runner_->PostTask(
    147         FROM_HERE,
    148         base::Bind(&FakeDatagramChannelFactory::NotifyChannelCreated,
    149                    weak_factory_.GetWeakPtr(), base::Passed(&channel),
    150                    name, callback));
    151   } else {
    152     NotifyChannelCreated(channel.Pass(), name, callback);
    153   }
    154 }
    155 
    156 void FakeDatagramChannelFactory::NotifyChannelCreated(
    157     scoped_ptr<FakeDatagramSocket> owned_socket,
    158     const std::string& name,
    159     const ChannelCreatedCallback& callback) {
    160   if (channels_.find(name) != channels_.end())
    161     callback.Run(owned_socket.PassAs<net::Socket>());
    162 }
    163 
    164 void FakeDatagramChannelFactory::CancelChannelCreation(
    165     const std::string& name) {
    166   channels_.erase(name);
    167 }
    168 
    169 }  // namespace protocol
    170 }  // namespace remoting
    171