1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "remoting/protocol/message_reader.h" 6 7 #include "base/bind.h" 8 #include "base/callback.h" 9 #include "base/compiler_specific.h" 10 #include "base/location.h" 11 #include "base/thread_task_runner_handle.h" 12 #include "base/single_thread_task_runner.h" 13 #include "net/base/io_buffer.h" 14 #include "net/base/net_errors.h" 15 #include "net/socket/socket.h" 16 #include "remoting/base/compound_buffer.h" 17 #include "remoting/proto/internal.pb.h" 18 19 namespace remoting { 20 namespace protocol { 21 22 static const int kReadBufferSize = 4096; 23 24 MessageReader::MessageReader() 25 : socket_(NULL), 26 read_pending_(false), 27 pending_messages_(0), 28 closed_(false), 29 weak_factory_(this) { 30 } 31 32 void MessageReader::Init(net::Socket* socket, 33 const MessageReceivedCallback& callback) { 34 DCHECK(CalledOnValidThread()); 35 message_received_callback_ = callback; 36 DCHECK(socket); 37 socket_ = socket; 38 DoRead(); 39 } 40 41 MessageReader::~MessageReader() { 42 } 43 44 void MessageReader::DoRead() { 45 DCHECK(CalledOnValidThread()); 46 // Don't try to read again if there is another read pending or we 47 // have messages that we haven't finished processing yet. 48 while (!closed_ && !read_pending_ && pending_messages_ == 0) { 49 read_buffer_ = new net::IOBuffer(kReadBufferSize); 50 int result = socket_->Read( 51 read_buffer_.get(), 52 kReadBufferSize, 53 base::Bind(&MessageReader::OnRead, weak_factory_.GetWeakPtr())); 54 HandleReadResult(result); 55 } 56 } 57 58 void MessageReader::OnRead(int result) { 59 DCHECK(CalledOnValidThread()); 60 DCHECK(read_pending_); 61 read_pending_ = false; 62 63 if (!closed_) { 64 HandleReadResult(result); 65 DoRead(); 66 } 67 } 68 69 void MessageReader::HandleReadResult(int result) { 70 DCHECK(CalledOnValidThread()); 71 if (closed_) 72 return; 73 74 if (result > 0) { 75 OnDataReceived(read_buffer_.get(), result); 76 } else if (result == net::ERR_IO_PENDING) { 77 read_pending_ = true; 78 } else { 79 if (result != net::ERR_CONNECTION_CLOSED) { 80 LOG(ERROR) << "Read() returned error " << result; 81 } 82 // Stop reading after any error. 83 closed_ = true; 84 } 85 } 86 87 void MessageReader::OnDataReceived(net::IOBuffer* data, int data_size) { 88 DCHECK(CalledOnValidThread()); 89 message_decoder_.AddData(data, data_size); 90 91 // Get list of all new messages first, and then call the callback 92 // for all of them. 93 while (true) { 94 CompoundBuffer* buffer = message_decoder_.GetNextMessage(); 95 if (!buffer) 96 break; 97 pending_messages_++; 98 base::ThreadTaskRunnerHandle::Get()->PostTask( 99 FROM_HERE, 100 base::Bind(&MessageReader::RunCallback, 101 weak_factory_.GetWeakPtr(), 102 base::Passed(scoped_ptr<CompoundBuffer>(buffer)))); 103 } 104 } 105 106 void MessageReader::RunCallback(scoped_ptr<CompoundBuffer> message) { 107 message_received_callback_.Run( 108 message.Pass(), base::Bind(&MessageReader::OnMessageDone, 109 weak_factory_.GetWeakPtr())); 110 } 111 112 void MessageReader::OnMessageDone() { 113 DCHECK(CalledOnValidThread()); 114 pending_messages_--; 115 DCHECK_GE(pending_messages_, 0); 116 117 // Start next read if necessary. 118 DoRead(); 119 } 120 121 } // namespace protocol 122 } // namespace remoting 123