Home | History | Annotate | Download | only in cast_channel
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "extensions/browser/api/cast_channel/cast_transport.h"
      6 
      7 #include <string>
      8 
      9 #include "base/bind.h"
     10 #include "base/format_macros.h"
     11 #include "base/numerics/safe_conversions.h"
     12 #include "base/strings/stringprintf.h"
     13 #include "extensions/browser/api/cast_channel/cast_framer.h"
     14 #include "extensions/browser/api/cast_channel/cast_message_util.h"
     15 #include "extensions/browser/api/cast_channel/logger.h"
     16 #include "extensions/browser/api/cast_channel/logger_util.h"
     17 #include "extensions/common/api/cast_channel/cast_channel.pb.h"
     18 #include "net/base/net_errors.h"
     19 
     20 #define VLOG_WITH_CONNECTION(level)                       \
     21   VLOG(level) << "[" << socket_->ip_endpoint().ToString() \
     22               << ", auth=" << socket_->channel_auth() << "] "
     23 
     24 namespace extensions {
     25 namespace core_api {
     26 namespace cast_channel {
     27 
     28 CastTransport::CastTransport(CastSocketInterface* socket,
     29                              Delegate* read_delegate,
     30                              scoped_refptr<Logger> logger)
     31     : socket_(socket),
     32       read_delegate_(read_delegate),
     33       write_state_(WRITE_STATE_NONE),
     34       read_state_(READ_STATE_NONE),
     35       logger_(logger) {
     36   DCHECK(socket);
     37   DCHECK(read_delegate);
     38 
     39   // Buffer is reused across messages to minimize unnecessary buffer
     40   // [re]allocations.
     41   read_buffer_ = new net::GrowableIOBuffer();
     42   read_buffer_->SetCapacity(MessageFramer::MessageHeader::max_message_size());
     43   framer_.reset(new MessageFramer(read_buffer_));
     44 }
     45 
     46 CastTransport::~CastTransport() {
     47   DCHECK(thread_checker_.CalledOnValidThread());
     48   FlushWriteQueue();
     49 }
     50 
     51 // static
     52 proto::ReadState CastTransport::ReadStateToProto(
     53     CastTransport::ReadState state) {
     54   switch (state) {
     55     case CastTransport::READ_STATE_NONE:
     56       return proto::READ_STATE_NONE;
     57     case CastTransport::READ_STATE_READ:
     58       return proto::READ_STATE_READ;
     59     case CastTransport::READ_STATE_READ_COMPLETE:
     60       return proto::READ_STATE_READ_COMPLETE;
     61     case CastTransport::READ_STATE_DO_CALLBACK:
     62       return proto::READ_STATE_DO_CALLBACK;
     63     case CastTransport::READ_STATE_ERROR:
     64       return proto::READ_STATE_ERROR;
     65     default:
     66       NOTREACHED();
     67       return proto::READ_STATE_NONE;
     68   }
     69 }
     70 
     71 // static
     72 proto::WriteState CastTransport::WriteStateToProto(
     73     CastTransport::WriteState state) {
     74   switch (state) {
     75     case CastTransport::WRITE_STATE_NONE:
     76       return proto::WRITE_STATE_NONE;
     77     case CastTransport::WRITE_STATE_WRITE:
     78       return proto::WRITE_STATE_WRITE;
     79     case CastTransport::WRITE_STATE_WRITE_COMPLETE:
     80       return proto::WRITE_STATE_WRITE_COMPLETE;
     81     case CastTransport::WRITE_STATE_DO_CALLBACK:
     82       return proto::WRITE_STATE_DO_CALLBACK;
     83     case CastTransport::WRITE_STATE_ERROR:
     84       return proto::WRITE_STATE_ERROR;
     85     default:
     86       NOTREACHED();
     87       return proto::WRITE_STATE_NONE;
     88   }
     89 }
     90 
     91 // static
     92 proto::ErrorState CastTransport::ErrorStateToProto(ChannelError state) {
     93   switch (state) {
     94     case CHANNEL_ERROR_NONE:
     95       return proto::CHANNEL_ERROR_NONE;
     96     case CHANNEL_ERROR_CHANNEL_NOT_OPEN:
     97       return proto::CHANNEL_ERROR_CHANNEL_NOT_OPEN;
     98     case CHANNEL_ERROR_AUTHENTICATION_ERROR:
     99       return proto::CHANNEL_ERROR_AUTHENTICATION_ERROR;
    100     case CHANNEL_ERROR_CONNECT_ERROR:
    101       return proto::CHANNEL_ERROR_CONNECT_ERROR;
    102     case CHANNEL_ERROR_SOCKET_ERROR:
    103       return proto::CHANNEL_ERROR_SOCKET_ERROR;
    104     case CHANNEL_ERROR_TRANSPORT_ERROR:
    105       return proto::CHANNEL_ERROR_TRANSPORT_ERROR;
    106     case CHANNEL_ERROR_INVALID_MESSAGE:
    107       return proto::CHANNEL_ERROR_INVALID_MESSAGE;
    108     case CHANNEL_ERROR_INVALID_CHANNEL_ID:
    109       return proto::CHANNEL_ERROR_INVALID_CHANNEL_ID;
    110     case CHANNEL_ERROR_CONNECT_TIMEOUT:
    111       return proto::CHANNEL_ERROR_CONNECT_TIMEOUT;
    112     case CHANNEL_ERROR_UNKNOWN:
    113       return proto::CHANNEL_ERROR_UNKNOWN;
    114     default:
    115       NOTREACHED();
    116       return proto::CHANNEL_ERROR_NONE;
    117   }
    118 }
    119 
    120 void CastTransport::FlushWriteQueue() {
    121   for (; !write_queue_.empty(); write_queue_.pop()) {
    122     net::CompletionCallback& callback = write_queue_.front().callback;
    123     callback.Run(net::ERR_FAILED);
    124     callback.Reset();
    125   }
    126 }
    127 
    128 void CastTransport::SendMessage(const CastMessage& message,
    129                                 const net::CompletionCallback& callback) {
    130   DCHECK(thread_checker_.CalledOnValidThread());
    131   std::string serialized_message;
    132   if (!MessageFramer::Serialize(message, &serialized_message)) {
    133     logger_->LogSocketEventForMessage(socket_->id(),
    134                                       proto::SEND_MESSAGE_FAILED,
    135                                       message.namespace_(),
    136                                       "Error when serializing message.");
    137     callback.Run(net::ERR_FAILED);
    138     return;
    139   }
    140   WriteRequest write_request(
    141       message.namespace_(), serialized_message, callback);
    142 
    143   write_queue_.push(write_request);
    144   logger_->LogSocketEventForMessage(
    145       socket_->id(),
    146       proto::MESSAGE_ENQUEUED,
    147       message.namespace_(),
    148       base::StringPrintf("Queue size: %" PRIuS, write_queue_.size()));
    149   if (write_state_ == WRITE_STATE_NONE) {
    150     SetWriteState(WRITE_STATE_WRITE);
    151     OnWriteResult(net::OK);
    152   }
    153 }
    154 
    155 CastTransport::WriteRequest::WriteRequest(
    156     const std::string& namespace_,
    157     const std::string& payload,
    158     const net::CompletionCallback& callback)
    159     : message_namespace(namespace_), callback(callback) {
    160   VLOG(2) << "WriteRequest size: " << payload.size();
    161   io_buffer = new net::DrainableIOBuffer(new net::StringIOBuffer(payload),
    162                                          payload.size());
    163 }
    164 
    165 CastTransport::WriteRequest::~WriteRequest() {
    166 }
    167 
    168 void CastTransport::SetReadState(ReadState read_state) {
    169   if (read_state_ != read_state) {
    170     read_state_ = read_state;
    171     logger_->LogSocketReadState(socket_->id(), ReadStateToProto(read_state_));
    172   }
    173 }
    174 
    175 void CastTransport::SetWriteState(WriteState write_state) {
    176   if (write_state_ != write_state) {
    177     write_state_ = write_state;
    178     logger_->LogSocketWriteState(socket_->id(),
    179                                  WriteStateToProto(write_state_));
    180   }
    181 }
    182 
    183 void CastTransport::SetErrorState(ChannelError error_state) {
    184   if (error_state_ != error_state) {
    185     error_state_ = error_state;
    186     logger_->LogSocketErrorState(socket_->id(),
    187                                  ErrorStateToProto(error_state_));
    188   }
    189 }
    190 
    191 void CastTransport::OnWriteResult(int result) {
    192   DCHECK(thread_checker_.CalledOnValidThread());
    193   VLOG_WITH_CONNECTION(1) << "OnWriteResult queue size: "
    194                           << write_queue_.size();
    195 
    196   if (write_queue_.empty()) {
    197     SetWriteState(WRITE_STATE_NONE);
    198     return;
    199   }
    200 
    201   // Network operations can either finish synchronously or asynchronously.
    202   // This method executes the state machine transitions in a loop so that
    203   // write state transitions happen even when network operations finish
    204   // synchronously.
    205   int rv = result;
    206   do {
    207     WriteState state = write_state_;
    208     write_state_ = WRITE_STATE_NONE;
    209     switch (state) {
    210       case WRITE_STATE_WRITE:
    211         rv = DoWrite();
    212         break;
    213       case WRITE_STATE_WRITE_COMPLETE:
    214         rv = DoWriteComplete(rv);
    215         break;
    216       case WRITE_STATE_DO_CALLBACK:
    217         rv = DoWriteCallback();
    218         break;
    219       case WRITE_STATE_ERROR:
    220         rv = DoWriteError(rv);
    221         break;
    222       default:
    223         NOTREACHED() << "BUG in write flow. Unknown state: " << state;
    224         break;
    225     }
    226   } while (!write_queue_.empty() && rv != net::ERR_IO_PENDING &&
    227            write_state_ != WRITE_STATE_NONE);
    228 
    229   // No state change occurred in do-while loop above. This means state has
    230   // transitioned to NONE.
    231   if (write_state_ == WRITE_STATE_NONE) {
    232     logger_->LogSocketWriteState(socket_->id(),
    233                                  WriteStateToProto(write_state_));
    234   }
    235 
    236   // If write loop is done because the queue is empty then set write
    237   // state to NONE
    238   if (write_queue_.empty()) {
    239     SetWriteState(WRITE_STATE_NONE);
    240   }
    241 
    242   // Write loop is done - if the result is ERR_FAILED then close with error.
    243   if (rv == net::ERR_FAILED) {
    244     DCHECK_NE(CHANNEL_ERROR_NONE, error_state_);
    245     socket_->CloseWithError(error_state_);
    246     FlushWriteQueue();
    247   }
    248 }
    249 
    250 int CastTransport::DoWrite() {
    251   DCHECK(!write_queue_.empty());
    252   WriteRequest& request = write_queue_.front();
    253 
    254   VLOG_WITH_CONNECTION(2) << "WriteData byte_count = "
    255                           << request.io_buffer->size() << " bytes_written "
    256                           << request.io_buffer->BytesConsumed();
    257 
    258   SetWriteState(WRITE_STATE_WRITE_COMPLETE);
    259 
    260   int rv = socket_->Write(
    261       request.io_buffer.get(),
    262       request.io_buffer->BytesRemaining(),
    263       base::Bind(&CastTransport::OnWriteResult, base::Unretained(this)));
    264   logger_->LogSocketEventWithRv(socket_->id(), proto::SOCKET_WRITE, rv);
    265 
    266   return rv;
    267 }
    268 
    269 int CastTransport::DoWriteComplete(int result) {
    270   VLOG_WITH_CONNECTION(2) << "DoWriteComplete result=" << result;
    271   DCHECK(!write_queue_.empty());
    272   if (result <= 0) {  // NOTE that 0 also indicates an error
    273     SetErrorState(CHANNEL_ERROR_TRANSPORT_ERROR);
    274     SetWriteState(WRITE_STATE_ERROR);
    275     return result == 0 ? net::ERR_FAILED : result;
    276   }
    277 
    278   // Some bytes were successfully written
    279   WriteRequest& request = write_queue_.front();
    280   scoped_refptr<net::DrainableIOBuffer> io_buffer = request.io_buffer;
    281   io_buffer->DidConsume(result);
    282   if (io_buffer->BytesRemaining() == 0) {  // Message fully sent
    283     SetWriteState(WRITE_STATE_DO_CALLBACK);
    284   } else {
    285     SetWriteState(WRITE_STATE_WRITE);
    286   }
    287 
    288   return net::OK;
    289 }
    290 
    291 int CastTransport::DoWriteCallback() {
    292   VLOG_WITH_CONNECTION(2) << "DoWriteCallback";
    293   DCHECK(!write_queue_.empty());
    294 
    295   SetWriteState(WRITE_STATE_WRITE);
    296 
    297   WriteRequest& request = write_queue_.front();
    298   int bytes_consumed = request.io_buffer->BytesConsumed();
    299   logger_->LogSocketEventForMessage(
    300       socket_->id(),
    301       proto::MESSAGE_WRITTEN,
    302       request.message_namespace,
    303       base::StringPrintf("Bytes: %d", bytes_consumed));
    304   request.callback.Run(net::OK);
    305   write_queue_.pop();
    306   return net::OK;
    307 }
    308 
    309 int CastTransport::DoWriteError(int result) {
    310   VLOG_WITH_CONNECTION(2) << "DoWriteError result=" << result;
    311   DCHECK_NE(CHANNEL_ERROR_NONE, error_state_);
    312   DCHECK_LT(result, 0);
    313   return net::ERR_FAILED;
    314 }
    315 
    316 void CastTransport::StartReadLoop() {
    317   DCHECK(thread_checker_.CalledOnValidThread());
    318   // Read loop would have already been started if read state is not NONE
    319   if (read_state_ == READ_STATE_NONE) {
    320     SetReadState(READ_STATE_READ);
    321     OnReadResult(net::OK);
    322   }
    323 }
    324 
    325 void CastTransport::OnReadResult(int result) {
    326   DCHECK(thread_checker_.CalledOnValidThread());
    327   // Network operations can either finish synchronously or asynchronously.
    328   // This method executes the state machine transitions in a loop so that
    329   // write state transitions happen even when network operations finish
    330   // synchronously.
    331   int rv = result;
    332   do {
    333     ReadState state = read_state_;
    334     read_state_ = READ_STATE_NONE;
    335 
    336     switch (state) {
    337       case READ_STATE_READ:
    338         rv = DoRead();
    339         break;
    340       case READ_STATE_READ_COMPLETE:
    341         rv = DoReadComplete(rv);
    342         break;
    343       case READ_STATE_DO_CALLBACK:
    344         rv = DoReadCallback();
    345         break;
    346       case READ_STATE_ERROR:
    347         rv = DoReadError(rv);
    348         DCHECK_EQ(read_state_, READ_STATE_NONE);
    349         break;
    350       default:
    351         NOTREACHED() << "BUG in read flow. Unknown state: " << state;
    352         break;
    353     }
    354   } while (rv != net::ERR_IO_PENDING && read_state_ != READ_STATE_NONE);
    355 
    356   // No state change occurred in do-while loop above. This means state has
    357   // transitioned to NONE.
    358   if (read_state_ == READ_STATE_NONE) {
    359     logger_->LogSocketReadState(socket_->id(), ReadStateToProto(read_state_));
    360   }
    361 
    362   if (rv == net::ERR_FAILED) {
    363     DCHECK_NE(CHANNEL_ERROR_NONE, error_state_);
    364     socket_->CloseWithError(error_state_);
    365     FlushWriteQueue();
    366     read_delegate_->OnError(
    367         socket_, error_state_, logger_->GetLastErrors(socket_->id()));
    368   }
    369 }
    370 
    371 int CastTransport::DoRead() {
    372   VLOG_WITH_CONNECTION(2) << "DoRead";
    373   SetReadState(READ_STATE_READ_COMPLETE);
    374 
    375   // Determine how many bytes need to be read.
    376   size_t num_bytes_to_read = framer_->BytesRequested();
    377 
    378   // Read up to num_bytes_to_read into |current_read_buffer_|.
    379   int rv = socket_->Read(
    380       read_buffer_.get(),
    381       base::checked_cast<uint32>(num_bytes_to_read),
    382       base::Bind(&CastTransport::OnReadResult, base::Unretained(this)));
    383 
    384   return rv;
    385 }
    386 
    387 int CastTransport::DoReadComplete(int result) {
    388   VLOG_WITH_CONNECTION(2) << "DoReadComplete result = " << result;
    389 
    390   if (result <= 0) {
    391     SetErrorState(CHANNEL_ERROR_TRANSPORT_ERROR);
    392     SetReadState(READ_STATE_ERROR);
    393     return result == 0 ? net::ERR_FAILED : result;
    394   }
    395 
    396   size_t message_size;
    397   DCHECK(current_message_.get() == NULL);
    398   current_message_ = framer_->Ingest(result, &message_size, &error_state_);
    399   if (current_message_.get()) {
    400     DCHECK_EQ(error_state_, CHANNEL_ERROR_NONE);
    401     DCHECK_GT(message_size, static_cast<size_t>(0));
    402     logger_->LogSocketEventForMessage(
    403         socket_->id(),
    404         proto::MESSAGE_READ,
    405         current_message_->namespace_(),
    406         base::StringPrintf("Message size: %u",
    407                            static_cast<uint32>(message_size)));
    408     SetReadState(READ_STATE_DO_CALLBACK);
    409   } else if (error_state_ != CHANNEL_ERROR_NONE) {
    410     DCHECK(current_message_.get() == NULL);
    411     SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
    412     SetReadState(READ_STATE_ERROR);
    413   } else {
    414     DCHECK(current_message_.get() == NULL);
    415     SetReadState(READ_STATE_READ);
    416   }
    417   return net::OK;
    418 }
    419 
    420 int CastTransport::DoReadCallback() {
    421   VLOG_WITH_CONNECTION(2) << "DoReadCallback";
    422   SetReadState(READ_STATE_READ);
    423   if (!IsCastMessageValid(*current_message_)) {
    424     SetReadState(READ_STATE_ERROR);
    425     SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
    426     return net::ERR_INVALID_RESPONSE;
    427   }
    428   logger_->LogSocketEventForMessage(socket_->id(),
    429                                     proto::NOTIFY_ON_MESSAGE,
    430                                     current_message_->namespace_(),
    431                                     std::string());
    432   read_delegate_->OnMessage(socket_, *current_message_);
    433   current_message_.reset();
    434   return net::OK;
    435 }
    436 
    437 int CastTransport::DoReadError(int result) {
    438   VLOG_WITH_CONNECTION(2) << "DoReadError";
    439   DCHECK_NE(CHANNEL_ERROR_NONE, error_state_);
    440   DCHECK_LE(result, 0);
    441   return net::ERR_FAILED;
    442 }
    443 
    444 }  // namespace cast_channel
    445 }  // namespace core_api
    446 }  // namespace extensions
    447