Home | History | Annotate | Download | only in engine
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "google_apis/gcm/engine/connection_handler_impl.h"
      6 
      7 #include "base/message_loop/message_loop.h"
      8 #include "google/protobuf/io/coded_stream.h"
      9 #include "google_apis/gcm/base/mcs_util.h"
     10 #include "google_apis/gcm/base/socket_stream.h"
     11 #include "google_apis/gcm/protocol/mcs.pb.h"
     12 #include "net/base/net_errors.h"
     13 #include "net/socket/stream_socket.h"
     14 
     15 using namespace google::protobuf::io;
     16 
     17 namespace gcm {
     18 
     19 namespace {
     20 
     21 // # of bytes a MCS version packet consumes.
     22 const int kVersionPacketLen = 1;
     23 // # of bytes a tag packet consumes.
     24 const int kTagPacketLen = 1;
     25 // Max # of bytes a length packet consumes.
     26 const int kSizePacketLenMin = 1;
     27 const int kSizePacketLenMax = 2;
     28 
     29 // The current MCS protocol version.
     30 const int kMCSVersion = 41;
     31 
     32 }  // namespace
     33 
     34 ConnectionHandlerImpl::ConnectionHandlerImpl(
     35     base::TimeDelta read_timeout,
     36     const ProtoReceivedCallback& read_callback,
     37     const ProtoSentCallback& write_callback,
     38     const ConnectionChangedCallback& connection_callback)
     39     : read_timeout_(read_timeout),
     40       socket_(NULL),
     41       handshake_complete_(false),
     42       message_tag_(0),
     43       message_size_(0),
     44       read_callback_(read_callback),
     45       write_callback_(write_callback),
     46       connection_callback_(connection_callback),
     47       weak_ptr_factory_(this) {
     48 }
     49 
     50 ConnectionHandlerImpl::~ConnectionHandlerImpl() {
     51 }
     52 
     53 void ConnectionHandlerImpl::Init(
     54     const mcs_proto::LoginRequest& login_request,
     55     net::StreamSocket* socket) {
     56   DCHECK(!read_callback_.is_null());
     57   DCHECK(!write_callback_.is_null());
     58   DCHECK(!connection_callback_.is_null());
     59 
     60   // Invalidate any previously outstanding reads.
     61   weak_ptr_factory_.InvalidateWeakPtrs();
     62 
     63   handshake_complete_ = false;
     64   message_tag_ = 0;
     65   message_size_ = 0;
     66   socket_ = socket;
     67   input_stream_.reset(new SocketInputStream(socket_));
     68   output_stream_.reset(new SocketOutputStream(socket_));
     69 
     70   Login(login_request);
     71 }
     72 
     73 void ConnectionHandlerImpl::Reset() {
     74   CloseConnection();
     75 }
     76 
     77 bool ConnectionHandlerImpl::CanSendMessage() const {
     78   return handshake_complete_ && output_stream_.get() &&
     79       output_stream_->GetState() == SocketOutputStream::EMPTY;
     80 }
     81 
     82 void ConnectionHandlerImpl::SendMessage(
     83     const google::protobuf::MessageLite& message) {
     84   DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
     85   DCHECK(handshake_complete_);
     86 
     87   {
     88     CodedOutputStream coded_output_stream(output_stream_.get());
     89     DVLOG(1) << "Writing proto of size " << message.ByteSize();
     90     int tag = GetMCSProtoTag(message);
     91     DCHECK_NE(tag, -1);
     92     coded_output_stream.WriteRaw(&tag, 1);
     93     coded_output_stream.WriteVarint32(message.ByteSize());
     94     message.SerializeToCodedStream(&coded_output_stream);
     95   }
     96 
     97   if (output_stream_->Flush(
     98           base::Bind(&ConnectionHandlerImpl::OnMessageSent,
     99                      weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
    100     OnMessageSent();
    101   }
    102 }
    103 
    104 void ConnectionHandlerImpl::Login(
    105     const google::protobuf::MessageLite& login_request) {
    106   DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
    107 
    108   const char version_byte[1] = {kMCSVersion};
    109   const char login_request_tag[1] = {kLoginRequestTag};
    110   {
    111     CodedOutputStream coded_output_stream(output_stream_.get());
    112     coded_output_stream.WriteRaw(version_byte, 1);
    113     coded_output_stream.WriteRaw(login_request_tag, 1);
    114     coded_output_stream.WriteVarint32(login_request.ByteSize());
    115     login_request.SerializeToCodedStream(&coded_output_stream);
    116   }
    117 
    118   if (output_stream_->Flush(
    119           base::Bind(&ConnectionHandlerImpl::OnMessageSent,
    120                      weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
    121     base::MessageLoop::current()->PostTask(
    122         FROM_HERE,
    123         base::Bind(&ConnectionHandlerImpl::OnMessageSent,
    124                    weak_ptr_factory_.GetWeakPtr()));
    125   }
    126 
    127   read_timeout_timer_.Start(FROM_HERE,
    128                             read_timeout_,
    129                             base::Bind(&ConnectionHandlerImpl::OnTimeout,
    130                                        weak_ptr_factory_.GetWeakPtr()));
    131   WaitForData(MCS_VERSION_TAG_AND_SIZE);
    132 }
    133 
    134 void ConnectionHandlerImpl::OnMessageSent() {
    135   if (!output_stream_.get()) {
    136     // The connection has already been closed. Just return.
    137     DCHECK(!input_stream_.get());
    138     DCHECK(!read_timeout_timer_.IsRunning());
    139     return;
    140   }
    141 
    142   if (output_stream_->GetState() != SocketOutputStream::EMPTY) {
    143     int last_error = output_stream_->last_error();
    144     CloseConnection();
    145     // If the socket stream had an error, plumb it up, else plumb up FAILED.
    146     if (last_error == net::OK)
    147       last_error = net::ERR_FAILED;
    148     connection_callback_.Run(last_error);
    149     return;
    150   }
    151 
    152   write_callback_.Run();
    153 }
    154 
    155 void ConnectionHandlerImpl::GetNextMessage() {
    156   DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() ||
    157          SocketInputStream::READY == input_stream_->GetState());
    158   message_tag_ = 0;
    159   message_size_ = 0;
    160 
    161   WaitForData(MCS_TAG_AND_SIZE);
    162 }
    163 
    164 void ConnectionHandlerImpl::WaitForData(ProcessingState state) {
    165   DVLOG(1) << "Waiting for MCS data: state == " << state;
    166 
    167   if (!input_stream_) {
    168     // The connection has already been closed. Just return.
    169     DCHECK(!output_stream_.get());
    170     DCHECK(!read_timeout_timer_.IsRunning());
    171     return;
    172   }
    173 
    174   if (input_stream_->GetState() != SocketInputStream::EMPTY &&
    175       input_stream_->GetState() != SocketInputStream::READY) {
    176     // An error occurred.
    177     int last_error = output_stream_->last_error();
    178     CloseConnection();
    179     // If the socket stream had an error, plumb it up, else plumb up FAILED.
    180     if (last_error == net::OK)
    181       last_error = net::ERR_FAILED;
    182     connection_callback_.Run(last_error);
    183     return;
    184   }
    185 
    186   // Used to determine whether a Socket::Read is necessary.
    187   size_t min_bytes_needed = 0;
    188   // Used to limit the size of the Socket::Read.
    189   size_t max_bytes_needed = 0;
    190 
    191   switch(state) {
    192     case MCS_VERSION_TAG_AND_SIZE:
    193       min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin;
    194       max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax;
    195       break;
    196     case MCS_TAG_AND_SIZE:
    197       min_bytes_needed = kTagPacketLen + kSizePacketLenMin;
    198       max_bytes_needed = kTagPacketLen + kSizePacketLenMax;
    199       break;
    200     case MCS_FULL_SIZE:
    201       // If in this state, the minimum size packet length must already have been
    202       // insufficient, so set both to the max length.
    203       min_bytes_needed = kSizePacketLenMax;
    204       max_bytes_needed = kSizePacketLenMax;
    205       break;
    206     case MCS_PROTO_BYTES:
    207       read_timeout_timer_.Reset();
    208       // No variability in the message size, set both to the same.
    209       min_bytes_needed = message_size_;
    210       max_bytes_needed = message_size_;
    211       break;
    212     default:
    213       NOTREACHED();
    214   }
    215   DCHECK_GE(max_bytes_needed, min_bytes_needed);
    216 
    217   size_t unread_byte_count = input_stream_->UnreadByteCount();
    218   if (min_bytes_needed > unread_byte_count &&
    219       input_stream_->Refresh(
    220           base::Bind(&ConnectionHandlerImpl::WaitForData,
    221                      weak_ptr_factory_.GetWeakPtr(),
    222                      state),
    223           max_bytes_needed - unread_byte_count) == net::ERR_IO_PENDING) {
    224     return;
    225   }
    226 
    227   // Check for refresh errors.
    228   if (input_stream_->GetState() != SocketInputStream::READY) {
    229     // An error occurred.
    230     int last_error = input_stream_->last_error();
    231     CloseConnection();
    232     // If the socket stream had an error, plumb it up, else plumb up FAILED.
    233     if (last_error == net::OK)
    234       last_error = net::ERR_FAILED;
    235     connection_callback_.Run(last_error);
    236     return;
    237   }
    238 
    239   // Check whether read is complete, or needs to be continued (
    240   // SocketInputStream::Refresh can finish without reading all the data).
    241   if (input_stream_->UnreadByteCount() < min_bytes_needed) {
    242     DVLOG(1) << "Socket read finished prematurely. Waiting for "
    243              << min_bytes_needed - input_stream_->UnreadByteCount()
    244              << " more bytes.";
    245     base::MessageLoop::current()->PostTask(
    246         FROM_HERE,
    247         base::Bind(&ConnectionHandlerImpl::WaitForData,
    248                    weak_ptr_factory_.GetWeakPtr(),
    249                    MCS_PROTO_BYTES));
    250     return;
    251   }
    252 
    253   // Received enough bytes, process them.
    254   DVLOG(1) << "Processing MCS data: state == " << state;
    255   switch(state) {
    256     case MCS_VERSION_TAG_AND_SIZE:
    257       OnGotVersion();
    258       break;
    259     case MCS_TAG_AND_SIZE:
    260       OnGotMessageTag();
    261       break;
    262     case MCS_FULL_SIZE:
    263       OnGotMessageSize();
    264       break;
    265     case MCS_PROTO_BYTES:
    266       OnGotMessageBytes();
    267       break;
    268     default:
    269       NOTREACHED();
    270   }
    271 }
    272 
    273 void ConnectionHandlerImpl::OnGotVersion() {
    274   uint8 version = 0;
    275   {
    276     CodedInputStream coded_input_stream(input_stream_.get());
    277     coded_input_stream.ReadRaw(&version, 1);
    278   }
    279   // TODO(zea): remove this when the server is ready.
    280   if (version < kMCSVersion && version != 38) {
    281     LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version);
    282     connection_callback_.Run(net::ERR_FAILED);
    283     return;
    284   }
    285 
    286   input_stream_->RebuildBuffer();
    287 
    288   // Process the LoginResponse message tag.
    289   OnGotMessageTag();
    290 }
    291 
    292 void ConnectionHandlerImpl::OnGotMessageTag() {
    293   if (input_stream_->GetState() != SocketInputStream::READY) {
    294     LOG(ERROR) << "Failed to receive protobuf tag.";
    295     read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
    296     return;
    297   }
    298 
    299   {
    300     CodedInputStream coded_input_stream(input_stream_.get());
    301     coded_input_stream.ReadRaw(&message_tag_, 1);
    302   }
    303 
    304   DVLOG(1) << "Received proto of type "
    305            << static_cast<unsigned int>(message_tag_);
    306 
    307   if (!read_timeout_timer_.IsRunning()) {
    308     read_timeout_timer_.Start(FROM_HERE,
    309                               read_timeout_,
    310                               base::Bind(&ConnectionHandlerImpl::OnTimeout,
    311                                          weak_ptr_factory_.GetWeakPtr()));
    312   }
    313   OnGotMessageSize();
    314 }
    315 
    316 void ConnectionHandlerImpl::OnGotMessageSize() {
    317   if (input_stream_->GetState() != SocketInputStream::READY) {
    318     LOG(ERROR) << "Failed to receive message size.";
    319     read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
    320     return;
    321   }
    322 
    323   bool need_another_byte = false;
    324   int prev_byte_count = input_stream_->ByteCount();
    325   {
    326     CodedInputStream coded_input_stream(input_stream_.get());
    327     if (!coded_input_stream.ReadVarint32(&message_size_))
    328       need_another_byte = true;
    329   }
    330 
    331   if (need_another_byte) {
    332     DVLOG(1) << "Expecting another message size byte.";
    333     if (prev_byte_count >= kSizePacketLenMax) {
    334       // Already had enough bytes, something else went wrong.
    335       LOG(ERROR) << "Failed to process message size.";
    336       read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
    337       return;
    338     }
    339     // Back up by the amount read (should always be 1 byte).
    340     int bytes_read = prev_byte_count - input_stream_->ByteCount();
    341     DCHECK_EQ(bytes_read, 1);
    342     input_stream_->BackUp(bytes_read);
    343     WaitForData(MCS_FULL_SIZE);
    344     return;
    345   }
    346 
    347   DVLOG(1) << "Proto size: " << message_size_;
    348 
    349   if (message_size_ > 0)
    350     WaitForData(MCS_PROTO_BYTES);
    351   else
    352     OnGotMessageBytes();
    353 }
    354 
    355 void ConnectionHandlerImpl::OnGotMessageBytes() {
    356   read_timeout_timer_.Stop();
    357   scoped_ptr<google::protobuf::MessageLite> protobuf(
    358       BuildProtobufFromTag(message_tag_));
    359   // Messages with no content are valid; just use the default protobuf for
    360   // that tag.
    361   if (protobuf.get() && message_size_ == 0) {
    362     base::MessageLoop::current()->PostTask(
    363         FROM_HERE,
    364         base::Bind(&ConnectionHandlerImpl::GetNextMessage,
    365                    weak_ptr_factory_.GetWeakPtr()));
    366     read_callback_.Run(protobuf.Pass());
    367     return;
    368   }
    369 
    370   if (!protobuf.get() ||
    371       input_stream_->GetState() != SocketInputStream::READY) {
    372     LOG(ERROR) << "Failed to extract protobuf bytes of type "
    373                << static_cast<unsigned int>(message_tag_);
    374     // Reset the connection.
    375     connection_callback_.Run(net::ERR_FAILED);
    376     return;
    377   }
    378 
    379   {
    380     CodedInputStream coded_input_stream(input_stream_.get());
    381     if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) {
    382       LOG(ERROR) << "Unable to parse GCM message of type "
    383                  << static_cast<unsigned int>(message_tag_);
    384       // Reset the connection.
    385       connection_callback_.Run(net::ERR_FAILED);
    386       return;
    387     }
    388   }
    389 
    390   input_stream_->RebuildBuffer();
    391   base::MessageLoop::current()->PostTask(
    392       FROM_HERE,
    393       base::Bind(&ConnectionHandlerImpl::GetNextMessage,
    394                  weak_ptr_factory_.GetWeakPtr()));
    395   if (message_tag_ == kLoginResponseTag) {
    396     if (handshake_complete_) {
    397       LOG(ERROR) << "Unexpected login response.";
    398     } else {
    399       handshake_complete_ = true;
    400       DVLOG(1) << "GCM Handshake complete.";
    401       connection_callback_.Run(net::OK);
    402     }
    403   }
    404   read_callback_.Run(protobuf.Pass());
    405 }
    406 
    407 void ConnectionHandlerImpl::OnTimeout() {
    408   LOG(ERROR) << "Timed out waiting for GCM Protocol buffer.";
    409   CloseConnection();
    410   connection_callback_.Run(net::ERR_TIMED_OUT);
    411 }
    412 
    413 void ConnectionHandlerImpl::CloseConnection() {
    414   DVLOG(1) << "Closing connection.";
    415   read_timeout_timer_.Stop();
    416   if (socket_)
    417     socket_->Disconnect();
    418   socket_ = NULL;
    419   handshake_complete_ = false;
    420   message_tag_ = 0;
    421   message_size_ = 0;
    422   input_stream_.reset();
    423   output_stream_.reset();
    424   weak_ptr_factory_.InvalidateWeakPtrs();
    425 }
    426 
    427 }  // namespace gcm
    428