Home | History | Annotate | Download | only in cast_channel
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "extensions/browser/api/cast_channel/cast_framer.h"
      6 
      7 #include <stdlib.h>
      8 
      9 #include "base/numerics/safe_conversions.h"
     10 #include "base/strings/string_number_conversions.h"
     11 #include "base/sys_byteorder.h"
     12 #include "extensions/common/api/cast_channel/cast_channel.pb.h"
     13 
     14 namespace extensions {
     15 namespace core_api {
     16 namespace cast_channel {
     17 MessageFramer::MessageFramer(scoped_refptr<net::GrowableIOBuffer> input_buffer)
     18     : input_buffer_(input_buffer), error_(false) {
     19   Reset();
     20 }
     21 
     22 MessageFramer::~MessageFramer() {
     23 }
     24 
     25 MessageFramer::MessageHeader::MessageHeader() : message_size(0) {
     26 }
     27 
     28 void MessageFramer::MessageHeader::SetMessageSize(size_t size) {
     29   DCHECK_LT(size, static_cast<size_t>(kuint32max));
     30   DCHECK_GT(size, 0U);
     31   message_size = size;
     32 }
     33 
     34 // TODO(mfoltz): Investigate replacing header serialization with base::Pickle,
     35 // if bit-for-bit compatible.
     36 void MessageFramer::MessageHeader::PrependToString(std::string* str) {
     37   MessageHeader output = *this;
     38   output.message_size = base::HostToNet32(message_size);
     39   size_t header_size = MessageHeader::header_size();
     40   scoped_ptr<char, base::FreeDeleter> char_array(
     41       static_cast<char*>(malloc(header_size)));
     42   memcpy(char_array.get(), &output, header_size);
     43   str->insert(0, char_array.get(), header_size);
     44 }
     45 
     46 // TODO(mfoltz): Investigate replacing header deserialization with base::Pickle,
     47 // if bit-for-bit compatible.
     48 void MessageFramer::MessageHeader::Deserialize(char* data,
     49                                                MessageHeader* header) {
     50   uint32 message_size;
     51   memcpy(&message_size, data, header_size());
     52   header->message_size =
     53       base::checked_cast<size_t>(base::NetToHost32(message_size));
     54 }
     55 
     56 // static
     57 size_t MessageFramer::MessageHeader::header_size() {
     58   return sizeof(uint32);
     59 }
     60 
     61 // static
     62 size_t MessageFramer::MessageHeader::max_message_size() {
     63   return 65535;
     64 }
     65 
     66 std::string MessageFramer::MessageHeader::ToString() {
     67   return "{message_size: " +
     68          base::UintToString(static_cast<uint32>(message_size)) + "}";
     69 }
     70 
     71 // static
     72 bool MessageFramer::Serialize(const CastMessage& message_proto,
     73                               std::string* message_data) {
     74   DCHECK(message_data);
     75   message_proto.SerializeToString(message_data);
     76   size_t message_size = message_data->size();
     77   if (message_size > MessageHeader::max_message_size()) {
     78     message_data->clear();
     79     return false;
     80   }
     81   MessageHeader header;
     82   header.SetMessageSize(message_size);
     83   header.PrependToString(message_data);
     84   return true;
     85 }
     86 
     87 size_t MessageFramer::BytesRequested() {
     88   size_t bytes_left;
     89   if (error_) {
     90     return 0;
     91   }
     92 
     93   switch (current_element_) {
     94     case HEADER:
     95       bytes_left = MessageHeader::header_size() - message_bytes_received_;
     96       DCHECK_LE(bytes_left, MessageHeader::header_size());
     97       VLOG(2) << "Bytes needed for header: " << bytes_left;
     98       return bytes_left;
     99     case BODY:
    100       bytes_left =
    101           (body_size_ + MessageHeader::header_size()) - message_bytes_received_;
    102       DCHECK_LE(
    103           bytes_left,
    104           MessageHeader::max_message_size() - MessageHeader::header_size());
    105       VLOG(2) << "Bytes needed for body: " << bytes_left;
    106       return bytes_left;
    107     default:
    108       NOTREACHED() << "Unhandled packet element type.";
    109       return 0;
    110   }
    111 }
    112 
    113 scoped_ptr<CastMessage> MessageFramer::Ingest(size_t num_bytes,
    114                                               size_t* message_length,
    115                                               ChannelError* error) {
    116   DCHECK(error);
    117   DCHECK(message_length);
    118   if (error_) {
    119     *error = CHANNEL_ERROR_INVALID_MESSAGE;
    120     return scoped_ptr<CastMessage>();
    121   }
    122 
    123   DCHECK_EQ(base::checked_cast<int32>(message_bytes_received_),
    124             input_buffer_->offset());
    125   CHECK_LE(num_bytes, BytesRequested());
    126   message_bytes_received_ += num_bytes;
    127   *error = CHANNEL_ERROR_NONE;
    128   *message_length = 0;
    129   switch (current_element_) {
    130     case HEADER:
    131       if (BytesRequested() == 0) {
    132         MessageHeader header;
    133         MessageHeader::Deserialize(input_buffer_.get()->StartOfBuffer(),
    134                                    &header);
    135         if (header.message_size > MessageHeader::max_message_size()) {
    136           VLOG(1) << "Error parsing header (message size too large).";
    137           *error = CHANNEL_ERROR_INVALID_MESSAGE;
    138           error_ = true;
    139           return scoped_ptr<CastMessage>();
    140         }
    141         current_element_ = BODY;
    142         body_size_ = header.message_size;
    143       }
    144       break;
    145     case BODY:
    146       if (BytesRequested() == 0) {
    147         scoped_ptr<CastMessage> parsed_message(new CastMessage);
    148         if (!parsed_message->ParseFromArray(
    149                 input_buffer_->StartOfBuffer() + MessageHeader::header_size(),
    150                 body_size_)) {
    151           VLOG(1) << "Error parsing packet body.";
    152           *error = CHANNEL_ERROR_INVALID_MESSAGE;
    153           error_ = true;
    154           return scoped_ptr<CastMessage>();
    155         }
    156         *message_length = body_size_;
    157         Reset();
    158         return parsed_message.Pass();
    159       }
    160       break;
    161     default:
    162       NOTREACHED() << "Unhandled packet element type.";
    163       return scoped_ptr<CastMessage>();
    164   }
    165 
    166   input_buffer_->set_offset(message_bytes_received_);
    167   return scoped_ptr<CastMessage>();
    168 }
    169 
    170 void MessageFramer::Reset() {
    171   current_element_ = HEADER;
    172   message_bytes_received_ = 0;
    173   body_size_ = 0;
    174   input_buffer_->set_offset(0);
    175 }
    176 
    177 }  // namespace cast_channel
    178 }  // namespace core_api
    179 }  // namespace extensions
    180