Home | History | Annotate | Download | only in forwarder
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <errno.h>
      6 #include <fcntl.h>
      7 #include <netinet/in.h>
      8 #include <netinet/tcp.h>
      9 #include <pthread.h>
     10 #include <signal.h>
     11 #include <stdio.h>
     12 #include <stdlib.h>
     13 #include <string.h>
     14 #include <sys/select.h>
     15 #include <sys/socket.h>
     16 #include <sys/wait.h>
     17 #include <unistd.h>
     18 
     19 #include "base/command_line.h"
     20 #include "base/logging.h"
     21 #include "base/posix/eintr_wrapper.h"
     22 #include "tools/android/common/adb_connection.h"
     23 #include "tools/android/common/daemon.h"
     24 #include "tools/android/common/net.h"
     25 
     26 namespace {
     27 
     28 const pthread_t kInvalidThread = static_cast<pthread_t>(-1);
     29 volatile bool g_killed = false;
     30 
     31 void CloseSocket(int fd) {
     32   if (fd >= 0) {
     33     int old_errno = errno;
     34     close(fd);
     35     errno = old_errno;
     36   }
     37 }
     38 
     39 class Buffer {
     40  public:
     41   Buffer()
     42       : bytes_read_(0),
     43         write_offset_(0) {
     44   }
     45 
     46   bool CanRead() {
     47     return bytes_read_ == 0;
     48   }
     49 
     50   bool CanWrite() {
     51     return write_offset_ < bytes_read_;
     52   }
     53 
     54   int Read(int fd) {
     55     int ret = -1;
     56     if (CanRead()) {
     57       ret = HANDLE_EINTR(read(fd, buffer_, kBufferSize));
     58       if (ret > 0)
     59         bytes_read_ = ret;
     60     }
     61     return ret;
     62   }
     63 
     64   int Write(int fd) {
     65     int ret = -1;
     66     if (CanWrite()) {
     67       ret = HANDLE_EINTR(write(fd, buffer_ + write_offset_,
     68                                bytes_read_ - write_offset_));
     69       if (ret > 0) {
     70         write_offset_ += ret;
     71         if (write_offset_ == bytes_read_) {
     72           write_offset_ = 0;
     73           bytes_read_ = 0;
     74         }
     75       }
     76     }
     77     return ret;
     78   }
     79 
     80  private:
     81   // A big buffer to let our file-over-http bridge work more like real file.
     82   static const int kBufferSize = 1024 * 128;
     83   int bytes_read_;
     84   int write_offset_;
     85   char buffer_[kBufferSize];
     86 
     87   DISALLOW_COPY_AND_ASSIGN(Buffer);
     88 };
     89 
     90 class Server;
     91 
     92 struct ForwarderThreadInfo {
     93   ForwarderThreadInfo(Server* a_server, int a_forwarder_index)
     94       : server(a_server),
     95         forwarder_index(a_forwarder_index) {
     96   }
     97   Server* server;
     98   int forwarder_index;
     99 };
    100 
    101 struct ForwarderInfo {
    102   time_t start_time;
    103   int socket1;
    104   time_t socket1_last_byte_time;
    105   size_t socket1_bytes;
    106   int socket2;
    107   time_t socket2_last_byte_time;
    108   size_t socket2_bytes;
    109 };
    110 
    111 class Server {
    112  public:
    113   Server()
    114       : thread_(kInvalidThread),
    115         socket_(-1) {
    116     memset(forward_to_, 0, sizeof(forward_to_));
    117     memset(&forwarders_, 0, sizeof(forwarders_));
    118   }
    119 
    120   int GetFreeForwarderIndex() {
    121     for (int i = 0; i < kMaxForwarders; i++) {
    122       if (forwarders_[i].start_time == 0)
    123         return i;
    124     }
    125     return -1;
    126   }
    127 
    128   void DisposeForwarderInfo(int index) {
    129     forwarders_[index].start_time = 0;
    130   }
    131 
    132   ForwarderInfo* GetForwarderInfo(int index) {
    133     return &forwarders_[index];
    134   }
    135 
    136   void DumpInformation() {
    137     LOG(INFO) << "Server information: " << forward_to_;
    138     LOG(INFO) << "No.: age up(bytes,idle) down(bytes,idle)";
    139     int count = 0;
    140     time_t now = time(NULL);
    141     for (int i = 0; i < kMaxForwarders; i++) {
    142       const ForwarderInfo& info = forwarders_[i];
    143       if (info.start_time) {
    144         count++;
    145         LOG(INFO) << count << ": " << now - info.start_time << " up("
    146                   << info.socket1_bytes << ","
    147                   << now - info.socket1_last_byte_time << " down("
    148                   << info.socket2_bytes << ","
    149                   << now - info.socket2_last_byte_time << ")";
    150       }
    151     }
    152   }
    153 
    154   void Shutdown() {
    155     if (socket_ >= 0)
    156       shutdown(socket_, SHUT_RDWR);
    157   }
    158 
    159   bool InitSocket(const char* arg);
    160 
    161   void StartThread() {
    162     pthread_create(&thread_, NULL, ServerThread, this);
    163   }
    164 
    165   void JoinThread() {
    166     if (thread_ != kInvalidThread)
    167       pthread_join(thread_, NULL);
    168   }
    169 
    170  private:
    171   static void* ServerThread(void* arg);
    172 
    173   // There are 3 kinds of threads that will access the array:
    174   // 1. Server thread will get a free ForwarderInfo and initialize it;
    175   // 2. Forwarder threads will dispose the ForwarderInfo when it finishes;
    176   // 3. Main thread will iterate and print the forwarders.
    177   // Using an array is not optimal, but can avoid locks or other complex
    178   // inter-thread communication.
    179   static const int kMaxForwarders = 512;
    180   ForwarderInfo forwarders_[kMaxForwarders];
    181 
    182   pthread_t thread_;
    183   int socket_;
    184   char forward_to_[40];
    185 
    186   DISALLOW_COPY_AND_ASSIGN(Server);
    187 };
    188 
    189 // Forwards all outputs from one socket to another socket.
    190 void* ForwarderThread(void* arg) {
    191   ForwarderThreadInfo* thread_info =
    192       reinterpret_cast<ForwarderThreadInfo*>(arg);
    193   Server* server = thread_info->server;
    194   int index = thread_info->forwarder_index;
    195   delete thread_info;
    196   ForwarderInfo* info = server->GetForwarderInfo(index);
    197   int socket1 = info->socket1;
    198   int socket2 = info->socket2;
    199   int nfds = socket1 > socket2 ? socket1 + 1 : socket2 + 1;
    200   fd_set read_fds;
    201   fd_set write_fds;
    202   Buffer buffer1;
    203   Buffer buffer2;
    204 
    205   while (!g_killed) {
    206     FD_ZERO(&read_fds);
    207     if (buffer1.CanRead())
    208       FD_SET(socket1, &read_fds);
    209     if (buffer2.CanRead())
    210       FD_SET(socket2, &read_fds);
    211 
    212     FD_ZERO(&write_fds);
    213     if (buffer1.CanWrite())
    214       FD_SET(socket2, &write_fds);
    215     if (buffer2.CanWrite())
    216       FD_SET(socket1, &write_fds);
    217 
    218     if (HANDLE_EINTR(select(nfds, &read_fds, &write_fds, NULL, NULL)) <= 0) {
    219       LOG(ERROR) << "Select error: " << strerror(errno);
    220       break;
    221     }
    222 
    223     int now = time(NULL);
    224     if (FD_ISSET(socket1, &read_fds)) {
    225       info->socket1_last_byte_time = now;
    226       int bytes = buffer1.Read(socket1);
    227       if (bytes <= 0)
    228         break;
    229       info->socket1_bytes += bytes;
    230     }
    231     if (FD_ISSET(socket2, &read_fds)) {
    232       info->socket2_last_byte_time = now;
    233       int bytes = buffer2.Read(socket2);
    234       if (bytes <= 0)
    235         break;
    236       info->socket2_bytes += bytes;
    237     }
    238     if (FD_ISSET(socket1, &write_fds)) {
    239       if (buffer2.Write(socket1) <= 0)
    240         break;
    241     }
    242     if (FD_ISSET(socket2, &write_fds)) {
    243       if (buffer1.Write(socket2) <= 0)
    244         break;
    245     }
    246   }
    247 
    248   CloseSocket(socket1);
    249   CloseSocket(socket2);
    250   server->DisposeForwarderInfo(index);
    251   return NULL;
    252 }
    253 
    254 // Listens to a server socket. On incoming request, forward it to the host.
    255 // static
    256 void* Server::ServerThread(void* arg) {
    257   Server* server = reinterpret_cast<Server*>(arg);
    258   while (!g_killed) {
    259     int forwarder_index = server->GetFreeForwarderIndex();
    260     if (forwarder_index < 0) {
    261       LOG(ERROR) << "Too many forwarders";
    262       continue;
    263     }
    264 
    265     struct sockaddr_in addr;
    266     socklen_t addr_len = sizeof(addr);
    267     int socket = HANDLE_EINTR(accept(server->socket_,
    268                                      reinterpret_cast<sockaddr*>(&addr),
    269                                      &addr_len));
    270     if (socket < 0) {
    271       LOG(ERROR) << "Failed to accept: " << strerror(errno);
    272       break;
    273     }
    274     tools::DisableNagle(socket);
    275 
    276     int host_socket = tools::ConnectAdbHostSocket(server->forward_to_);
    277     if (host_socket >= 0) {
    278       // Set NONBLOCK flag because we use select().
    279       fcntl(socket, F_SETFL, fcntl(socket, F_GETFL) | O_NONBLOCK);
    280       fcntl(host_socket, F_SETFL, fcntl(host_socket, F_GETFL) | O_NONBLOCK);
    281 
    282       ForwarderInfo* forwarder_info = server->GetForwarderInfo(forwarder_index);
    283       time_t now = time(NULL);
    284       forwarder_info->start_time = now;
    285       forwarder_info->socket1 = socket;
    286       forwarder_info->socket1_last_byte_time = now;
    287       forwarder_info->socket1_bytes = 0;
    288       forwarder_info->socket2 = host_socket;
    289       forwarder_info->socket2_last_byte_time = now;
    290       forwarder_info->socket2_bytes = 0;
    291 
    292       pthread_t thread;
    293       pthread_create(&thread, NULL, ForwarderThread,
    294                      new ForwarderThreadInfo(server, forwarder_index));
    295     } else {
    296       // Close the unused client socket which is failed to connect to host.
    297       CloseSocket(socket);
    298     }
    299   }
    300 
    301   CloseSocket(server->socket_);
    302   server->socket_ = -1;
    303   return NULL;
    304 }
    305 
    306 // Format of arg: <Device port>[:<Forward to port>:<Forward to address>]
    307 bool Server::InitSocket(const char* arg) {
    308   char* endptr;
    309   int local_port = static_cast<int>(strtol(arg, &endptr, 10));
    310   if (local_port < 0)
    311     return false;
    312 
    313   if (*endptr != ':') {
    314     snprintf(forward_to_, sizeof(forward_to_), "%d:127.0.0.1", local_port);
    315   } else {
    316     strncpy(forward_to_, endptr + 1, sizeof(forward_to_) - 1);
    317   }
    318 
    319   socket_ = socket(AF_INET, SOCK_STREAM, 0);
    320   if (socket_ < 0) {
    321     perror("server socket");
    322     return false;
    323   }
    324   tools::DisableNagle(socket_);
    325 
    326   sockaddr_in addr;
    327   memset(&addr, 0, sizeof(addr));
    328   addr.sin_family = AF_INET;
    329   addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
    330   addr.sin_port = htons(local_port);
    331   int reuse_addr = 1;
    332   setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR,
    333              &reuse_addr, sizeof(reuse_addr));
    334   tools::DeferAccept(socket_);
    335   if (HANDLE_EINTR(bind(socket_, reinterpret_cast<sockaddr*>(&addr),
    336                         sizeof(addr))) < 0 ||
    337       HANDLE_EINTR(listen(socket_, 5)) < 0) {
    338     perror("server bind");
    339     CloseSocket(socket_);
    340     socket_ = -1;
    341     return false;
    342   }
    343 
    344   if (local_port == 0) {
    345     socklen_t addrlen = sizeof(addr);
    346     if (getsockname(socket_, reinterpret_cast<sockaddr*>(&addr), &addrlen)
    347         != 0) {
    348       perror("get listen address");
    349       CloseSocket(socket_);
    350       socket_ = -1;
    351       return false;
    352     }
    353     local_port = ntohs(addr.sin_port);
    354   }
    355 
    356   printf("Forwarding device port %d to host %s\n", local_port, forward_to_);
    357   return true;
    358 }
    359 
    360 int g_server_count = 0;
    361 Server* g_servers = NULL;
    362 
    363 void KillHandler(int unused) {
    364   g_killed = true;
    365   for (int i = 0; i < g_server_count; i++)
    366     g_servers[i].Shutdown();
    367 }
    368 
    369 void DumpInformation(int unused) {
    370   for (int i = 0; i < g_server_count; i++)
    371     g_servers[i].DumpInformation();
    372 }
    373 
    374 }  // namespace
    375 
    376 int main(int argc, char** argv) {
    377   printf("Android device to host TCP forwarder\n");
    378   printf("Like 'adb forward' but in the reverse direction\n");
    379 
    380   CommandLine command_line(argc, argv);
    381   CommandLine::StringVector server_args = command_line.GetArgs();
    382   if (tools::HasHelpSwitch(command_line) || server_args.empty()) {
    383     tools::ShowHelp(
    384         argv[0],
    385         "<Device port>[:<Forward to port>:<Forward to address>] ...",
    386         "  <Forward to port> default is <Device port>\n"
    387         "  <Forward to address> default is 127.0.0.1\n"
    388         "If <Device port> is 0, a port will by dynamically allocated.\n");
    389     return 0;
    390   }
    391 
    392   g_servers = new Server[server_args.size()];
    393   g_server_count = 0;
    394   int failed_count = 0;
    395   for (size_t i = 0; i < server_args.size(); i++) {
    396     if (!g_servers[g_server_count].InitSocket(server_args[i].c_str())) {
    397       printf("Couldn't start forwarder server for port spec: %s\n",
    398              server_args[i].c_str());
    399       ++failed_count;
    400     } else {
    401       ++g_server_count;
    402     }
    403   }
    404 
    405   if (g_server_count == 0) {
    406     printf("No forwarder servers could be started. Exiting.\n");
    407     delete [] g_servers;
    408     return failed_count;
    409   }
    410 
    411   if (!tools::HasNoSpawnDaemonSwitch(command_line))
    412     tools::SpawnDaemon(failed_count);
    413 
    414   signal(SIGTERM, KillHandler);
    415   signal(SIGUSR2, DumpInformation);
    416 
    417   for (int i = 0; i < g_server_count; i++)
    418     g_servers[i].StartThread();
    419   for (int i = 0; i < g_server_count; i++)
    420     g_servers[i].JoinThread();
    421   g_server_count = 0;
    422   delete [] g_servers;
    423 
    424   return 0;
    425 }
    426 
    427