1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "remoting/protocol/buffered_socket_writer.h" 6 7 #include "base/bind.h" 8 #include "base/location.h" 9 #include "base/single_thread_task_runner.h" 10 #include "base/stl_util.h" 11 #include "base/thread_task_runner_handle.h" 12 #include "net/base/net_errors.h" 13 14 namespace remoting { 15 namespace protocol { 16 17 struct BufferedSocketWriterBase::PendingPacket { 18 PendingPacket(scoped_refptr<net::IOBufferWithSize> data, 19 const base::Closure& done_task) 20 : data(data), 21 done_task(done_task) { 22 } 23 24 scoped_refptr<net::IOBufferWithSize> data; 25 base::Closure done_task; 26 }; 27 28 BufferedSocketWriterBase::BufferedSocketWriterBase() 29 : buffer_size_(0), 30 socket_(NULL), 31 write_pending_(false), 32 closed_(false), 33 destroyed_flag_(NULL) { 34 } 35 36 void BufferedSocketWriterBase::Init(net::Socket* socket, 37 const WriteFailedCallback& callback) { 38 DCHECK(CalledOnValidThread()); 39 DCHECK(socket); 40 socket_ = socket; 41 write_failed_callback_ = callback; 42 } 43 44 bool BufferedSocketWriterBase::Write( 45 scoped_refptr<net::IOBufferWithSize> data, const base::Closure& done_task) { 46 DCHECK(CalledOnValidThread()); 47 DCHECK(socket_); 48 DCHECK(data.get()); 49 50 // Don't write after Close(). 51 if (closed_) 52 return false; 53 54 queue_.push_back(new PendingPacket(data, done_task)); 55 buffer_size_ += data->size(); 56 57 DoWrite(); 58 59 // DoWrite() may trigger OnWriteError() to be called. 60 return !closed_; 61 } 62 63 void BufferedSocketWriterBase::DoWrite() { 64 DCHECK(CalledOnValidThread()); 65 DCHECK(socket_); 66 67 // Don't try to write if there is another write pending. 68 if (write_pending_) 69 return; 70 71 // Don't write after Close(). 72 if (closed_) 73 return; 74 75 while (true) { 76 net::IOBuffer* current_packet; 77 int current_packet_size; 78 GetNextPacket(¤t_packet, ¤t_packet_size); 79 80 // Return if the queue is empty. 81 if (!current_packet) 82 return; 83 84 int result = socket_->Write( 85 current_packet, current_packet_size, 86 base::Bind(&BufferedSocketWriterBase::OnWritten, 87 base::Unretained(this))); 88 bool write_again = false; 89 HandleWriteResult(result, &write_again); 90 if (!write_again) 91 return; 92 } 93 } 94 95 void BufferedSocketWriterBase::HandleWriteResult(int result, 96 bool* write_again) { 97 *write_again = false; 98 if (result < 0) { 99 if (result == net::ERR_IO_PENDING) { 100 write_pending_ = true; 101 } else { 102 HandleError(result); 103 if (!write_failed_callback_.is_null()) 104 write_failed_callback_.Run(result); 105 } 106 return; 107 } 108 109 base::Closure done_task = AdvanceBufferPosition(result); 110 if (!done_task.is_null()) { 111 bool destroyed = false; 112 destroyed_flag_ = &destroyed; 113 done_task.Run(); 114 if (destroyed) { 115 // Stop doing anything if we've been destroyed by the callback. 116 return; 117 } 118 destroyed_flag_ = NULL; 119 } 120 121 *write_again = true; 122 } 123 124 void BufferedSocketWriterBase::OnWritten(int result) { 125 DCHECK(CalledOnValidThread()); 126 DCHECK(write_pending_); 127 write_pending_ = false; 128 129 bool write_again; 130 HandleWriteResult(result, &write_again); 131 if (write_again) 132 DoWrite(); 133 } 134 135 void BufferedSocketWriterBase::HandleError(int result) { 136 DCHECK(CalledOnValidThread()); 137 138 closed_ = true; 139 140 STLDeleteElements(&queue_); 141 142 // Notify subclass that an error is received. 143 OnError(result); 144 } 145 146 int BufferedSocketWriterBase::GetBufferSize() { 147 return buffer_size_; 148 } 149 150 int BufferedSocketWriterBase::GetBufferChunks() { 151 return queue_.size(); 152 } 153 154 void BufferedSocketWriterBase::Close() { 155 DCHECK(CalledOnValidThread()); 156 closed_ = true; 157 } 158 159 BufferedSocketWriterBase::~BufferedSocketWriterBase() { 160 if (destroyed_flag_) 161 *destroyed_flag_ = true; 162 163 STLDeleteElements(&queue_); 164 } 165 166 base::Closure BufferedSocketWriterBase::PopQueue() { 167 base::Closure result = queue_.front()->done_task; 168 delete queue_.front(); 169 queue_.pop_front(); 170 return result; 171 } 172 173 BufferedSocketWriter::BufferedSocketWriter() { 174 } 175 176 void BufferedSocketWriter::GetNextPacket( 177 net::IOBuffer** buffer, int* size) { 178 if (!current_buf_.get()) { 179 if (queue_.empty()) { 180 *buffer = NULL; 181 return; // Nothing to write. 182 } 183 current_buf_ = new net::DrainableIOBuffer(queue_.front()->data.get(), 184 queue_.front()->data->size()); 185 } 186 187 *buffer = current_buf_.get(); 188 *size = current_buf_->BytesRemaining(); 189 } 190 191 base::Closure BufferedSocketWriter::AdvanceBufferPosition(int written) { 192 buffer_size_ -= written; 193 current_buf_->DidConsume(written); 194 195 if (current_buf_->BytesRemaining() == 0) { 196 current_buf_ = NULL; 197 return PopQueue(); 198 } 199 return base::Closure(); 200 } 201 202 void BufferedSocketWriter::OnError(int result) { 203 current_buf_ = NULL; 204 } 205 206 BufferedSocketWriter::~BufferedSocketWriter() { 207 } 208 209 BufferedDatagramWriter::BufferedDatagramWriter() { 210 } 211 212 void BufferedDatagramWriter::GetNextPacket( 213 net::IOBuffer** buffer, int* size) { 214 if (queue_.empty()) { 215 *buffer = NULL; 216 return; // Nothing to write. 217 } 218 *buffer = queue_.front()->data.get(); 219 *size = queue_.front()->data->size(); 220 } 221 222 base::Closure BufferedDatagramWriter::AdvanceBufferPosition(int written) { 223 DCHECK_EQ(written, queue_.front()->data->size()); 224 buffer_size_ -= queue_.front()->data->size(); 225 return PopQueue(); 226 } 227 228 void BufferedDatagramWriter::OnError(int result) { 229 // Nothing to do here. 230 } 231 232 BufferedDatagramWriter::~BufferedDatagramWriter() { 233 } 234 235 } // namespace protocol 236 } // namespace remoting 237