Home | History | Annotate | Download | only in net
      1 // Copyright (c) 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "chrome/test/chromedriver/net/port_server.h"
      6 
      7 #include "base/bind.h"
      8 #include "base/bind_helpers.h"
      9 #include "base/logging.h"
     10 #include "base/process/process_handle.h"
     11 #include "base/rand_util.h"
     12 #include "base/strings/string_number_conversions.h"
     13 #include "base/sync_socket.h"
     14 #include "chrome/test/chromedriver/chrome/status.h"
     15 #include "net/base/net_errors.h"
     16 #include "net/base/net_log.h"
     17 #include "net/base/net_util.h"
     18 #include "net/base/sys_addrinfo.h"
     19 #include "net/socket/tcp_server_socket.h"
     20 
     21 #if defined(OS_LINUX)
     22 #include <sys/socket.h>
     23 #include <sys/un.h>
     24 #endif
     25 
     26 PortReservation::PortReservation(const base::Closure& on_free_func, int port)
     27     : on_free_func_(on_free_func), port_(port) {}
     28 
     29 PortReservation::~PortReservation() {
     30   if (!on_free_func_.is_null())
     31     on_free_func_.Run();
     32 }
     33 
     34 void PortReservation::Leak() {
     35   LOG(ERROR) << "Port leaked: " << port_;
     36   on_free_func_.Reset();
     37 }
     38 
     39 PortServer::PortServer(const std::string& path) : path_(path) {
     40   CHECK(path_.size() && path_[0] == 0)
     41       << "path must be for Linux abstract namespace";
     42 }
     43 
     44 PortServer::~PortServer() {}
     45 
     46 Status PortServer::ReservePort(int* port,
     47                                scoped_ptr<PortReservation>* reservation) {
     48   int port_to_use = 0;
     49   {
     50     base::AutoLock lock(free_lock_);
     51     if (free_.size()) {
     52       port_to_use = free_.front();
     53       free_.pop_front();
     54     }
     55   }
     56   if (!port_to_use) {
     57     Status status = RequestPort(&port_to_use);
     58     if (status.IsError())
     59       return status;
     60   }
     61   *port = port_to_use;
     62   reservation->reset(new PortReservation(
     63       base::Bind(&PortServer::ReleasePort, base::Unretained(this), port_to_use),
     64       port_to_use));
     65   return Status(kOk);
     66 }
     67 
     68 Status PortServer::RequestPort(int* port) {
     69   // The client sends its PID + \n, and the server responds with a port + \n,
     70   // which is valid for the lifetime of the referred process.
     71 #if defined(OS_LINUX)
     72   int sock_fd = socket(AF_UNIX, SOCK_STREAM, 0);
     73   if (sock_fd < 0)
     74     return Status(kUnknownError, "unable to create socket");
     75   base::SyncSocket sock(sock_fd);
     76   struct timeval tv;
     77   tv.tv_sec = 10;
     78   tv.tv_usec = 0;
     79   if (setsockopt(sock_fd,
     80                  SOL_SOCKET,
     81                  SO_RCVTIMEO,
     82                  reinterpret_cast<char*>(&tv),
     83                  sizeof(tv)) < 0 ||
     84       setsockopt(sock_fd,
     85                  SOL_SOCKET,
     86                  SO_SNDTIMEO,
     87                  reinterpret_cast<char*>(&tv),
     88                  sizeof(tv)) < 0) {
     89     return Status(kUnknownError, "unable to set socket timeout");
     90   }
     91 
     92   struct sockaddr_un addr;
     93   memset(&addr, 0, sizeof(addr));
     94   addr.sun_family = AF_UNIX;
     95   memcpy(addr.sun_path, &path_[0], path_.length());
     96   if (connect(sock.handle(),
     97               reinterpret_cast<struct sockaddr*>(&addr),
     98               sizeof(sa_family_t) + path_.length())) {
     99     return Status(kUnknownError, "unable to connect");
    100   }
    101 
    102   int proc_id = static_cast<int>(base::GetCurrentProcId());
    103   std::string request = base::IntToString(proc_id);
    104   request += "\n";
    105   VLOG(0) << "PORTSERVER REQUEST " << request;
    106   if (sock.Send(request.c_str(), request.length()) != request.length())
    107     return Status(kUnknownError, "failed to send portserver request");
    108 
    109   std::string response;
    110   do {
    111     char c = 0;
    112     size_t rv = sock.Receive(&c, 1);
    113     if (!rv)
    114       break;
    115     response.push_back(c);
    116   } while (sock.Peek());
    117   if (response.empty())
    118     return Status(kUnknownError, "failed to receive portserver response");
    119   VLOG(0) << "PORTSERVER RESPONSE " << response;
    120 
    121   int new_port = 0;
    122   if (*response.rbegin() != '\n' ||
    123       !base::StringToInt(response.substr(0, response.length() - 1), &new_port))
    124     return Status(kUnknownError, "failed to parse portserver response");
    125   *port = new_port;
    126   return Status(kOk);
    127 #else
    128   return Status(kUnknownError, "not implemented for this platform");
    129 #endif
    130 }
    131 
    132 void PortServer::ReleasePort(int port) {
    133   base::AutoLock lock(free_lock_);
    134   free_.push_back(port);
    135 }
    136 
    137 PortManager::PortManager(int min_port, int max_port)
    138     : min_port_(min_port), max_port_(max_port) {
    139   CHECK_GE(max_port_, min_port_);
    140 }
    141 
    142 PortManager::~PortManager() {}
    143 
    144 Status PortManager::ReservePort(int* port,
    145                                 scoped_ptr<PortReservation>* reservation) {
    146   base::AutoLock lock(taken_lock_);
    147 
    148   int start = base::RandInt(min_port_, max_port_);
    149   bool wrapped = false;
    150   for (int try_port = start; try_port != start || !wrapped; ++try_port) {
    151     if (try_port > max_port_) {
    152       wrapped = true;
    153       if (min_port_ == max_port_)
    154         break;
    155       try_port = min_port_;
    156     }
    157     if (taken_.count(try_port))
    158       continue;
    159 
    160     char parts[] = {127, 0, 0, 1};
    161     net::IPAddressNumber address(parts, parts + arraysize(parts));
    162     net::NetLog::Source source;
    163     net::TCPServerSocket sock(NULL, source);
    164     if (sock.Listen(net::IPEndPoint(address, try_port), 1) != net::OK)
    165       continue;
    166 
    167     taken_.insert(try_port);
    168     *port = try_port;
    169     reservation->reset(new PortReservation(
    170         base::Bind(&PortManager::ReleasePort, base::Unretained(this), try_port),
    171         try_port));
    172     return Status(kOk);
    173   }
    174   return Status(kUnknownError, "unable to find open port");
    175 }
    176 
    177 void PortManager::ReleasePort(int port) {
    178   base::AutoLock lock(taken_lock_);
    179   taken_.erase(port);
    180 }
    181