1 // Copyright (c) 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "chrome/test/chromedriver/net/adb_client_socket.h" 6 7 #include "base/bind.h" 8 #include "base/compiler_specific.h" 9 #include "base/strings/string_number_conversions.h" 10 #include "base/strings/string_util.h" 11 #include "base/strings/stringprintf.h" 12 #include "net/base/address_list.h" 13 #include "net/base/completion_callback.h" 14 #include "net/base/net_errors.h" 15 #include "net/base/net_util.h" 16 #include "net/socket/tcp_client_socket.h" 17 18 namespace { 19 20 const int kBufferSize = 16 * 1024; 21 const int kResponseBufferSize = 16; 22 const char kOkayResponse[] = "OKAY"; 23 const char kHostTransportCommand[] = "host:transport:%s"; 24 const char kLocalAbstractCommand[] = "localabstract:%s"; 25 const char kLocalhost[] = "127.0.0.1"; 26 27 typedef base::Callback<void(int, const std::string&)> CommandCallback; 28 typedef base::Callback<void(int, net::StreamSocket*)> SocketCallback; 29 30 std::string EncodeMessage(const std::string& message) { 31 static const char kHexChars[] = "0123456789ABCDEF"; 32 33 size_t length = message.length(); 34 std::string result(4, '\0'); 35 char b = reinterpret_cast<const char*>(&length)[1]; 36 result[0] = kHexChars[(b >> 4) & 0xf]; 37 result[1] = kHexChars[b & 0xf]; 38 b = reinterpret_cast<const char*>(&length)[0]; 39 result[2] = kHexChars[(b >> 4) & 0xf]; 40 result[3] = kHexChars[b & 0xf]; 41 return result + message; 42 } 43 44 class AdbTransportSocket : public AdbClientSocket { 45 public: 46 AdbTransportSocket(int port, 47 const std::string& serial, 48 const std::string& socket_name, 49 const SocketCallback& callback) 50 : AdbClientSocket(port), 51 serial_(serial), 52 socket_name_(socket_name), 53 callback_(callback) { 54 Connect(base::Bind(&AdbTransportSocket::OnConnected, 55 base::Unretained(this))); 56 } 57 58 private: 59 ~AdbTransportSocket() {} 60 61 void OnConnected(int result) { 62 if (!CheckNetResultOrDie(result)) 63 return; 64 SendCommand(base::StringPrintf(kHostTransportCommand, serial_.c_str()), 65 true, base::Bind(&AdbTransportSocket::SendLocalAbstract, 66 base::Unretained(this))); 67 } 68 69 void SendLocalAbstract(int result, const std::string& response) { 70 if (!CheckNetResultOrDie(result)) 71 return; 72 SendCommand(base::StringPrintf(kLocalAbstractCommand, socket_name_.c_str()), 73 true, base::Bind(&AdbTransportSocket::OnSocketAvailable, 74 base::Unretained(this))); 75 } 76 77 void OnSocketAvailable(int result, const std::string& response) { 78 if (!CheckNetResultOrDie(result)) 79 return; 80 callback_.Run(net::OK, socket_.release()); 81 delete this; 82 } 83 84 bool CheckNetResultOrDie(int result) { 85 if (result >= 0) 86 return true; 87 callback_.Run(result, NULL); 88 delete this; 89 return false; 90 } 91 92 std::string serial_; 93 std::string socket_name_; 94 SocketCallback callback_; 95 }; 96 97 class HttpOverAdbSocket { 98 public: 99 HttpOverAdbSocket(int port, 100 const std::string& serial, 101 const std::string& socket_name, 102 const std::string& request, 103 const CommandCallback& callback) 104 : request_(request), 105 command_callback_(callback), 106 body_pos_(0) { 107 Connect(port, serial, socket_name); 108 } 109 110 HttpOverAdbSocket(int port, 111 const std::string& serial, 112 const std::string& socket_name, 113 const std::string& request, 114 const SocketCallback& callback) 115 : request_(request), 116 socket_callback_(callback), 117 body_pos_(0) { 118 Connect(port, serial, socket_name); 119 } 120 121 private: 122 ~HttpOverAdbSocket() { 123 } 124 125 void Connect(int port, 126 const std::string& serial, 127 const std::string& socket_name) { 128 AdbClientSocket::TransportQuery( 129 port, serial, socket_name, 130 base::Bind(&HttpOverAdbSocket::OnSocketAvailable, 131 base::Unretained(this))); 132 } 133 134 void OnSocketAvailable(int result, 135 net::StreamSocket* socket) { 136 if (!CheckNetResultOrDie(result)) 137 return; 138 139 socket_.reset(socket); 140 141 scoped_refptr<net::StringIOBuffer> request_buffer = 142 new net::StringIOBuffer(request_); 143 144 result = socket_->Write( 145 request_buffer.get(), 146 request_buffer->size(), 147 base::Bind(&HttpOverAdbSocket::ReadResponse, base::Unretained(this))); 148 if (result != net::ERR_IO_PENDING) 149 ReadResponse(result); 150 } 151 152 void ReadResponse(int result) { 153 if (!CheckNetResultOrDie(result)) 154 return; 155 156 scoped_refptr<net::IOBuffer> response_buffer = 157 new net::IOBuffer(kBufferSize); 158 159 result = socket_->Read(response_buffer.get(), 160 kBufferSize, 161 base::Bind(&HttpOverAdbSocket::OnResponseData, 162 base::Unretained(this), 163 response_buffer, 164 -1)); 165 if (result != net::ERR_IO_PENDING) 166 OnResponseData(response_buffer, -1, result); 167 } 168 169 void OnResponseData(scoped_refptr<net::IOBuffer> response_buffer, 170 int bytes_total, 171 int result) { 172 if (!CheckNetResultOrDie(result)) 173 return; 174 if (result == 0) { 175 CheckNetResultOrDie(net::ERR_CONNECTION_CLOSED); 176 return; 177 } 178 179 response_ += std::string(response_buffer->data(), result); 180 int expected_length = 0; 181 if (bytes_total < 0) { 182 size_t content_pos = response_.find("Content-Length:"); 183 if (content_pos != std::string::npos) { 184 size_t endline_pos = response_.find("\n", content_pos); 185 if (endline_pos != std::string::npos) { 186 std::string len = response_.substr(content_pos + 15, 187 endline_pos - content_pos - 15); 188 TrimWhitespace(len, TRIM_ALL, &len); 189 if (!base::StringToInt(len, &expected_length)) { 190 CheckNetResultOrDie(net::ERR_FAILED); 191 return; 192 } 193 } 194 } 195 196 body_pos_ = response_.find("\r\n\r\n"); 197 if (body_pos_ != std::string::npos) { 198 body_pos_ += 4; 199 bytes_total = body_pos_ + expected_length; 200 } 201 } 202 203 if (bytes_total == static_cast<int>(response_.length())) { 204 if (!command_callback_.is_null()) 205 command_callback_.Run(body_pos_, response_); 206 else 207 socket_callback_.Run(net::OK, socket_.release()); 208 delete this; 209 return; 210 } 211 212 result = socket_->Read(response_buffer.get(), 213 kBufferSize, 214 base::Bind(&HttpOverAdbSocket::OnResponseData, 215 base::Unretained(this), 216 response_buffer, 217 bytes_total)); 218 if (result != net::ERR_IO_PENDING) 219 OnResponseData(response_buffer, bytes_total, result); 220 } 221 222 bool CheckNetResultOrDie(int result) { 223 if (result >= 0) 224 return true; 225 if (!command_callback_.is_null()) 226 command_callback_.Run(result, std::string()); 227 else 228 socket_callback_.Run(result, NULL); 229 delete this; 230 return false; 231 } 232 233 scoped_ptr<net::StreamSocket> socket_; 234 std::string request_; 235 std::string response_; 236 CommandCallback command_callback_; 237 SocketCallback socket_callback_; 238 size_t body_pos_; 239 }; 240 241 class AdbQuerySocket : AdbClientSocket { 242 public: 243 AdbQuerySocket(int port, 244 const std::string& query, 245 const CommandCallback& callback) 246 : AdbClientSocket(port), 247 current_query_(0), 248 callback_(callback) { 249 if (Tokenize(query, "|", &queries_) == 0) { 250 CheckNetResultOrDie(net::ERR_INVALID_ARGUMENT); 251 return; 252 } 253 Connect(base::Bind(&AdbQuerySocket::SendNextQuery, 254 base::Unretained(this))); 255 } 256 257 private: 258 ~AdbQuerySocket() { 259 } 260 261 void SendNextQuery(int result) { 262 if (!CheckNetResultOrDie(result)) 263 return; 264 std::string query = queries_[current_query_]; 265 if (query.length() > 0xFFFF) { 266 CheckNetResultOrDie(net::ERR_MSG_TOO_BIG); 267 return; 268 } 269 bool is_void = current_query_ < queries_.size() - 1; 270 SendCommand(query, is_void, 271 base::Bind(&AdbQuerySocket::OnResponse, base::Unretained(this))); 272 } 273 274 void OnResponse(int result, const std::string& response) { 275 if (++current_query_ < queries_.size()) { 276 SendNextQuery(net::OK); 277 } else { 278 callback_.Run(result, response); 279 delete this; 280 } 281 } 282 283 bool CheckNetResultOrDie(int result) { 284 if (result >= 0) 285 return true; 286 callback_.Run(result, std::string()); 287 delete this; 288 return false; 289 } 290 291 std::vector<std::string> queries_; 292 size_t current_query_; 293 CommandCallback callback_; 294 }; 295 296 } // namespace 297 298 // static 299 void AdbClientSocket::AdbQuery(int port, 300 const std::string& query, 301 const CommandCallback& callback) { 302 new AdbQuerySocket(port, query, callback); 303 } 304 305 #if defined(DEBUG_DEVTOOLS) 306 static void UseTransportQueryForDesktop(const SocketCallback& callback, 307 net::StreamSocket* socket, 308 int result) { 309 callback.Run(result, socket); 310 } 311 #endif // defined(DEBUG_DEVTOOLS) 312 313 // static 314 void AdbClientSocket::TransportQuery(int port, 315 const std::string& serial, 316 const std::string& socket_name, 317 const SocketCallback& callback) { 318 #if defined(DEBUG_DEVTOOLS) 319 if (serial.empty()) { 320 // Use plain socket for remote debugging on Desktop (debugging purposes). 321 net::IPAddressNumber ip_number; 322 net::ParseIPLiteralToNumber(kLocalhost, &ip_number); 323 324 int tcp_port = 0; 325 if (!base::StringToInt(socket_name, &tcp_port)) 326 tcp_port = 9222; 327 328 net::AddressList address_list = 329 net::AddressList::CreateFromIPAddress(ip_number, tcp_port); 330 net::TCPClientSocket* socket = new net::TCPClientSocket( 331 address_list, NULL, net::NetLog::Source()); 332 socket->Connect(base::Bind(&UseTransportQueryForDesktop, callback, socket)); 333 return; 334 } 335 #endif // defined(DEBUG_DEVTOOLS) 336 new AdbTransportSocket(port, serial, socket_name, callback); 337 } 338 339 // static 340 void AdbClientSocket::HttpQuery(int port, 341 const std::string& serial, 342 const std::string& socket_name, 343 const std::string& request_path, 344 const CommandCallback& callback) { 345 new HttpOverAdbSocket(port, serial, socket_name, request_path, 346 callback); 347 } 348 349 // static 350 void AdbClientSocket::HttpQuery(int port, 351 const std::string& serial, 352 const std::string& socket_name, 353 const std::string& request_path, 354 const SocketCallback& callback) { 355 new HttpOverAdbSocket(port, serial, socket_name, request_path, 356 callback); 357 } 358 359 AdbClientSocket::AdbClientSocket(int port) 360 : host_(kLocalhost), port_(port) { 361 } 362 363 AdbClientSocket::~AdbClientSocket() { 364 } 365 366 void AdbClientSocket::Connect(const net::CompletionCallback& callback) { 367 net::IPAddressNumber ip_number; 368 if (!net::ParseIPLiteralToNumber(host_, &ip_number)) { 369 callback.Run(net::ERR_FAILED); 370 return; 371 } 372 373 net::AddressList address_list = 374 net::AddressList::CreateFromIPAddress(ip_number, port_); 375 socket_.reset(new net::TCPClientSocket(address_list, NULL, 376 net::NetLog::Source())); 377 int result = socket_->Connect(callback); 378 if (result != net::ERR_IO_PENDING) 379 callback.Run(result); 380 } 381 382 void AdbClientSocket::SendCommand(const std::string& command, 383 bool is_void, 384 const CommandCallback& callback) { 385 scoped_refptr<net::StringIOBuffer> request_buffer = 386 new net::StringIOBuffer(EncodeMessage(command)); 387 int result = socket_->Write(request_buffer.get(), 388 request_buffer->size(), 389 base::Bind(&AdbClientSocket::ReadResponse, 390 base::Unretained(this), 391 callback, 392 is_void)); 393 if (result != net::ERR_IO_PENDING) 394 ReadResponse(callback, is_void, result); 395 } 396 397 void AdbClientSocket::ReadResponse(const CommandCallback& callback, 398 bool is_void, 399 int result) { 400 if (result < 0) { 401 callback.Run(result, "IO error"); 402 return; 403 } 404 scoped_refptr<net::IOBuffer> response_buffer = 405 new net::IOBuffer(kBufferSize); 406 result = socket_->Read(response_buffer.get(), 407 kBufferSize, 408 base::Bind(&AdbClientSocket::OnResponseHeader, 409 base::Unretained(this), 410 callback, 411 is_void, 412 response_buffer)); 413 if (result != net::ERR_IO_PENDING) 414 OnResponseHeader(callback, is_void, response_buffer, result); 415 } 416 417 void AdbClientSocket::OnResponseHeader( 418 const CommandCallback& callback, 419 bool is_void, 420 scoped_refptr<net::IOBuffer> response_buffer, 421 int result) { 422 if (result <= 0) { 423 callback.Run(result == 0 ? net::ERR_CONNECTION_CLOSED : result, 424 "IO error"); 425 return; 426 } 427 428 std::string data = std::string(response_buffer->data(), result); 429 if (result < 4) { 430 callback.Run(net::ERR_FAILED, "Response is too short: " + data); 431 return; 432 } 433 434 std::string status = data.substr(0, 4); 435 if (status != kOkayResponse) { 436 callback.Run(net::ERR_FAILED, data); 437 return; 438 } 439 440 data = data.substr(4); 441 442 if (!is_void) { 443 int payload_length = 0; 444 int bytes_left = -1; 445 if (data.length() >= 4 && 446 base::HexStringToInt(data.substr(0, 4), &payload_length)) { 447 data = data.substr(4); 448 bytes_left = payload_length - result + 8; 449 } else { 450 bytes_left = -1; 451 } 452 OnResponseData(callback, data, response_buffer, bytes_left, 0); 453 } else { 454 callback.Run(net::OK, data); 455 } 456 } 457 458 void AdbClientSocket::OnResponseData( 459 const CommandCallback& callback, 460 const std::string& response, 461 scoped_refptr<net::IOBuffer> response_buffer, 462 int bytes_left, 463 int result) { 464 if (result < 0) { 465 callback.Run(result, "IO error"); 466 return; 467 } 468 469 bytes_left -= result; 470 std::string new_response = 471 response + std::string(response_buffer->data(), result); 472 if (bytes_left == 0) { 473 callback.Run(net::OK, new_response); 474 return; 475 } 476 477 // Read tail 478 result = socket_->Read(response_buffer.get(), 479 kBufferSize, 480 base::Bind(&AdbClientSocket::OnResponseData, 481 base::Unretained(this), 482 callback, 483 new_response, 484 response_buffer, 485 bytes_left)); 486 if (result > 0) 487 OnResponseData(callback, new_response, response_buffer, bytes_left, result); 488 else if (result != net::ERR_IO_PENDING) 489 callback.Run(net::OK, new_response); 490 } 491