1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "google_apis/gcm/engine/connection_handler_impl.h" 6 7 #include "base/message_loop/message_loop.h" 8 #include "google/protobuf/io/coded_stream.h" 9 #include "google_apis/gcm/base/mcs_util.h" 10 #include "google_apis/gcm/base/socket_stream.h" 11 #include "google_apis/gcm/protocol/mcs.pb.h" 12 #include "net/base/net_errors.h" 13 #include "net/socket/stream_socket.h" 14 15 using namespace google::protobuf::io; 16 17 namespace gcm { 18 19 namespace { 20 21 // # of bytes a MCS version packet consumes. 22 const int kVersionPacketLen = 1; 23 // # of bytes a tag packet consumes. 24 const int kTagPacketLen = 1; 25 // Max # of bytes a length packet consumes. 26 const int kSizePacketLenMin = 1; 27 const int kSizePacketLenMax = 2; 28 29 // The current MCS protocol version. 30 const int kMCSVersion = 41; 31 32 } // namespace 33 34 ConnectionHandlerImpl::ConnectionHandlerImpl( 35 base::TimeDelta read_timeout, 36 const ProtoReceivedCallback& read_callback, 37 const ProtoSentCallback& write_callback, 38 const ConnectionChangedCallback& connection_callback) 39 : read_timeout_(read_timeout), 40 socket_(NULL), 41 handshake_complete_(false), 42 message_tag_(0), 43 message_size_(0), 44 read_callback_(read_callback), 45 write_callback_(write_callback), 46 connection_callback_(connection_callback), 47 weak_ptr_factory_(this) { 48 } 49 50 ConnectionHandlerImpl::~ConnectionHandlerImpl() { 51 } 52 53 void ConnectionHandlerImpl::Init( 54 const mcs_proto::LoginRequest& login_request, 55 net::StreamSocket* socket) { 56 DCHECK(!read_callback_.is_null()); 57 DCHECK(!write_callback_.is_null()); 58 DCHECK(!connection_callback_.is_null()); 59 60 // Invalidate any previously outstanding reads. 61 weak_ptr_factory_.InvalidateWeakPtrs(); 62 63 handshake_complete_ = false; 64 message_tag_ = 0; 65 message_size_ = 0; 66 socket_ = socket; 67 input_stream_.reset(new SocketInputStream(socket_)); 68 output_stream_.reset(new SocketOutputStream(socket_)); 69 70 Login(login_request); 71 } 72 73 void ConnectionHandlerImpl::Reset() { 74 CloseConnection(); 75 } 76 77 bool ConnectionHandlerImpl::CanSendMessage() const { 78 return handshake_complete_ && output_stream_.get() && 79 output_stream_->GetState() == SocketOutputStream::EMPTY; 80 } 81 82 void ConnectionHandlerImpl::SendMessage( 83 const google::protobuf::MessageLite& message) { 84 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY); 85 DCHECK(handshake_complete_); 86 87 { 88 CodedOutputStream coded_output_stream(output_stream_.get()); 89 DVLOG(1) << "Writing proto of size " << message.ByteSize(); 90 int tag = GetMCSProtoTag(message); 91 DCHECK_NE(tag, -1); 92 coded_output_stream.WriteRaw(&tag, 1); 93 coded_output_stream.WriteVarint32(message.ByteSize()); 94 message.SerializeToCodedStream(&coded_output_stream); 95 } 96 97 if (output_stream_->Flush( 98 base::Bind(&ConnectionHandlerImpl::OnMessageSent, 99 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) { 100 OnMessageSent(); 101 } 102 } 103 104 void ConnectionHandlerImpl::Login( 105 const google::protobuf::MessageLite& login_request) { 106 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY); 107 108 const char version_byte[1] = {kMCSVersion}; 109 const char login_request_tag[1] = {kLoginRequestTag}; 110 { 111 CodedOutputStream coded_output_stream(output_stream_.get()); 112 coded_output_stream.WriteRaw(version_byte, 1); 113 coded_output_stream.WriteRaw(login_request_tag, 1); 114 coded_output_stream.WriteVarint32(login_request.ByteSize()); 115 login_request.SerializeToCodedStream(&coded_output_stream); 116 } 117 118 if (output_stream_->Flush( 119 base::Bind(&ConnectionHandlerImpl::OnMessageSent, 120 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) { 121 base::MessageLoop::current()->PostTask( 122 FROM_HERE, 123 base::Bind(&ConnectionHandlerImpl::OnMessageSent, 124 weak_ptr_factory_.GetWeakPtr())); 125 } 126 127 read_timeout_timer_.Start(FROM_HERE, 128 read_timeout_, 129 base::Bind(&ConnectionHandlerImpl::OnTimeout, 130 weak_ptr_factory_.GetWeakPtr())); 131 WaitForData(MCS_VERSION_TAG_AND_SIZE); 132 } 133 134 void ConnectionHandlerImpl::OnMessageSent() { 135 if (!output_stream_.get()) { 136 // The connection has already been closed. Just return. 137 DCHECK(!input_stream_.get()); 138 DCHECK(!read_timeout_timer_.IsRunning()); 139 return; 140 } 141 142 if (output_stream_->GetState() != SocketOutputStream::EMPTY) { 143 int last_error = output_stream_->last_error(); 144 CloseConnection(); 145 // If the socket stream had an error, plumb it up, else plumb up FAILED. 146 if (last_error == net::OK) 147 last_error = net::ERR_FAILED; 148 connection_callback_.Run(last_error); 149 return; 150 } 151 152 write_callback_.Run(); 153 } 154 155 void ConnectionHandlerImpl::GetNextMessage() { 156 DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() || 157 SocketInputStream::READY == input_stream_->GetState()); 158 message_tag_ = 0; 159 message_size_ = 0; 160 161 WaitForData(MCS_TAG_AND_SIZE); 162 } 163 164 void ConnectionHandlerImpl::WaitForData(ProcessingState state) { 165 DVLOG(1) << "Waiting for MCS data: state == " << state; 166 167 if (!input_stream_) { 168 // The connection has already been closed. Just return. 169 DCHECK(!output_stream_.get()); 170 DCHECK(!read_timeout_timer_.IsRunning()); 171 return; 172 } 173 174 if (input_stream_->GetState() != SocketInputStream::EMPTY && 175 input_stream_->GetState() != SocketInputStream::READY) { 176 // An error occurred. 177 int last_error = output_stream_->last_error(); 178 CloseConnection(); 179 // If the socket stream had an error, plumb it up, else plumb up FAILED. 180 if (last_error == net::OK) 181 last_error = net::ERR_FAILED; 182 connection_callback_.Run(last_error); 183 return; 184 } 185 186 // Used to determine whether a Socket::Read is necessary. 187 size_t min_bytes_needed = 0; 188 // Used to limit the size of the Socket::Read. 189 size_t max_bytes_needed = 0; 190 191 switch(state) { 192 case MCS_VERSION_TAG_AND_SIZE: 193 min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin; 194 max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax; 195 break; 196 case MCS_TAG_AND_SIZE: 197 min_bytes_needed = kTagPacketLen + kSizePacketLenMin; 198 max_bytes_needed = kTagPacketLen + kSizePacketLenMax; 199 break; 200 case MCS_FULL_SIZE: 201 // If in this state, the minimum size packet length must already have been 202 // insufficient, so set both to the max length. 203 min_bytes_needed = kSizePacketLenMax; 204 max_bytes_needed = kSizePacketLenMax; 205 break; 206 case MCS_PROTO_BYTES: 207 read_timeout_timer_.Reset(); 208 // No variability in the message size, set both to the same. 209 min_bytes_needed = message_size_; 210 max_bytes_needed = message_size_; 211 break; 212 default: 213 NOTREACHED(); 214 } 215 DCHECK_GE(max_bytes_needed, min_bytes_needed); 216 217 size_t unread_byte_count = input_stream_->UnreadByteCount(); 218 if (min_bytes_needed > unread_byte_count && 219 input_stream_->Refresh( 220 base::Bind(&ConnectionHandlerImpl::WaitForData, 221 weak_ptr_factory_.GetWeakPtr(), 222 state), 223 max_bytes_needed - unread_byte_count) == net::ERR_IO_PENDING) { 224 return; 225 } 226 227 // Check for refresh errors. 228 if (input_stream_->GetState() != SocketInputStream::READY) { 229 // An error occurred. 230 int last_error = input_stream_->last_error(); 231 CloseConnection(); 232 // If the socket stream had an error, plumb it up, else plumb up FAILED. 233 if (last_error == net::OK) 234 last_error = net::ERR_FAILED; 235 connection_callback_.Run(last_error); 236 return; 237 } 238 239 // Check whether read is complete, or needs to be continued ( 240 // SocketInputStream::Refresh can finish without reading all the data). 241 if (input_stream_->UnreadByteCount() < min_bytes_needed) { 242 DVLOG(1) << "Socket read finished prematurely. Waiting for " 243 << min_bytes_needed - input_stream_->UnreadByteCount() 244 << " more bytes."; 245 base::MessageLoop::current()->PostTask( 246 FROM_HERE, 247 base::Bind(&ConnectionHandlerImpl::WaitForData, 248 weak_ptr_factory_.GetWeakPtr(), 249 MCS_PROTO_BYTES)); 250 return; 251 } 252 253 // Received enough bytes, process them. 254 DVLOG(1) << "Processing MCS data: state == " << state; 255 switch(state) { 256 case MCS_VERSION_TAG_AND_SIZE: 257 OnGotVersion(); 258 break; 259 case MCS_TAG_AND_SIZE: 260 OnGotMessageTag(); 261 break; 262 case MCS_FULL_SIZE: 263 OnGotMessageSize(); 264 break; 265 case MCS_PROTO_BYTES: 266 OnGotMessageBytes(); 267 break; 268 default: 269 NOTREACHED(); 270 } 271 } 272 273 void ConnectionHandlerImpl::OnGotVersion() { 274 uint8 version = 0; 275 { 276 CodedInputStream coded_input_stream(input_stream_.get()); 277 coded_input_stream.ReadRaw(&version, 1); 278 } 279 // TODO(zea): remove this when the server is ready. 280 if (version < kMCSVersion && version != 38) { 281 LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version); 282 connection_callback_.Run(net::ERR_FAILED); 283 return; 284 } 285 286 input_stream_->RebuildBuffer(); 287 288 // Process the LoginResponse message tag. 289 OnGotMessageTag(); 290 } 291 292 void ConnectionHandlerImpl::OnGotMessageTag() { 293 if (input_stream_->GetState() != SocketInputStream::READY) { 294 LOG(ERROR) << "Failed to receive protobuf tag."; 295 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); 296 return; 297 } 298 299 { 300 CodedInputStream coded_input_stream(input_stream_.get()); 301 coded_input_stream.ReadRaw(&message_tag_, 1); 302 } 303 304 DVLOG(1) << "Received proto of type " 305 << static_cast<unsigned int>(message_tag_); 306 307 if (!read_timeout_timer_.IsRunning()) { 308 read_timeout_timer_.Start(FROM_HERE, 309 read_timeout_, 310 base::Bind(&ConnectionHandlerImpl::OnTimeout, 311 weak_ptr_factory_.GetWeakPtr())); 312 } 313 OnGotMessageSize(); 314 } 315 316 void ConnectionHandlerImpl::OnGotMessageSize() { 317 if (input_stream_->GetState() != SocketInputStream::READY) { 318 LOG(ERROR) << "Failed to receive message size."; 319 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); 320 return; 321 } 322 323 bool need_another_byte = false; 324 int prev_byte_count = input_stream_->ByteCount(); 325 { 326 CodedInputStream coded_input_stream(input_stream_.get()); 327 if (!coded_input_stream.ReadVarint32(&message_size_)) 328 need_another_byte = true; 329 } 330 331 if (need_another_byte) { 332 DVLOG(1) << "Expecting another message size byte."; 333 if (prev_byte_count >= kSizePacketLenMax) { 334 // Already had enough bytes, something else went wrong. 335 LOG(ERROR) << "Failed to process message size."; 336 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); 337 return; 338 } 339 // Back up by the amount read (should always be 1 byte). 340 int bytes_read = prev_byte_count - input_stream_->ByteCount(); 341 DCHECK_EQ(bytes_read, 1); 342 input_stream_->BackUp(bytes_read); 343 WaitForData(MCS_FULL_SIZE); 344 return; 345 } 346 347 DVLOG(1) << "Proto size: " << message_size_; 348 349 if (message_size_ > 0) 350 WaitForData(MCS_PROTO_BYTES); 351 else 352 OnGotMessageBytes(); 353 } 354 355 void ConnectionHandlerImpl::OnGotMessageBytes() { 356 read_timeout_timer_.Stop(); 357 scoped_ptr<google::protobuf::MessageLite> protobuf( 358 BuildProtobufFromTag(message_tag_)); 359 // Messages with no content are valid; just use the default protobuf for 360 // that tag. 361 if (protobuf.get() && message_size_ == 0) { 362 base::MessageLoop::current()->PostTask( 363 FROM_HERE, 364 base::Bind(&ConnectionHandlerImpl::GetNextMessage, 365 weak_ptr_factory_.GetWeakPtr())); 366 read_callback_.Run(protobuf.Pass()); 367 return; 368 } 369 370 if (!protobuf.get() || 371 input_stream_->GetState() != SocketInputStream::READY) { 372 LOG(ERROR) << "Failed to extract protobuf bytes of type " 373 << static_cast<unsigned int>(message_tag_); 374 // Reset the connection. 375 connection_callback_.Run(net::ERR_FAILED); 376 return; 377 } 378 379 { 380 CodedInputStream coded_input_stream(input_stream_.get()); 381 if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) { 382 LOG(ERROR) << "Unable to parse GCM message of type " 383 << static_cast<unsigned int>(message_tag_); 384 // Reset the connection. 385 connection_callback_.Run(net::ERR_FAILED); 386 return; 387 } 388 } 389 390 input_stream_->RebuildBuffer(); 391 base::MessageLoop::current()->PostTask( 392 FROM_HERE, 393 base::Bind(&ConnectionHandlerImpl::GetNextMessage, 394 weak_ptr_factory_.GetWeakPtr())); 395 if (message_tag_ == kLoginResponseTag) { 396 if (handshake_complete_) { 397 LOG(ERROR) << "Unexpected login response."; 398 } else { 399 handshake_complete_ = true; 400 DVLOG(1) << "GCM Handshake complete."; 401 connection_callback_.Run(net::OK); 402 } 403 } 404 read_callback_.Run(protobuf.Pass()); 405 } 406 407 void ConnectionHandlerImpl::OnTimeout() { 408 LOG(ERROR) << "Timed out waiting for GCM Protocol buffer."; 409 CloseConnection(); 410 connection_callback_.Run(net::ERR_TIMED_OUT); 411 } 412 413 void ConnectionHandlerImpl::CloseConnection() { 414 DVLOG(1) << "Closing connection."; 415 read_timeout_timer_.Stop(); 416 if (socket_) 417 socket_->Disconnect(); 418 socket_ = NULL; 419 handshake_complete_ = false; 420 message_tag_ = 0; 421 message_size_ = 0; 422 input_stream_.reset(); 423 output_stream_.reset(); 424 weak_ptr_factory_.InvalidateWeakPtrs(); 425 } 426 427 } // namespace gcm 428