Home | History | Annotate | Download | only in protocol
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "remoting/protocol/buffered_socket_writer.h"
      6 
      7 #include "base/bind.h"
      8 #include "base/location.h"
      9 #include "base/single_thread_task_runner.h"
     10 #include "base/stl_util.h"
     11 #include "base/thread_task_runner_handle.h"
     12 #include "net/base/net_errors.h"
     13 
     14 namespace remoting {
     15 namespace protocol {
     16 
     17 struct BufferedSocketWriterBase::PendingPacket {
     18   PendingPacket(scoped_refptr<net::IOBufferWithSize> data,
     19                 const base::Closure& done_task)
     20       : data(data),
     21         done_task(done_task) {
     22   }
     23 
     24   scoped_refptr<net::IOBufferWithSize> data;
     25   base::Closure done_task;
     26 };
     27 
     28 BufferedSocketWriterBase::BufferedSocketWriterBase()
     29     : buffer_size_(0),
     30       socket_(NULL),
     31       write_pending_(false),
     32       closed_(false),
     33       destroyed_flag_(NULL) {
     34 }
     35 
     36 void BufferedSocketWriterBase::Init(net::Socket* socket,
     37                                     const WriteFailedCallback& callback) {
     38   DCHECK(CalledOnValidThread());
     39   DCHECK(socket);
     40   socket_ = socket;
     41   write_failed_callback_ = callback;
     42 }
     43 
     44 bool BufferedSocketWriterBase::Write(
     45     scoped_refptr<net::IOBufferWithSize> data, const base::Closure& done_task) {
     46   DCHECK(CalledOnValidThread());
     47   DCHECK(socket_);
     48   DCHECK(data.get());
     49 
     50   // Don't write after Close().
     51   if (closed_)
     52     return false;
     53 
     54   queue_.push_back(new PendingPacket(data, done_task));
     55   buffer_size_ += data->size();
     56 
     57   DoWrite();
     58 
     59   // DoWrite() may trigger OnWriteError() to be called.
     60   return !closed_;
     61 }
     62 
     63 void BufferedSocketWriterBase::DoWrite() {
     64   DCHECK(CalledOnValidThread());
     65   DCHECK(socket_);
     66 
     67   // Don't try to write if there is another write pending.
     68   if (write_pending_)
     69     return;
     70 
     71   // Don't write after Close().
     72   if (closed_)
     73     return;
     74 
     75   while (true) {
     76     net::IOBuffer* current_packet;
     77     int current_packet_size;
     78     GetNextPacket(&current_packet, &current_packet_size);
     79 
     80     // Return if the queue is empty.
     81     if (!current_packet)
     82       return;
     83 
     84     int result = socket_->Write(
     85         current_packet, current_packet_size,
     86         base::Bind(&BufferedSocketWriterBase::OnWritten,
     87                    base::Unretained(this)));
     88     bool write_again = false;
     89     HandleWriteResult(result, &write_again);
     90     if (!write_again)
     91       return;
     92   }
     93 }
     94 
     95 void BufferedSocketWriterBase::HandleWriteResult(int result,
     96                                                  bool* write_again) {
     97   *write_again = false;
     98   if (result < 0) {
     99     if (result == net::ERR_IO_PENDING) {
    100       write_pending_ = true;
    101     } else {
    102       HandleError(result);
    103       if (!write_failed_callback_.is_null())
    104         write_failed_callback_.Run(result);
    105     }
    106     return;
    107   }
    108 
    109   base::Closure done_task = AdvanceBufferPosition(result);
    110   if (!done_task.is_null()) {
    111     bool destroyed = false;
    112     destroyed_flag_ = &destroyed;
    113     done_task.Run();
    114     if (destroyed) {
    115       // Stop doing anything if we've been destroyed by the callback.
    116       return;
    117     }
    118     destroyed_flag_ = NULL;
    119   }
    120 
    121   *write_again = true;
    122 }
    123 
    124 void BufferedSocketWriterBase::OnWritten(int result) {
    125   DCHECK(CalledOnValidThread());
    126   DCHECK(write_pending_);
    127   write_pending_ = false;
    128 
    129   bool write_again;
    130   HandleWriteResult(result, &write_again);
    131   if (write_again)
    132     DoWrite();
    133 }
    134 
    135 void BufferedSocketWriterBase::HandleError(int result) {
    136   DCHECK(CalledOnValidThread());
    137 
    138   closed_ = true;
    139 
    140   STLDeleteElements(&queue_);
    141 
    142   // Notify subclass that an error is received.
    143   OnError(result);
    144 }
    145 
    146 int BufferedSocketWriterBase::GetBufferSize() {
    147   return buffer_size_;
    148 }
    149 
    150 int BufferedSocketWriterBase::GetBufferChunks() {
    151   return queue_.size();
    152 }
    153 
    154 void BufferedSocketWriterBase::Close() {
    155   DCHECK(CalledOnValidThread());
    156   closed_ = true;
    157 }
    158 
    159 BufferedSocketWriterBase::~BufferedSocketWriterBase() {
    160   if (destroyed_flag_)
    161     *destroyed_flag_ = true;
    162 
    163   STLDeleteElements(&queue_);
    164 }
    165 
    166 base::Closure BufferedSocketWriterBase::PopQueue() {
    167   base::Closure result = queue_.front()->done_task;
    168   delete queue_.front();
    169   queue_.pop_front();
    170   return result;
    171 }
    172 
    173 BufferedSocketWriter::BufferedSocketWriter() {
    174 }
    175 
    176 void BufferedSocketWriter::GetNextPacket(
    177     net::IOBuffer** buffer, int* size) {
    178   if (!current_buf_.get()) {
    179     if (queue_.empty()) {
    180       *buffer = NULL;
    181       return;  // Nothing to write.
    182     }
    183     current_buf_ = new net::DrainableIOBuffer(queue_.front()->data.get(),
    184                                               queue_.front()->data->size());
    185   }
    186 
    187   *buffer = current_buf_.get();
    188   *size = current_buf_->BytesRemaining();
    189 }
    190 
    191 base::Closure BufferedSocketWriter::AdvanceBufferPosition(int written) {
    192   buffer_size_ -= written;
    193   current_buf_->DidConsume(written);
    194 
    195   if (current_buf_->BytesRemaining() == 0) {
    196     current_buf_ = NULL;
    197     return PopQueue();
    198   }
    199   return base::Closure();
    200 }
    201 
    202 void BufferedSocketWriter::OnError(int result) {
    203   current_buf_ = NULL;
    204 }
    205 
    206 BufferedSocketWriter::~BufferedSocketWriter() {
    207 }
    208 
    209 BufferedDatagramWriter::BufferedDatagramWriter() {
    210 }
    211 
    212 void BufferedDatagramWriter::GetNextPacket(
    213     net::IOBuffer** buffer, int* size) {
    214   if (queue_.empty()) {
    215     *buffer = NULL;
    216     return;  // Nothing to write.
    217   }
    218   *buffer = queue_.front()->data.get();
    219   *size = queue_.front()->data->size();
    220 }
    221 
    222 base::Closure BufferedDatagramWriter::AdvanceBufferPosition(int written) {
    223   DCHECK_EQ(written, queue_.front()->data->size());
    224   buffer_size_ -= queue_.front()->data->size();
    225   return PopQueue();
    226 }
    227 
    228 void BufferedDatagramWriter::OnError(int result) {
    229   // Nothing to do here.
    230 }
    231 
    232 BufferedDatagramWriter::~BufferedDatagramWriter() {
    233 }
    234 
    235 }  // namespace protocol
    236 }  // namespace remoting
    237