Home | History | Annotate | Download | only in test
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "device/test/usb_test_gadget.h"
      6 
      7 #include <string>
      8 #include <vector>
      9 
     10 #include "base/command_line.h"
     11 #include "base/compiler_specific.h"
     12 #include "base/files/file.h"
     13 #include "base/files/file_path.h"
     14 #include "base/logging.h"
     15 #include "base/macros.h"
     16 #include "base/memory/ref_counted.h"
     17 #include "base/memory/scoped_ptr.h"
     18 #include "base/path_service.h"
     19 #include "base/process/process.h"
     20 #include "base/run_loop.h"
     21 #include "base/strings/string_number_conversions.h"
     22 #include "base/strings/stringprintf.h"
     23 #include "base/strings/utf_string_conversions.h"
     24 #include "base/threading/platform_thread.h"
     25 #include "base/time/time.h"
     26 #include "device/usb/usb_device.h"
     27 #include "device/usb/usb_device_handle.h"
     28 #include "device/usb/usb_service.h"
     29 #include "net/proxy/proxy_service.h"
     30 #include "net/url_request/url_fetcher.h"
     31 #include "net/url_request/url_fetcher_delegate.h"
     32 #include "net/url_request/url_request_context.h"
     33 #include "net/url_request/url_request_context_builder.h"
     34 #include "net/url_request/url_request_context_getter.h"
     35 #include "url/gurl.h"
     36 
     37 using ::base::PlatformThread;
     38 using ::base::TimeDelta;
     39 
     40 namespace device {
     41 
     42 namespace {
     43 
     44 static const char kCommandLineSwitch[] = "enable-gadget-tests";
     45 static const int kClaimRetries = 100;  // 5 seconds
     46 static const int kDisconnectRetries = 100;  // 5 seconds
     47 static const int kRetryPeriod = 50;  // 0.05 seconds
     48 static const int kReconnectRetries = 100;  // 5 seconds
     49 static const int kUpdateRetries = 100;  // 5 seconds
     50 
     51 struct UsbTestGadgetConfiguration {
     52   UsbTestGadget::Type type;
     53   const char* http_resource;
     54   uint16 product_id;
     55 };
     56 
     57 static const struct UsbTestGadgetConfiguration kConfigurations[] = {
     58     {UsbTestGadget::DEFAULT, "/unconfigure", 0x58F0},
     59     {UsbTestGadget::KEYBOARD, "/keyboard/configure", 0x58F1},
     60     {UsbTestGadget::MOUSE, "/mouse/configure", 0x58F2},
     61     {UsbTestGadget::HID_ECHO, "/hid_echo/configure", 0x58F3},
     62     {UsbTestGadget::ECHO, "/echo/configure", 0x58F4},
     63 };
     64 
     65 class UsbTestGadgetImpl : public UsbTestGadget {
     66  public:
     67   virtual ~UsbTestGadgetImpl();
     68 
     69   virtual bool Unclaim() OVERRIDE;
     70   virtual bool Disconnect() OVERRIDE;
     71   virtual bool Reconnect() OVERRIDE;
     72   virtual bool SetType(Type type) OVERRIDE;
     73   virtual UsbDevice* GetDevice() const OVERRIDE;
     74   virtual std::string GetSerialNumber() const OVERRIDE;
     75 
     76  protected:
     77   UsbTestGadgetImpl();
     78 
     79  private:
     80   scoped_ptr<net::URLFetcher> CreateURLFetcher(
     81       const GURL& url,
     82       net::URLFetcher::RequestType request_type,
     83       net::URLFetcherDelegate* delegate);
     84   int SimplePOSTRequest(const GURL& url, const std::string& form_data);
     85   bool FindUnclaimed();
     86   bool GetVersion(std::string* version);
     87   bool Update();
     88   bool FindClaimed();
     89   bool ReadLocalVersion(std::string* version);
     90   bool ReadLocalPackage(std::string* package);
     91   bool ReadFile(const base::FilePath& file_path, std::string* content);
     92 
     93   class Delegate : public net::URLFetcherDelegate {
     94    public:
     95     Delegate() {}
     96     virtual ~Delegate() {}
     97 
     98     void WaitForCompletion() {
     99       run_loop_.Run();
    100     }
    101 
    102     virtual void OnURLFetchComplete(const net::URLFetcher* source) OVERRIDE {
    103       run_loop_.Quit();
    104     }
    105 
    106    private:
    107     base::RunLoop run_loop_;
    108 
    109     DISALLOW_COPY_AND_ASSIGN(Delegate);
    110   };
    111 
    112   scoped_refptr<UsbDevice> device_;
    113   std::string device_address_;
    114   scoped_ptr<net::URLRequestContext> request_context_;
    115   std::string session_id_;
    116   UsbService* usb_service_;
    117 
    118   friend class UsbTestGadget;
    119 
    120   DISALLOW_COPY_AND_ASSIGN(UsbTestGadgetImpl);
    121 };
    122 
    123 }  // namespace
    124 
    125 bool UsbTestGadget::IsTestEnabled() {
    126   base::CommandLine* command_line = CommandLine::ForCurrentProcess();
    127   return command_line->HasSwitch(kCommandLineSwitch);
    128 }
    129 
    130 scoped_ptr<UsbTestGadget> UsbTestGadget::Claim() {
    131   scoped_ptr<UsbTestGadgetImpl> gadget(new UsbTestGadgetImpl);
    132 
    133   int retries = kClaimRetries;
    134   while (!gadget->FindUnclaimed()) {
    135     if (--retries == 0) {
    136       LOG(ERROR) << "Failed to find an unclaimed device.";
    137       return scoped_ptr<UsbTestGadget>();
    138     }
    139     PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod));
    140   }
    141   VLOG(1) << "It took " << (kClaimRetries - retries)
    142           << " retries to find an unclaimed device.";
    143 
    144   return gadget.PassAs<UsbTestGadget>();
    145 }
    146 
    147 UsbTestGadgetImpl::UsbTestGadgetImpl() {
    148   net::URLRequestContextBuilder context_builder;
    149   context_builder.set_proxy_service(net::ProxyService::CreateDirect());
    150   request_context_.reset(context_builder.Build());
    151 
    152   base::ProcessId process_id = base::Process::Current().pid();
    153   session_id_ = base::StringPrintf(
    154       "%s:%p", base::HexEncode(&process_id, sizeof(process_id)).c_str(), this);
    155 
    156   usb_service_ = UsbService::GetInstance(NULL);
    157 }
    158 
    159 UsbTestGadgetImpl::~UsbTestGadgetImpl() {
    160   if (!device_address_.empty()) {
    161     Unclaim();
    162   }
    163 }
    164 
    165 UsbDevice* UsbTestGadgetImpl::GetDevice() const {
    166   return device_.get();
    167 }
    168 
    169 std::string UsbTestGadgetImpl::GetSerialNumber() const {
    170   return device_address_;
    171 }
    172 
    173 scoped_ptr<net::URLFetcher> UsbTestGadgetImpl::CreateURLFetcher(
    174     const GURL& url, net::URLFetcher::RequestType request_type,
    175     net::URLFetcherDelegate* delegate) {
    176   scoped_ptr<net::URLFetcher> url_fetcher(
    177       net::URLFetcher::Create(url, request_type, delegate));
    178 
    179   url_fetcher->SetRequestContext(
    180       new net::TrivialURLRequestContextGetter(
    181           request_context_.get(),
    182           base::MessageLoop::current()->message_loop_proxy()));
    183 
    184   return url_fetcher.PassAs<net::URLFetcher>();
    185 }
    186 
    187 int UsbTestGadgetImpl::SimplePOSTRequest(const GURL& url,
    188                                          const std::string& form_data) {
    189   Delegate delegate;
    190   scoped_ptr<net::URLFetcher> url_fetcher =
    191     CreateURLFetcher(url, net::URLFetcher::POST, &delegate);
    192 
    193   url_fetcher->SetUploadData("application/x-www-form-urlencoded", form_data);
    194   url_fetcher->Start();
    195   delegate.WaitForCompletion();
    196 
    197   return url_fetcher->GetResponseCode();
    198 }
    199 
    200 bool UsbTestGadgetImpl::FindUnclaimed() {
    201   std::vector<scoped_refptr<UsbDevice> > devices;
    202   usb_service_->GetDevices(&devices);
    203 
    204   for (std::vector<scoped_refptr<UsbDevice> >::const_iterator iter =
    205          devices.begin(); iter != devices.end(); ++iter) {
    206     const scoped_refptr<UsbDevice> &device = *iter;
    207     if (device->vendor_id() == 0x18D1 && device->product_id() == 0x58F0) {
    208       base::string16 serial_utf16;
    209       if (!device->GetSerialNumber(&serial_utf16)) {
    210         continue;
    211       }
    212 
    213       const std::string serial = base::UTF16ToUTF8(serial_utf16);
    214       const GURL url("http://" + serial + "/claim");
    215       const std::string form_data = base::StringPrintf(
    216           "session_id=%s",
    217           net::EscapeUrlEncodedData(session_id_, true).c_str());
    218       const int response_code = SimplePOSTRequest(url, form_data);
    219 
    220       if (response_code == 200) {
    221         device_address_ = serial;
    222         device_ = device;
    223         break;
    224       }
    225 
    226       // The device is probably claimed by another process.
    227       if (response_code != 403) {
    228         LOG(WARNING) << "Unexpected HTTP " << response_code << " from /claim.";
    229       }
    230     }
    231   }
    232 
    233   std::string local_version;
    234   std::string version;
    235   if (!ReadLocalVersion(&local_version) ||
    236       !GetVersion(&version)) {
    237     return false;
    238   }
    239 
    240   if (version == local_version) {
    241     return true;
    242   }
    243 
    244   return Update();
    245 }
    246 
    247 bool UsbTestGadgetImpl::GetVersion(std::string* version) {
    248   Delegate delegate;
    249   const GURL url("http://" + device_address_ + "/version");
    250   scoped_ptr<net::URLFetcher> url_fetcher =
    251       CreateURLFetcher(url, net::URLFetcher::GET, &delegate);
    252 
    253   url_fetcher->Start();
    254   delegate.WaitForCompletion();
    255 
    256   const int response_code = url_fetcher->GetResponseCode();
    257   if (response_code != 200) {
    258     VLOG(2) << "Unexpected HTTP " << response_code << " from /version.";
    259     return false;
    260   }
    261 
    262   STLClearObject(version);
    263   if (!url_fetcher->GetResponseAsString(version)) {
    264     VLOG(2) << "Failed to read body from /version.";
    265     return false;
    266   }
    267   return true;
    268 }
    269 
    270 bool UsbTestGadgetImpl::Update() {
    271   std::string version;
    272   if (!ReadLocalVersion(&version)) {
    273     return false;
    274   }
    275   LOG(INFO) << "Updating " << device_address_ << " to " << version << "...";
    276 
    277   Delegate delegate;
    278   const GURL url("http://" + device_address_ + "/update");
    279   scoped_ptr<net::URLFetcher> url_fetcher =
    280       CreateURLFetcher(url, net::URLFetcher::POST, &delegate);
    281 
    282   const std::string mime_header =
    283       base::StringPrintf(
    284       "--foo\r\n"
    285       "Content-Disposition: form-data; name=\"file\"; "
    286           "filename=\"usb_gadget-%s.zip\"\r\n"
    287       "Content-Type: application/octet-stream\r\n"
    288       "\r\n", version.c_str());
    289   const std::string mime_footer("\r\n--foo--\r\n");
    290 
    291   std::string package;
    292   if (!ReadLocalPackage(&package)) {
    293     return false;
    294   }
    295 
    296   url_fetcher->SetUploadData("multipart/form-data; boundary=foo",
    297                              mime_header + package + mime_footer);
    298   url_fetcher->Start();
    299   delegate.WaitForCompletion();
    300 
    301   const int response_code = url_fetcher->GetResponseCode();
    302   if (response_code != 200) {
    303     LOG(ERROR) << "Unexpected HTTP " << response_code << " from /update.";
    304     return false;
    305   }
    306 
    307   int retries = kUpdateRetries;
    308   std::string new_version;
    309   while (!GetVersion(&new_version) || new_version != version) {
    310     if (--retries == 0) {
    311       LOG(ERROR) << "Device not responding with new version.";
    312       return false;
    313     }
    314     PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod));
    315   }
    316   VLOG(1) << "It took " << (kUpdateRetries - retries)
    317           << " retries to see the new version.";
    318 
    319   // Release the old reference to the device and try to open a new one.
    320   device_ = NULL;
    321   retries = kReconnectRetries;
    322   while (!FindClaimed()) {
    323     if (--retries == 0) {
    324       LOG(ERROR) << "Failed to find updated device.";
    325       return false;
    326     }
    327     PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod));
    328   }
    329   VLOG(1) << "It took " << (kReconnectRetries - retries)
    330           << " retries to find the updated device.";
    331 
    332   return true;
    333 }
    334 
    335 bool UsbTestGadgetImpl::FindClaimed() {
    336   CHECK(!device_.get());
    337 
    338   std::string expected_serial = GetSerialNumber();
    339 
    340   std::vector<scoped_refptr<UsbDevice> > devices;
    341   usb_service_->GetDevices(&devices);
    342 
    343   for (std::vector<scoped_refptr<UsbDevice> >::iterator iter =
    344          devices.begin(); iter != devices.end(); ++iter) {
    345     scoped_refptr<UsbDevice> &device = *iter;
    346 
    347     if (device->vendor_id() == 0x18D1) {
    348       const uint16 product_id = device->product_id();
    349       bool found = false;
    350       for (size_t i = 0; i < arraysize(kConfigurations); ++i) {
    351         if (product_id == kConfigurations[i].product_id) {
    352           found = true;
    353           break;
    354         }
    355       }
    356       if (!found) {
    357         continue;
    358       }
    359 
    360       base::string16 serial_utf16;
    361       if (!device->GetSerialNumber(&serial_utf16)) {
    362         continue;
    363       }
    364 
    365       std::string serial = base::UTF16ToUTF8(serial_utf16);
    366       if (serial != expected_serial) {
    367         continue;
    368       }
    369 
    370       device_ = device;
    371       return true;
    372     }
    373   }
    374 
    375   return false;
    376 }
    377 
    378 bool UsbTestGadgetImpl::ReadLocalVersion(std::string* version) {
    379   base::FilePath file_path;
    380   CHECK(PathService::Get(base::DIR_EXE, &file_path));
    381   file_path = file_path.AppendASCII("usb_gadget.zip.md5");
    382 
    383   return ReadFile(file_path, version);
    384 }
    385 
    386 bool UsbTestGadgetImpl::ReadLocalPackage(std::string* package) {
    387   base::FilePath file_path;
    388   CHECK(PathService::Get(base::DIR_EXE, &file_path));
    389   file_path = file_path.AppendASCII("usb_gadget.zip");
    390 
    391   return ReadFile(file_path, package);
    392 }
    393 
    394 bool UsbTestGadgetImpl::ReadFile(const base::FilePath& file_path,
    395                                  std::string* content) {
    396   base::File file(file_path, base::File::FLAG_OPEN | base::File::FLAG_READ);
    397   if (!file.IsValid()) {
    398     LOG(ERROR) << "Cannot open " << file_path.MaybeAsASCII() << ": "
    399                << base::File::ErrorToString(file.error_details());
    400     return false;
    401   }
    402 
    403   STLClearObject(content);
    404   int rv;
    405   do {
    406     char buf[4096];
    407     rv = file.ReadAtCurrentPos(buf, sizeof buf);
    408     if (rv == -1) {
    409       LOG(ERROR) << "Cannot read " << file_path.MaybeAsASCII() << ": "
    410                  << base::File::ErrorToString(file.error_details());
    411       return false;
    412     }
    413     content->append(buf, rv);
    414   } while (rv > 0);
    415 
    416   return true;
    417 }
    418 
    419 bool UsbTestGadgetImpl::Unclaim() {
    420   VLOG(1) << "Releasing the device at " << device_address_ << ".";
    421 
    422   const GURL url("http://" + device_address_ + "/unclaim");
    423   const int response_code = SimplePOSTRequest(url, "");
    424 
    425   if (response_code != 200) {
    426     LOG(ERROR) << "Unexpected HTTP " << response_code << " from /unclaim.";
    427     return false;
    428   }
    429   return true;
    430 }
    431 
    432 bool UsbTestGadgetImpl::SetType(Type type) {
    433   const struct UsbTestGadgetConfiguration* config = NULL;
    434   for (size_t i = 0; i < arraysize(kConfigurations); ++i) {
    435     if (kConfigurations[i].type == type) {
    436       config = &kConfigurations[i];
    437     }
    438   }
    439   CHECK(config);
    440 
    441   const GURL url("http://" + device_address_ + config->http_resource);
    442   const int response_code = SimplePOSTRequest(url, "");
    443 
    444   if (response_code != 200) {
    445     LOG(ERROR) << "Unexpected HTTP " << response_code
    446                << " from " << config->http_resource << ".";
    447     return false;
    448   }
    449 
    450   // Release the old reference to the device and try to open a new one.
    451   int retries = kReconnectRetries;
    452   while (true) {
    453     device_ = NULL;
    454     if (FindClaimed() && device_->product_id() == config->product_id) {
    455       break;
    456     }
    457     if (--retries == 0) {
    458       LOG(ERROR) << "Failed to find updated device.";
    459       return false;
    460     }
    461     PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod));
    462   }
    463   VLOG(1) << "It took " << (kReconnectRetries - retries)
    464           << " retries to find the updated device.";
    465 
    466   return true;
    467 }
    468 
    469 bool UsbTestGadgetImpl::Disconnect() {
    470   const GURL url("http://" + device_address_ + "/disconnect");
    471   const int response_code = SimplePOSTRequest(url, "");
    472 
    473   if (response_code != 200) {
    474     LOG(ERROR) << "Unexpected HTTP " << response_code << " from /disconnect.";
    475     return false;
    476   }
    477 
    478   // Release the old reference to the device and wait until it can't be found.
    479   int retries = kDisconnectRetries;
    480   while (true) {
    481     device_ = NULL;
    482     if (!FindClaimed()) {
    483       break;
    484     }
    485     if (--retries == 0) {
    486       LOG(ERROR) << "Device did not disconnect.";
    487       return false;
    488     }
    489     PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod));
    490   }
    491   VLOG(1) << "It took " << (kDisconnectRetries - retries)
    492           << " retries for the device to disconnect.";
    493 
    494   return true;
    495 }
    496 
    497 bool UsbTestGadgetImpl::Reconnect() {
    498   const GURL url("http://" + device_address_ + "/reconnect");
    499   const int response_code = SimplePOSTRequest(url, "");
    500 
    501   if (response_code != 200) {
    502     LOG(ERROR) << "Unexpected HTTP " << response_code << " from /reconnect.";
    503     return false;
    504   }
    505 
    506   int retries = kDisconnectRetries;
    507   while (true) {
    508     if (FindClaimed()) {
    509       break;
    510     }
    511     if (--retries == 0) {
    512       LOG(ERROR) << "Device did not reconnect.";
    513       return false;
    514     }
    515     PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod));
    516   }
    517   VLOG(1) << "It took " << (kDisconnectRetries - retries)
    518           << " retries for the device to reconnect.";
    519 
    520   return true;
    521 }
    522 
    523 }  // namespace device
    524