1 // Copyright (c) 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "chrome/test/chromedriver/net/adb_client_socket.h" 6 7 #include "base/bind.h" 8 #include "base/compiler_specific.h" 9 #include "base/strings/string_number_conversions.h" 10 #include "base/strings/string_util.h" 11 #include "base/strings/stringprintf.h" 12 #include "net/base/address_list.h" 13 #include "net/base/completion_callback.h" 14 #include "net/base/net_errors.h" 15 #include "net/base/net_util.h" 16 #include "net/socket/tcp_client_socket.h" 17 18 namespace { 19 20 const int kBufferSize = 16 * 1024; 21 const char kOkayResponse[] = "OKAY"; 22 const char kHostTransportCommand[] = "host:transport:%s"; 23 const char kLocalAbstractCommand[] = "localabstract:%s"; 24 const char kLocalhost[] = "127.0.0.1"; 25 26 typedef base::Callback<void(int, const std::string&)> CommandCallback; 27 typedef base::Callback<void(int, net::StreamSocket*)> SocketCallback; 28 29 std::string EncodeMessage(const std::string& message) { 30 static const char kHexChars[] = "0123456789ABCDEF"; 31 32 size_t length = message.length(); 33 std::string result(4, '\0'); 34 char b = reinterpret_cast<const char*>(&length)[1]; 35 result[0] = kHexChars[(b >> 4) & 0xf]; 36 result[1] = kHexChars[b & 0xf]; 37 b = reinterpret_cast<const char*>(&length)[0]; 38 result[2] = kHexChars[(b >> 4) & 0xf]; 39 result[3] = kHexChars[b & 0xf]; 40 return result + message; 41 } 42 43 class AdbTransportSocket : public AdbClientSocket { 44 public: 45 AdbTransportSocket(int port, 46 const std::string& serial, 47 const std::string& socket_name, 48 const SocketCallback& callback) 49 : AdbClientSocket(port), 50 serial_(serial), 51 socket_name_(socket_name), 52 callback_(callback) { 53 Connect(base::Bind(&AdbTransportSocket::OnConnected, 54 base::Unretained(this))); 55 } 56 57 private: 58 ~AdbTransportSocket() {} 59 60 void OnConnected(int result) { 61 if (!CheckNetResultOrDie(result)) 62 return; 63 SendCommand(base::StringPrintf(kHostTransportCommand, serial_.c_str()), 64 true, base::Bind(&AdbTransportSocket::SendLocalAbstract, 65 base::Unretained(this))); 66 } 67 68 void SendLocalAbstract(int result, const std::string& response) { 69 if (!CheckNetResultOrDie(result)) 70 return; 71 SendCommand(base::StringPrintf(kLocalAbstractCommand, socket_name_.c_str()), 72 true, base::Bind(&AdbTransportSocket::OnSocketAvailable, 73 base::Unretained(this))); 74 } 75 76 void OnSocketAvailable(int result, const std::string& response) { 77 if (!CheckNetResultOrDie(result)) 78 return; 79 callback_.Run(net::OK, socket_.release()); 80 delete this; 81 } 82 83 bool CheckNetResultOrDie(int result) { 84 if (result >= 0) 85 return true; 86 callback_.Run(result, NULL); 87 delete this; 88 return false; 89 } 90 91 std::string serial_; 92 std::string socket_name_; 93 SocketCallback callback_; 94 }; 95 96 class HttpOverAdbSocket { 97 public: 98 HttpOverAdbSocket(int port, 99 const std::string& serial, 100 const std::string& socket_name, 101 const std::string& request, 102 const CommandCallback& callback) 103 : request_(request), 104 command_callback_(callback), 105 body_pos_(0) { 106 Connect(port, serial, socket_name); 107 } 108 109 HttpOverAdbSocket(int port, 110 const std::string& serial, 111 const std::string& socket_name, 112 const std::string& request, 113 const SocketCallback& callback) 114 : request_(request), 115 socket_callback_(callback), 116 body_pos_(0) { 117 Connect(port, serial, socket_name); 118 } 119 120 private: 121 ~HttpOverAdbSocket() { 122 } 123 124 void Connect(int port, 125 const std::string& serial, 126 const std::string& socket_name) { 127 AdbClientSocket::TransportQuery( 128 port, serial, socket_name, 129 base::Bind(&HttpOverAdbSocket::OnSocketAvailable, 130 base::Unretained(this))); 131 } 132 133 void OnSocketAvailable(int result, 134 net::StreamSocket* socket) { 135 if (!CheckNetResultOrDie(result)) 136 return; 137 138 socket_.reset(socket); 139 140 scoped_refptr<net::StringIOBuffer> request_buffer = 141 new net::StringIOBuffer(request_); 142 143 result = socket_->Write( 144 request_buffer.get(), 145 request_buffer->size(), 146 base::Bind(&HttpOverAdbSocket::ReadResponse, base::Unretained(this))); 147 if (result != net::ERR_IO_PENDING) 148 ReadResponse(result); 149 } 150 151 void ReadResponse(int result) { 152 if (!CheckNetResultOrDie(result)) 153 return; 154 155 scoped_refptr<net::IOBuffer> response_buffer = 156 new net::IOBuffer(kBufferSize); 157 158 result = socket_->Read(response_buffer.get(), 159 kBufferSize, 160 base::Bind(&HttpOverAdbSocket::OnResponseData, 161 base::Unretained(this), 162 response_buffer, 163 -1)); 164 if (result != net::ERR_IO_PENDING) 165 OnResponseData(response_buffer, -1, result); 166 } 167 168 void OnResponseData(scoped_refptr<net::IOBuffer> response_buffer, 169 int bytes_total, 170 int result) { 171 if (!CheckNetResultOrDie(result)) 172 return; 173 if (result == 0) { 174 CheckNetResultOrDie(net::ERR_CONNECTION_CLOSED); 175 return; 176 } 177 178 response_ += std::string(response_buffer->data(), result); 179 int expected_length = 0; 180 if (bytes_total < 0) { 181 size_t content_pos = response_.find("Content-Length:"); 182 if (content_pos != std::string::npos) { 183 size_t endline_pos = response_.find("\n", content_pos); 184 if (endline_pos != std::string::npos) { 185 std::string len = response_.substr(content_pos + 15, 186 endline_pos - content_pos - 15); 187 TrimWhitespace(len, TRIM_ALL, &len); 188 if (!base::StringToInt(len, &expected_length)) { 189 CheckNetResultOrDie(net::ERR_FAILED); 190 return; 191 } 192 } 193 } 194 195 body_pos_ = response_.find("\r\n\r\n"); 196 if (body_pos_ != std::string::npos) { 197 body_pos_ += 4; 198 bytes_total = body_pos_ + expected_length; 199 } 200 } 201 202 if (bytes_total == static_cast<int>(response_.length())) { 203 if (!command_callback_.is_null()) 204 command_callback_.Run(body_pos_, response_); 205 else 206 socket_callback_.Run(net::OK, socket_.release()); 207 delete this; 208 return; 209 } 210 211 result = socket_->Read(response_buffer.get(), 212 kBufferSize, 213 base::Bind(&HttpOverAdbSocket::OnResponseData, 214 base::Unretained(this), 215 response_buffer, 216 bytes_total)); 217 if (result != net::ERR_IO_PENDING) 218 OnResponseData(response_buffer, bytes_total, result); 219 } 220 221 bool CheckNetResultOrDie(int result) { 222 if (result >= 0) 223 return true; 224 if (!command_callback_.is_null()) 225 command_callback_.Run(result, std::string()); 226 else 227 socket_callback_.Run(result, NULL); 228 delete this; 229 return false; 230 } 231 232 scoped_ptr<net::StreamSocket> socket_; 233 std::string request_; 234 std::string response_; 235 CommandCallback command_callback_; 236 SocketCallback socket_callback_; 237 size_t body_pos_; 238 }; 239 240 class AdbQuerySocket : AdbClientSocket { 241 public: 242 AdbQuerySocket(int port, 243 const std::string& query, 244 const CommandCallback& callback) 245 : AdbClientSocket(port), 246 current_query_(0), 247 callback_(callback) { 248 if (Tokenize(query, "|", &queries_) == 0) { 249 CheckNetResultOrDie(net::ERR_INVALID_ARGUMENT); 250 return; 251 } 252 Connect(base::Bind(&AdbQuerySocket::SendNextQuery, 253 base::Unretained(this))); 254 } 255 256 private: 257 ~AdbQuerySocket() { 258 } 259 260 void SendNextQuery(int result) { 261 if (!CheckNetResultOrDie(result)) 262 return; 263 std::string query = queries_[current_query_]; 264 if (query.length() > 0xFFFF) { 265 CheckNetResultOrDie(net::ERR_MSG_TOO_BIG); 266 return; 267 } 268 bool is_void = current_query_ < queries_.size() - 1; 269 SendCommand(query, is_void, 270 base::Bind(&AdbQuerySocket::OnResponse, base::Unretained(this))); 271 } 272 273 void OnResponse(int result, const std::string& response) { 274 if (++current_query_ < queries_.size()) { 275 SendNextQuery(net::OK); 276 } else { 277 callback_.Run(result, response); 278 delete this; 279 } 280 } 281 282 bool CheckNetResultOrDie(int result) { 283 if (result >= 0) 284 return true; 285 callback_.Run(result, std::string()); 286 delete this; 287 return false; 288 } 289 290 std::vector<std::string> queries_; 291 size_t current_query_; 292 CommandCallback callback_; 293 }; 294 295 } // namespace 296 297 // static 298 void AdbClientSocket::AdbQuery(int port, 299 const std::string& query, 300 const CommandCallback& callback) { 301 new AdbQuerySocket(port, query, callback); 302 } 303 304 #if defined(DEBUG_DEVTOOLS) 305 static void UseTransportQueryForDesktop(const SocketCallback& callback, 306 net::StreamSocket* socket, 307 int result) { 308 callback.Run(result, socket); 309 } 310 #endif // defined(DEBUG_DEVTOOLS) 311 312 // static 313 void AdbClientSocket::TransportQuery(int port, 314 const std::string& serial, 315 const std::string& socket_name, 316 const SocketCallback& callback) { 317 #if defined(DEBUG_DEVTOOLS) 318 if (serial.empty()) { 319 // Use plain socket for remote debugging on Desktop (debugging purposes). 320 net::IPAddressNumber ip_number; 321 net::ParseIPLiteralToNumber(kLocalhost, &ip_number); 322 323 int tcp_port = 0; 324 if (!base::StringToInt(socket_name, &tcp_port)) 325 tcp_port = 9222; 326 327 net::AddressList address_list = 328 net::AddressList::CreateFromIPAddress(ip_number, tcp_port); 329 net::TCPClientSocket* socket = new net::TCPClientSocket( 330 address_list, NULL, net::NetLog::Source()); 331 socket->Connect(base::Bind(&UseTransportQueryForDesktop, callback, socket)); 332 return; 333 } 334 #endif // defined(DEBUG_DEVTOOLS) 335 new AdbTransportSocket(port, serial, socket_name, callback); 336 } 337 338 // static 339 void AdbClientSocket::HttpQuery(int port, 340 const std::string& serial, 341 const std::string& socket_name, 342 const std::string& request_path, 343 const CommandCallback& callback) { 344 new HttpOverAdbSocket(port, serial, socket_name, request_path, 345 callback); 346 } 347 348 // static 349 void AdbClientSocket::HttpQuery(int port, 350 const std::string& serial, 351 const std::string& socket_name, 352 const std::string& request_path, 353 const SocketCallback& callback) { 354 new HttpOverAdbSocket(port, serial, socket_name, request_path, 355 callback); 356 } 357 358 AdbClientSocket::AdbClientSocket(int port) 359 : host_(kLocalhost), port_(port) { 360 } 361 362 AdbClientSocket::~AdbClientSocket() { 363 } 364 365 void AdbClientSocket::Connect(const net::CompletionCallback& callback) { 366 net::IPAddressNumber ip_number; 367 if (!net::ParseIPLiteralToNumber(host_, &ip_number)) { 368 callback.Run(net::ERR_FAILED); 369 return; 370 } 371 372 net::AddressList address_list = 373 net::AddressList::CreateFromIPAddress(ip_number, port_); 374 socket_.reset(new net::TCPClientSocket(address_list, NULL, 375 net::NetLog::Source())); 376 int result = socket_->Connect(callback); 377 if (result != net::ERR_IO_PENDING) 378 callback.Run(result); 379 } 380 381 void AdbClientSocket::SendCommand(const std::string& command, 382 bool is_void, 383 const CommandCallback& callback) { 384 scoped_refptr<net::StringIOBuffer> request_buffer = 385 new net::StringIOBuffer(EncodeMessage(command)); 386 int result = socket_->Write(request_buffer.get(), 387 request_buffer->size(), 388 base::Bind(&AdbClientSocket::ReadResponse, 389 base::Unretained(this), 390 callback, 391 is_void)); 392 if (result != net::ERR_IO_PENDING) 393 ReadResponse(callback, is_void, result); 394 } 395 396 void AdbClientSocket::ReadResponse(const CommandCallback& callback, 397 bool is_void, 398 int result) { 399 if (result < 0) { 400 callback.Run(result, "IO error"); 401 return; 402 } 403 scoped_refptr<net::IOBuffer> response_buffer = 404 new net::IOBuffer(kBufferSize); 405 result = socket_->Read(response_buffer.get(), 406 kBufferSize, 407 base::Bind(&AdbClientSocket::OnResponseHeader, 408 base::Unretained(this), 409 callback, 410 is_void, 411 response_buffer)); 412 if (result != net::ERR_IO_PENDING) 413 OnResponseHeader(callback, is_void, response_buffer, result); 414 } 415 416 void AdbClientSocket::OnResponseHeader( 417 const CommandCallback& callback, 418 bool is_void, 419 scoped_refptr<net::IOBuffer> response_buffer, 420 int result) { 421 if (result <= 0) { 422 callback.Run(result == 0 ? net::ERR_CONNECTION_CLOSED : result, 423 "IO error"); 424 return; 425 } 426 427 std::string data = std::string(response_buffer->data(), result); 428 if (result < 4) { 429 callback.Run(net::ERR_FAILED, "Response is too short: " + data); 430 return; 431 } 432 433 std::string status = data.substr(0, 4); 434 if (status != kOkayResponse) { 435 callback.Run(net::ERR_FAILED, data); 436 return; 437 } 438 439 data = data.substr(4); 440 441 if (!is_void) { 442 int payload_length = 0; 443 int bytes_left = -1; 444 if (data.length() >= 4 && 445 base::HexStringToInt(data.substr(0, 4), &payload_length)) { 446 data = data.substr(4); 447 bytes_left = payload_length - result + 8; 448 } else { 449 bytes_left = -1; 450 } 451 OnResponseData(callback, data, response_buffer, bytes_left, 0); 452 } else { 453 callback.Run(net::OK, data); 454 } 455 } 456 457 void AdbClientSocket::OnResponseData( 458 const CommandCallback& callback, 459 const std::string& response, 460 scoped_refptr<net::IOBuffer> response_buffer, 461 int bytes_left, 462 int result) { 463 if (result < 0) { 464 callback.Run(result, "IO error"); 465 return; 466 } 467 468 bytes_left -= result; 469 std::string new_response = 470 response + std::string(response_buffer->data(), result); 471 if (bytes_left == 0) { 472 callback.Run(net::OK, new_response); 473 return; 474 } 475 476 // Read tail 477 result = socket_->Read(response_buffer.get(), 478 kBufferSize, 479 base::Bind(&AdbClientSocket::OnResponseData, 480 base::Unretained(this), 481 callback, 482 new_response, 483 response_buffer, 484 bytes_left)); 485 if (result > 0) 486 OnResponseData(callback, new_response, response_buffer, bytes_left, result); 487 else if (result != net::ERR_IO_PENDING) 488 callback.Run(net::OK, new_response); 489 } 490