1 // Copyright 2014 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "extensions/browser/api/cast_channel/cast_socket.h" 6 7 #include <stdlib.h> 8 #include <string.h> 9 10 #include "base/bind.h" 11 #include "base/callback_helpers.h" 12 #include "base/format_macros.h" 13 #include "base/lazy_instance.h" 14 #include "base/numerics/safe_conversions.h" 15 #include "base/strings/string_number_conversions.h" 16 #include "base/strings/stringprintf.h" 17 #include "base/sys_byteorder.h" 18 #include "extensions/browser/api/cast_channel/cast_auth_util.h" 19 #include "extensions/browser/api/cast_channel/cast_framer.h" 20 #include "extensions/browser/api/cast_channel/cast_message_util.h" 21 #include "extensions/browser/api/cast_channel/logger.h" 22 #include "extensions/browser/api/cast_channel/logger_util.h" 23 #include "extensions/common/api/cast_channel/cast_channel.pb.h" 24 #include "net/base/address_list.h" 25 #include "net/base/host_port_pair.h" 26 #include "net/base/net_errors.h" 27 #include "net/base/net_util.h" 28 #include "net/cert/cert_verifier.h" 29 #include "net/cert/x509_certificate.h" 30 #include "net/http/transport_security_state.h" 31 #include "net/socket/client_socket_factory.h" 32 #include "net/socket/client_socket_handle.h" 33 #include "net/socket/ssl_client_socket.h" 34 #include "net/socket/stream_socket.h" 35 #include "net/socket/tcp_client_socket.h" 36 #include "net/ssl/ssl_config_service.h" 37 #include "net/ssl/ssl_info.h" 38 39 // Assumes |ip_endpoint_| of type net::IPEndPoint and |channel_auth_| of enum 40 // type ChannelAuthType are available in the current scope. 41 #define VLOG_WITH_CONNECTION(level) VLOG(level) << "[" << \ 42 ip_endpoint_.ToString() << ", auth=" << channel_auth_ << "] " 43 44 namespace { 45 46 // The default keepalive delay. On Linux, keepalives probes will be sent after 47 // the socket is idle for this length of time, and the socket will be closed 48 // after 9 failed probes. So the total idle time before close is 10 * 49 // kTcpKeepAliveDelaySecs. 50 const int kTcpKeepAliveDelaySecs = 10; 51 } // namespace 52 53 namespace extensions { 54 55 static base::LazyInstance<BrowserContextKeyedAPIFactory< 56 ApiResourceManager<core_api::cast_channel::CastSocket> > > g_factory = 57 LAZY_INSTANCE_INITIALIZER; 58 59 // static 60 template <> 61 BrowserContextKeyedAPIFactory< 62 ApiResourceManager<core_api::cast_channel::CastSocket> >* 63 ApiResourceManager<core_api::cast_channel::CastSocket>::GetFactoryInstance() { 64 return g_factory.Pointer(); 65 } 66 67 namespace core_api { 68 namespace cast_channel { 69 CastSocket::CastSocket(const std::string& owner_extension_id, 70 const net::IPEndPoint& ip_endpoint, 71 ChannelAuthType channel_auth, 72 CastSocket::Delegate* delegate, 73 net::NetLog* net_log, 74 const base::TimeDelta& timeout, 75 const scoped_refptr<Logger>& logger) 76 : ApiResource(owner_extension_id), 77 channel_id_(0), 78 ip_endpoint_(ip_endpoint), 79 channel_auth_(channel_auth), 80 delegate_(delegate), 81 net_log_(net_log), 82 logger_(logger), 83 connect_timeout_(timeout), 84 connect_timeout_timer_(new base::OneShotTimer<CastSocket>), 85 is_canceled_(false), 86 connect_state_(proto::CONN_STATE_NONE), 87 write_state_(proto::WRITE_STATE_NONE), 88 read_state_(proto::READ_STATE_NONE), 89 error_state_(CHANNEL_ERROR_NONE), 90 ready_state_(READY_STATE_NONE) { 91 DCHECK(net_log_); 92 DCHECK(channel_auth_ == CHANNEL_AUTH_TYPE_SSL || 93 channel_auth_ == CHANNEL_AUTH_TYPE_SSL_VERIFIED); 94 net_log_source_.type = net::NetLog::SOURCE_SOCKET; 95 net_log_source_.id = net_log_->NextID(); 96 97 // Buffer is reused across messages. 98 read_buffer_ = new net::GrowableIOBuffer(); 99 read_buffer_->SetCapacity(MessageFramer::MessageHeader::max_message_size()); 100 framer_.reset(new MessageFramer(read_buffer_)); 101 } 102 103 CastSocket::~CastSocket() { 104 // Ensure that resources are freed but do not run pending callbacks to avoid 105 // any re-entrancy. 106 CloseInternal(); 107 } 108 109 ReadyState CastSocket::ready_state() const { 110 return ready_state_; 111 } 112 113 ChannelError CastSocket::error_state() const { 114 return error_state_; 115 } 116 117 scoped_ptr<net::TCPClientSocket> CastSocket::CreateTcpSocket() { 118 net::AddressList addresses(ip_endpoint_); 119 return scoped_ptr<net::TCPClientSocket>( 120 new net::TCPClientSocket(addresses, net_log_, net_log_source_)); 121 // Options cannot be set on the TCPClientSocket yet, because the 122 // underlying platform socket will not be created until Bind() 123 // or Connect() is called. 124 } 125 126 scoped_ptr<net::SSLClientSocket> CastSocket::CreateSslSocket( 127 scoped_ptr<net::StreamSocket> socket) { 128 net::SSLConfig ssl_config; 129 // If a peer cert was extracted in a previous attempt to connect, then 130 // whitelist that cert. 131 if (!peer_cert_.empty()) { 132 net::SSLConfig::CertAndStatus cert_and_status; 133 cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID; 134 cert_and_status.der_cert = peer_cert_; 135 ssl_config.allowed_bad_certs.push_back(cert_and_status); 136 logger_->LogSocketEvent(channel_id_, proto::SSL_CERT_WHITELISTED); 137 } 138 139 cert_verifier_.reset(net::CertVerifier::CreateDefault()); 140 transport_security_state_.reset(new net::TransportSecurityState); 141 net::SSLClientSocketContext context; 142 // CertVerifier and TransportSecurityState are owned by us, not the 143 // context object. 144 context.cert_verifier = cert_verifier_.get(); 145 context.transport_security_state = transport_security_state_.get(); 146 147 scoped_ptr<net::ClientSocketHandle> connection(new net::ClientSocketHandle); 148 connection->SetSocket(socket.Pass()); 149 net::HostPortPair host_and_port = net::HostPortPair::FromIPEndPoint( 150 ip_endpoint_); 151 152 return net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket( 153 connection.Pass(), host_and_port, ssl_config, context); 154 } 155 156 bool CastSocket::ExtractPeerCert(std::string* cert) { 157 DCHECK(cert); 158 DCHECK(peer_cert_.empty()); 159 net::SSLInfo ssl_info; 160 if (!socket_->GetSSLInfo(&ssl_info) || !ssl_info.cert.get()) { 161 return false; 162 } 163 164 logger_->LogSocketEvent(channel_id_, proto::SSL_INFO_OBTAINED); 165 166 bool result = net::X509Certificate::GetDEREncoded( 167 ssl_info.cert->os_cert_handle(), cert); 168 if (result) { 169 VLOG_WITH_CONNECTION(1) << "Successfully extracted peer certificate: " 170 << *cert; 171 } 172 173 logger_->LogSocketEventWithRv( 174 channel_id_, proto::DER_ENCODED_CERT_OBTAIN, result ? 1 : 0); 175 return result; 176 } 177 178 bool CastSocket::VerifyChallengeReply() { 179 AuthResult result = AuthenticateChallengeReply(*challenge_reply_, peer_cert_); 180 logger_->LogSocketChallengeReplyEvent(channel_id_, result); 181 return result.success(); 182 } 183 184 void CastSocket::Connect(const net::CompletionCallback& callback) { 185 DCHECK(CalledOnValidThread()); 186 VLOG_WITH_CONNECTION(1) << "Connect readyState = " << ready_state_; 187 if (ready_state_ != READY_STATE_NONE) { 188 logger_->LogSocketEventWithDetails( 189 channel_id_, proto::CONNECT_FAILED, "ReadyState not NONE"); 190 callback.Run(net::ERR_CONNECTION_FAILED); 191 return; 192 } 193 194 connect_callback_ = callback; 195 SetReadyState(READY_STATE_CONNECTING); 196 SetConnectState(proto::CONN_STATE_TCP_CONNECT); 197 198 if (connect_timeout_.InMicroseconds() > 0) { 199 DCHECK(connect_timeout_callback_.IsCancelled()); 200 connect_timeout_callback_.Reset( 201 base::Bind(&CastSocket::OnConnectTimeout, base::Unretained(this))); 202 GetTimer()->Start(FROM_HERE, 203 connect_timeout_, 204 connect_timeout_callback_.callback()); 205 } 206 DoConnectLoop(net::OK); 207 } 208 209 void CastSocket::PostTaskToStartConnectLoop(int result) { 210 DCHECK(CalledOnValidThread()); 211 DCHECK(connect_loop_callback_.IsCancelled()); 212 connect_loop_callback_.Reset( 213 base::Bind(&CastSocket::DoConnectLoop, base::Unretained(this), result)); 214 base::MessageLoop::current()->PostTask(FROM_HERE, 215 connect_loop_callback_.callback()); 216 } 217 218 void CastSocket::OnConnectTimeout() { 219 DCHECK(CalledOnValidThread()); 220 // Stop all pending connection setup tasks and report back to the client. 221 is_canceled_ = true; 222 logger_->LogSocketEvent(channel_id_, proto::CONNECT_TIMED_OUT); 223 VLOG_WITH_CONNECTION(1) << "Timeout while establishing a connection."; 224 DoConnectCallback(net::ERR_TIMED_OUT); 225 } 226 227 // This method performs the state machine transitions for connection flow. 228 // There are two entry points to this method: 229 // 1. Connect method: this starts the flow 230 // 2. Callback from network operations that finish asynchronously 231 void CastSocket::DoConnectLoop(int result) { 232 connect_loop_callback_.Cancel(); 233 if (is_canceled_) { 234 LOG(ERROR) << "CANCELLED - Aborting DoConnectLoop."; 235 return; 236 } 237 // Network operations can either finish synchronously or asynchronously. 238 // This method executes the state machine transitions in a loop so that 239 // correct state transitions happen even when network operations finish 240 // synchronously. 241 int rv = result; 242 do { 243 proto::ConnectionState state = connect_state_; 244 // Default to CONN_STATE_NONE, which breaks the processing loop if any 245 // handler fails to transition to another state to continue processing. 246 connect_state_ = proto::CONN_STATE_NONE; 247 switch (state) { 248 case proto::CONN_STATE_TCP_CONNECT: 249 rv = DoTcpConnect(); 250 break; 251 case proto::CONN_STATE_TCP_CONNECT_COMPLETE: 252 rv = DoTcpConnectComplete(rv); 253 break; 254 case proto::CONN_STATE_SSL_CONNECT: 255 DCHECK_EQ(net::OK, rv); 256 rv = DoSslConnect(); 257 break; 258 case proto::CONN_STATE_SSL_CONNECT_COMPLETE: 259 rv = DoSslConnectComplete(rv); 260 break; 261 case proto::CONN_STATE_AUTH_CHALLENGE_SEND: 262 rv = DoAuthChallengeSend(); 263 break; 264 case proto::CONN_STATE_AUTH_CHALLENGE_SEND_COMPLETE: 265 rv = DoAuthChallengeSendComplete(rv); 266 break; 267 case proto::CONN_STATE_AUTH_CHALLENGE_REPLY_COMPLETE: 268 rv = DoAuthChallengeReplyComplete(rv); 269 break; 270 default: 271 NOTREACHED() << "BUG in connect flow. Unknown state: " << state; 272 break; 273 } 274 } while (rv != net::ERR_IO_PENDING && 275 connect_state_ != proto::CONN_STATE_NONE); 276 // Get out of the loop either when: // a. A network operation is pending, OR 277 // b. The Do* method called did not change state 278 279 // No state change occurred in do-while loop above. This means state has 280 // transitioned to NONE. 281 if (connect_state_ == proto::CONN_STATE_NONE) { 282 logger_->LogSocketConnectState(channel_id_, connect_state_); 283 } 284 285 // Connect loop is finished: if there is no pending IO invoke the callback. 286 if (rv != net::ERR_IO_PENDING) { 287 GetTimer()->Stop(); 288 DoConnectCallback(rv); 289 } 290 } 291 292 int CastSocket::DoTcpConnect() { 293 DCHECK(connect_loop_callback_.IsCancelled()); 294 VLOG_WITH_CONNECTION(1) << "DoTcpConnect"; 295 SetConnectState(proto::CONN_STATE_TCP_CONNECT_COMPLETE); 296 tcp_socket_ = CreateTcpSocket(); 297 298 int rv = tcp_socket_->Connect( 299 base::Bind(&CastSocket::DoConnectLoop, base::Unretained(this))); 300 logger_->LogSocketEventWithRv(channel_id_, proto::TCP_SOCKET_CONNECT, rv); 301 return rv; 302 } 303 304 int CastSocket::DoTcpConnectComplete(int result) { 305 VLOG_WITH_CONNECTION(1) << "DoTcpConnectComplete: " << result; 306 if (result == net::OK) { 307 // Enable TCP protocol-level keep-alive. 308 bool result = tcp_socket_->SetKeepAlive(true, kTcpKeepAliveDelaySecs); 309 LOG_IF(WARNING, !result) << "Failed to SetKeepAlive."; 310 logger_->LogSocketEventWithRv( 311 channel_id_, proto::TCP_SOCKET_SET_KEEP_ALIVE, result ? 1 : 0); 312 SetConnectState(proto::CONN_STATE_SSL_CONNECT); 313 } 314 return result; 315 } 316 317 int CastSocket::DoSslConnect() { 318 DCHECK(connect_loop_callback_.IsCancelled()); 319 VLOG_WITH_CONNECTION(1) << "DoSslConnect"; 320 SetConnectState(proto::CONN_STATE_SSL_CONNECT_COMPLETE); 321 socket_ = CreateSslSocket(tcp_socket_.PassAs<net::StreamSocket>()); 322 323 int rv = socket_->Connect( 324 base::Bind(&CastSocket::DoConnectLoop, base::Unretained(this))); 325 logger_->LogSocketEventWithRv(channel_id_, proto::SSL_SOCKET_CONNECT, rv); 326 return rv; 327 } 328 329 int CastSocket::DoSslConnectComplete(int result) { 330 VLOG_WITH_CONNECTION(1) << "DoSslConnectComplete: " << result; 331 if (result == net::ERR_CERT_AUTHORITY_INVALID && 332 peer_cert_.empty() && ExtractPeerCert(&peer_cert_)) { 333 SetConnectState(proto::CONN_STATE_TCP_CONNECT); 334 } else if (result == net::OK && 335 channel_auth_ == CHANNEL_AUTH_TYPE_SSL_VERIFIED) { 336 SetConnectState(proto::CONN_STATE_AUTH_CHALLENGE_SEND); 337 } 338 return result; 339 } 340 341 int CastSocket::DoAuthChallengeSend() { 342 VLOG_WITH_CONNECTION(1) << "DoAuthChallengeSend"; 343 SetConnectState(proto::CONN_STATE_AUTH_CHALLENGE_SEND_COMPLETE); 344 345 CastMessage challenge_message; 346 CreateAuthChallengeMessage(&challenge_message); 347 VLOG_WITH_CONNECTION(1) << "Sending challenge: " 348 << CastMessageToString(challenge_message); 349 // Post a task to send auth challenge so that DoWriteLoop is not nested inside 350 // DoConnectLoop. This is not strictly necessary but keeps the write loop 351 // code decoupled from connect loop code. 352 DCHECK(send_auth_challenge_callback_.IsCancelled()); 353 send_auth_challenge_callback_.Reset( 354 base::Bind(&CastSocket::SendCastMessageInternal, 355 base::Unretained(this), 356 challenge_message, 357 base::Bind(&CastSocket::DoAuthChallengeSendWriteComplete, 358 base::Unretained(this)))); 359 base::MessageLoop::current()->PostTask( 360 FROM_HERE, 361 send_auth_challenge_callback_.callback()); 362 // Always return IO_PENDING since the result is always asynchronous. 363 return net::ERR_IO_PENDING; 364 } 365 366 void CastSocket::DoAuthChallengeSendWriteComplete(int result) { 367 send_auth_challenge_callback_.Cancel(); 368 VLOG_WITH_CONNECTION(2) << "DoAuthChallengeSendWriteComplete: " << result; 369 DCHECK_GT(result, 0); 370 DCHECK_EQ(write_queue_.size(), 1UL); 371 PostTaskToStartConnectLoop(result); 372 } 373 374 int CastSocket::DoAuthChallengeSendComplete(int result) { 375 VLOG_WITH_CONNECTION(1) << "DoAuthChallengeSendComplete: " << result; 376 if (result < 0) { 377 return result; 378 } 379 SetConnectState(proto::CONN_STATE_AUTH_CHALLENGE_REPLY_COMPLETE); 380 381 // Post a task to start read loop so that DoReadLoop is not nested inside 382 // DoConnectLoop. This is not strictly necessary but keeps the read loop 383 // code decoupled from connect loop code. 384 PostTaskToStartReadLoop(); 385 // Always return IO_PENDING since the result is always asynchronous. 386 return net::ERR_IO_PENDING; 387 } 388 389 int CastSocket::DoAuthChallengeReplyComplete(int result) { 390 VLOG_WITH_CONNECTION(1) << "DoAuthChallengeReplyComplete: " << result; 391 if (result < 0) { 392 return result; 393 } 394 if (!VerifyChallengeReply()) { 395 return net::ERR_FAILED; 396 } 397 VLOG_WITH_CONNECTION(1) << "Auth challenge verification succeeded"; 398 return net::OK; 399 } 400 401 void CastSocket::DoConnectCallback(int result) { 402 SetReadyState((result == net::OK) ? READY_STATE_OPEN : READY_STATE_CLOSED); 403 if (result == net::OK) { 404 SetErrorState(CHANNEL_ERROR_NONE); 405 PostTaskToStartReadLoop(); 406 VLOG_WITH_CONNECTION(1) << "Calling Connect_Callback"; 407 base::ResetAndReturn(&connect_callback_).Run(result); 408 return; 409 } else if (result == net::ERR_TIMED_OUT) { 410 SetErrorState(CHANNEL_ERROR_CONNECT_TIMEOUT); 411 } else { 412 SetErrorState(CHANNEL_ERROR_CONNECT_ERROR); 413 } 414 // Calls the connect callback. 415 CloseWithError(); 416 } 417 418 void CastSocket::Close(const net::CompletionCallback& callback) { 419 CloseInternal(); 420 RunPendingCallbacksOnClose(); 421 // Run this callback last. It may delete the socket. 422 callback.Run(net::OK); 423 } 424 425 void CastSocket::CloseInternal() { 426 // TODO(mfoltz): Enforce this when CastChannelAPITest is rewritten to create 427 // and free sockets on the same thread. crbug.com/398242 428 // DCHECK(CalledOnValidThread()); 429 if (ready_state_ == READY_STATE_CLOSED) { 430 return; 431 } 432 433 VLOG_WITH_CONNECTION(1) << "Close ReadyState = " << ready_state_; 434 tcp_socket_.reset(); 435 socket_.reset(); 436 cert_verifier_.reset(); 437 transport_security_state_.reset(); 438 GetTimer()->Stop(); 439 440 // Cancel callbacks that we queued ourselves to re-enter the connect or read 441 // loops. 442 connect_loop_callback_.Cancel(); 443 send_auth_challenge_callback_.Cancel(); 444 read_loop_callback_.Cancel(); 445 connect_timeout_callback_.Cancel(); 446 SetReadyState(READY_STATE_CLOSED); 447 logger_->LogSocketEvent(channel_id_, proto::SOCKET_CLOSED); 448 } 449 450 void CastSocket::RunPendingCallbacksOnClose() { 451 DCHECK_EQ(ready_state_, READY_STATE_CLOSED); 452 if (!connect_callback_.is_null()) { 453 connect_callback_.Run(net::ERR_CONNECTION_FAILED); 454 connect_callback_.Reset(); 455 } 456 for (; !write_queue_.empty(); write_queue_.pop()) { 457 net::CompletionCallback& callback = write_queue_.front().callback; 458 callback.Run(net::ERR_FAILED); 459 callback.Reset(); 460 } 461 } 462 463 void CastSocket::SendMessage(const MessageInfo& message, 464 const net::CompletionCallback& callback) { 465 DCHECK(CalledOnValidThread()); 466 if (ready_state_ != READY_STATE_OPEN) { 467 logger_->LogSocketEventForMessage(channel_id_, 468 proto::SEND_MESSAGE_FAILED, 469 message.namespace_, 470 "Ready state not OPEN"); 471 callback.Run(net::ERR_FAILED); 472 return; 473 } 474 CastMessage message_proto; 475 if (!MessageInfoToCastMessage(message, &message_proto)) { 476 logger_->LogSocketEventForMessage(channel_id_, 477 proto::SEND_MESSAGE_FAILED, 478 message.namespace_, 479 "Failed to convert to CastMessage"); 480 callback.Run(net::ERR_FAILED); 481 return; 482 } 483 SendCastMessageInternal(message_proto, callback); 484 } 485 486 void CastSocket::SendCastMessageInternal( 487 const CastMessage& message, 488 const net::CompletionCallback& callback) { 489 WriteRequest write_request(callback); 490 if (!write_request.SetContent(message)) { 491 logger_->LogSocketEventForMessage(channel_id_, 492 proto::SEND_MESSAGE_FAILED, 493 message.namespace_(), 494 "SetContent failed"); 495 callback.Run(net::ERR_FAILED); 496 return; 497 } 498 499 write_queue_.push(write_request); 500 logger_->LogSocketEventForMessage( 501 channel_id_, 502 proto::MESSAGE_ENQUEUED, 503 message.namespace_(), 504 base::StringPrintf("Queue size: %" PRIuS, write_queue_.size())); 505 if (write_state_ == proto::WRITE_STATE_NONE) { 506 SetWriteState(proto::WRITE_STATE_WRITE); 507 DoWriteLoop(net::OK); 508 } 509 } 510 511 void CastSocket::DoWriteLoop(int result) { 512 DCHECK(CalledOnValidThread()); 513 VLOG_WITH_CONNECTION(1) << "DoWriteLoop queue size: " << write_queue_.size(); 514 515 if (write_queue_.empty()) { 516 SetWriteState(proto::WRITE_STATE_NONE); 517 return; 518 } 519 520 // Network operations can either finish synchronously or asynchronously. 521 // This method executes the state machine transitions in a loop so that 522 // write state transitions happen even when network operations finish 523 // synchronously. 524 int rv = result; 525 do { 526 proto::WriteState state = write_state_; 527 write_state_ = proto::WRITE_STATE_NONE; 528 switch (state) { 529 case proto::WRITE_STATE_WRITE: 530 rv = DoWrite(); 531 break; 532 case proto::WRITE_STATE_WRITE_COMPLETE: 533 rv = DoWriteComplete(rv); 534 break; 535 case proto::WRITE_STATE_DO_CALLBACK: 536 rv = DoWriteCallback(); 537 break; 538 case proto::WRITE_STATE_ERROR: 539 rv = DoWriteError(rv); 540 break; 541 default: 542 NOTREACHED() << "BUG in write flow. Unknown state: " << state; 543 break; 544 } 545 } while (!write_queue_.empty() && rv != net::ERR_IO_PENDING && 546 write_state_ != proto::WRITE_STATE_NONE); 547 548 // No state change occurred in do-while loop above. This means state has 549 // transitioned to NONE. 550 if (write_state_ == proto::WRITE_STATE_NONE) { 551 logger_->LogSocketWriteState(channel_id_, write_state_); 552 } 553 554 // If write loop is done because the queue is empty then set write 555 // state to NONE 556 if (write_queue_.empty()) { 557 SetWriteState(proto::WRITE_STATE_NONE); 558 } 559 560 // Write loop is done - if the result is ERR_FAILED then close with error. 561 if (rv == net::ERR_FAILED) { 562 CloseWithError(); 563 } 564 } 565 566 int CastSocket::DoWrite() { 567 DCHECK(!write_queue_.empty()); 568 WriteRequest& request = write_queue_.front(); 569 570 VLOG_WITH_CONNECTION(2) << "WriteData byte_count = " 571 << request.io_buffer->size() << " bytes_written " 572 << request.io_buffer->BytesConsumed(); 573 574 SetWriteState(proto::WRITE_STATE_WRITE_COMPLETE); 575 576 int rv = socket_->Write( 577 request.io_buffer.get(), 578 request.io_buffer->BytesRemaining(), 579 base::Bind(&CastSocket::DoWriteLoop, base::Unretained(this))); 580 logger_->LogSocketEventWithRv(channel_id_, proto::SOCKET_WRITE, rv); 581 582 return rv; 583 } 584 585 int CastSocket::DoWriteComplete(int result) { 586 DCHECK(!write_queue_.empty()); 587 if (result <= 0) { // NOTE that 0 also indicates an error 588 SetErrorState(CHANNEL_ERROR_SOCKET_ERROR); 589 SetWriteState(proto::WRITE_STATE_ERROR); 590 return result == 0 ? net::ERR_FAILED : result; 591 } 592 593 // Some bytes were successfully written 594 WriteRequest& request = write_queue_.front(); 595 scoped_refptr<net::DrainableIOBuffer> io_buffer = request.io_buffer; 596 io_buffer->DidConsume(result); 597 if (io_buffer->BytesRemaining() == 0) { // Message fully sent 598 SetWriteState(proto::WRITE_STATE_DO_CALLBACK); 599 } else { 600 SetWriteState(proto::WRITE_STATE_WRITE); 601 } 602 603 return net::OK; 604 } 605 606 int CastSocket::DoWriteCallback() { 607 DCHECK(!write_queue_.empty()); 608 609 SetWriteState(proto::WRITE_STATE_WRITE); 610 611 WriteRequest& request = write_queue_.front(); 612 int bytes_consumed = request.io_buffer->BytesConsumed(); 613 logger_->LogSocketEventForMessage( 614 channel_id_, 615 proto::MESSAGE_WRITTEN, 616 request.message_namespace, 617 base::StringPrintf("Bytes: %d", bytes_consumed)); 618 request.callback.Run(bytes_consumed); 619 write_queue_.pop(); 620 return net::OK; 621 } 622 623 int CastSocket::DoWriteError(int result) { 624 DCHECK(!write_queue_.empty()); 625 DCHECK_LT(result, 0); 626 627 // If inside connection flow, then there should be exactly one item in 628 // the write queue. 629 if (ready_state_ == READY_STATE_CONNECTING) { 630 write_queue_.pop(); 631 DCHECK(write_queue_.empty()); 632 PostTaskToStartConnectLoop(result); 633 // Connect loop will handle the error. Return net::OK so that write flow 634 // does not try to report error also. 635 return net::OK; 636 } 637 638 while (!write_queue_.empty()) { 639 WriteRequest& request = write_queue_.front(); 640 request.callback.Run(result); 641 write_queue_.pop(); 642 } 643 return net::ERR_FAILED; 644 } 645 646 void CastSocket::PostTaskToStartReadLoop() { 647 DCHECK(CalledOnValidThread()); 648 DCHECK(read_loop_callback_.IsCancelled()); 649 read_loop_callback_.Reset( 650 base::Bind(&CastSocket::StartReadLoop, base::Unretained(this))); 651 base::MessageLoop::current()->PostTask(FROM_HERE, 652 read_loop_callback_.callback()); 653 } 654 655 void CastSocket::StartReadLoop() { 656 read_loop_callback_.Cancel(); 657 // Read loop would have already been started if read state is not NONE 658 if (read_state_ == proto::READ_STATE_NONE) { 659 SetReadState(proto::READ_STATE_READ); 660 DoReadLoop(net::OK); 661 } 662 } 663 664 void CastSocket::DoReadLoop(int result) { 665 DCHECK(CalledOnValidThread()); 666 // Network operations can either finish synchronously or asynchronously. 667 // This method executes the state machine transitions in a loop so that 668 // write state transitions happen even when network operations finish 669 // synchronously. 670 int rv = result; 671 do { 672 proto::ReadState state = read_state_; 673 read_state_ = proto::READ_STATE_NONE; 674 675 switch (state) { 676 case proto::READ_STATE_READ: 677 rv = DoRead(); 678 break; 679 case proto::READ_STATE_READ_COMPLETE: 680 rv = DoReadComplete(rv); 681 break; 682 case proto::READ_STATE_DO_CALLBACK: 683 rv = DoReadCallback(); 684 break; 685 case proto::READ_STATE_ERROR: 686 rv = DoReadError(rv); 687 DCHECK_EQ(read_state_, proto::READ_STATE_NONE); 688 break; 689 default: 690 NOTREACHED() << "BUG in read flow. Unknown state: " << state; 691 break; 692 } 693 } while (rv != net::ERR_IO_PENDING && read_state_ != proto::READ_STATE_NONE); 694 695 // No state change occurred in do-while loop above. This means state has 696 // transitioned to NONE. 697 if (read_state_ == proto::READ_STATE_NONE) { 698 logger_->LogSocketReadState(channel_id_, read_state_); 699 } 700 701 if (rv == net::ERR_FAILED) { 702 if (ready_state_ == READY_STATE_CONNECTING) { 703 // Read errors during the handshake should notify the caller via the 704 // connect callback. This will also send error status via the OnError 705 // delegate. 706 PostTaskToStartConnectLoop(net::ERR_FAILED); 707 } else { 708 // Connection is already established. Close and send error status via the 709 // OnError delegate. 710 CloseWithError(); 711 } 712 } 713 } 714 715 int CastSocket::DoRead() { 716 SetReadState(proto::READ_STATE_READ_COMPLETE); 717 718 // Determine how many bytes need to be read. 719 size_t num_bytes_to_read = framer_->BytesRequested(); 720 721 // Read up to num_bytes_to_read into |current_read_buffer_|. 722 int rv = socket_->Read( 723 read_buffer_.get(), 724 base::checked_cast<uint32>(num_bytes_to_read), 725 base::Bind(&CastSocket::DoReadLoop, base::Unretained(this))); 726 logger_->LogSocketEventWithRv(channel_id_, proto::SOCKET_READ, rv); 727 728 return rv; 729 } 730 731 int CastSocket::DoReadComplete(int result) { 732 VLOG_WITH_CONNECTION(2) << "DoReadComplete result = " << result; 733 734 if (result <= 0) { // 0 means EOF: the peer closed the socket 735 VLOG_WITH_CONNECTION(1) << "Read error, peer closed the socket"; 736 SetErrorState(CHANNEL_ERROR_SOCKET_ERROR); 737 SetReadState(proto::READ_STATE_ERROR); 738 return result == 0 ? net::ERR_FAILED : result; 739 } 740 741 size_t message_size; 742 DCHECK(current_message_.get() == NULL); 743 current_message_ = framer_->Ingest(result, &message_size, &error_state_); 744 if (current_message_.get()) { 745 DCHECK_EQ(error_state_, CHANNEL_ERROR_NONE); 746 DCHECK_GT(message_size, static_cast<size_t>(0)); 747 logger_->LogSocketEventForMessage( 748 channel_id_, 749 proto::MESSAGE_READ, 750 current_message_->namespace_(), 751 base::StringPrintf("Message size: %u", 752 static_cast<uint32>(message_size))); 753 SetReadState(proto::READ_STATE_DO_CALLBACK); 754 } else if (error_state_ != CHANNEL_ERROR_NONE) { 755 DCHECK(current_message_.get() == NULL); 756 SetReadState(proto::READ_STATE_ERROR); 757 } else { 758 DCHECK(current_message_.get() == NULL); 759 SetReadState(proto::READ_STATE_READ); 760 } 761 return net::OK; 762 } 763 764 int CastSocket::DoReadCallback() { 765 SetReadState(proto::READ_STATE_READ); 766 const CastMessage& message = *current_message_; 767 if (ready_state_ == READY_STATE_CONNECTING) { 768 if (IsAuthMessage(message)) { 769 challenge_reply_.reset(new CastMessage(message)); 770 logger_->LogSocketEvent(channel_id_, proto::RECEIVED_CHALLENGE_REPLY); 771 PostTaskToStartConnectLoop(net::OK); 772 current_message_.reset(); 773 return net::OK; 774 } else { 775 SetReadState(proto::READ_STATE_ERROR); 776 SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE); 777 current_message_.reset(); 778 return net::ERR_INVALID_RESPONSE; 779 } 780 } 781 782 MessageInfo message_info; 783 if (!CastMessageToMessageInfo(message, &message_info)) { 784 current_message_.reset(); 785 SetReadState(proto::READ_STATE_ERROR); 786 SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE); 787 return net::ERR_INVALID_RESPONSE; 788 } 789 790 logger_->LogSocketEventForMessage(channel_id_, 791 proto::NOTIFY_ON_MESSAGE, 792 message.namespace_(), 793 std::string()); 794 delegate_->OnMessage(this, message_info); 795 current_message_.reset(); 796 797 return net::OK; 798 } 799 800 int CastSocket::DoReadError(int result) { 801 DCHECK_LE(result, 0); 802 return net::ERR_FAILED; 803 } 804 805 void CastSocket::CloseWithError() { 806 DCHECK(CalledOnValidThread()); 807 CloseInternal(); 808 RunPendingCallbacksOnClose(); 809 if (delegate_) { 810 logger_->LogSocketEvent(channel_id_, proto::NOTIFY_ON_ERROR); 811 delegate_->OnError(this, error_state_, logger_->GetLastErrors(channel_id_)); 812 } 813 } 814 815 std::string CastSocket::CastUrl() const { 816 return ((channel_auth_ == CHANNEL_AUTH_TYPE_SSL_VERIFIED) ? 817 "casts://" : "cast://") + ip_endpoint_.ToString(); 818 } 819 820 bool CastSocket::CalledOnValidThread() const { 821 return thread_checker_.CalledOnValidThread(); 822 } 823 824 base::Timer* CastSocket::GetTimer() { 825 return connect_timeout_timer_.get(); 826 } 827 828 void CastSocket::SetConnectState(proto::ConnectionState connect_state) { 829 if (connect_state_ != connect_state) { 830 connect_state_ = connect_state; 831 logger_->LogSocketConnectState(channel_id_, connect_state_); 832 } 833 } 834 835 void CastSocket::SetReadyState(ReadyState ready_state) { 836 if (ready_state_ != ready_state) { 837 ready_state_ = ready_state; 838 logger_->LogSocketReadyState(channel_id_, ReadyStateToProto(ready_state_)); 839 } 840 } 841 842 void CastSocket::SetErrorState(ChannelError error_state) { 843 if (error_state_ != error_state) { 844 error_state_ = error_state; 845 logger_->LogSocketErrorState(channel_id_, ErrorStateToProto(error_state_)); 846 } 847 } 848 849 void CastSocket::SetReadState(proto::ReadState read_state) { 850 if (read_state_ != read_state) { 851 read_state_ = read_state; 852 logger_->LogSocketReadState(channel_id_, read_state_); 853 } 854 } 855 856 void CastSocket::SetWriteState(proto::WriteState write_state) { 857 if (write_state_ != write_state) { 858 write_state_ = write_state; 859 logger_->LogSocketWriteState(channel_id_, write_state_); 860 } 861 } 862 863 CastSocket::WriteRequest::WriteRequest(const net::CompletionCallback& callback) 864 : callback(callback) { 865 } 866 867 bool CastSocket::WriteRequest::SetContent(const CastMessage& message_proto) { 868 DCHECK(!io_buffer.get()); 869 std::string message_data; 870 if (!MessageFramer::Serialize(message_proto, &message_data)) { 871 return false; 872 } 873 message_namespace = message_proto.namespace_(); 874 io_buffer = new net::DrainableIOBuffer(new net::StringIOBuffer(message_data), 875 message_data.size()); 876 return true; 877 } 878 879 CastSocket::WriteRequest::~WriteRequest() { 880 } 881 882 } // namespace cast_channel 883 } // namespace core_api 884 } // namespace extensions 885 886 #undef VLOG_WITH_CONNECTION 887