Home | History | Annotate | Download | only in shill
      1 //
      2 // Copyright (C) 2012 The Android Open Source Project
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 //
     16 
     17 #include "shill/connectivity_trial.h"
     18 
     19 #include <string>
     20 
     21 #include <base/bind.h>
     22 #include <base/strings/pattern.h>
     23 #include <base/strings/string_number_conversions.h>
     24 #include <base/strings/string_util.h>
     25 #include <base/strings/stringprintf.h>
     26 #if defined(__ANDROID__)
     27 #include <dbus/service_constants.h>
     28 #else
     29 #include <chromeos/dbus/service_constants.h>
     30 #endif  // __ANDROID__
     31 
     32 #include "shill/async_connection.h"
     33 #include "shill/connection.h"
     34 #include "shill/dns_client.h"
     35 #include "shill/event_dispatcher.h"
     36 #include "shill/http_request.h"
     37 #include "shill/http_url.h"
     38 #include "shill/logging.h"
     39 #include "shill/net/ip_address.h"
     40 #include "shill/net/sockets.h"
     41 
     42 using base::Bind;
     43 using base::Callback;
     44 using base::StringPrintf;
     45 using std::string;
     46 
     47 namespace shill {
     48 
     49 namespace Logging {
     50 static auto kModuleLogScope = ScopeLogger::kPortal;
     51 static string ObjectID(Connection* c) { return c->interface_name(); }
     52 }
     53 
     54 const char ConnectivityTrial::kDefaultURL[] =
     55     "http://www.gstatic.com/generate_204";
     56 const char ConnectivityTrial::kResponseExpected[] = "HTTP/?.? 204";
     57 
     58 ConnectivityTrial::ConnectivityTrial(
     59     ConnectionRefPtr connection,
     60     EventDispatcher* dispatcher,
     61     int trial_timeout_seconds,
     62     const Callback<void(Result)>& callback)
     63     : connection_(connection),
     64       dispatcher_(dispatcher),
     65       trial_timeout_seconds_(trial_timeout_seconds),
     66       trial_callback_(callback),
     67       weak_ptr_factory_(this),
     68       request_read_callback_(
     69           Bind(&ConnectivityTrial::RequestReadCallback,
     70                weak_ptr_factory_.GetWeakPtr())),
     71       request_result_callback_(
     72           Bind(&ConnectivityTrial::RequestResultCallback,
     73                weak_ptr_factory_.GetWeakPtr())),
     74       is_active_(false) { }
     75 
     76 ConnectivityTrial::~ConnectivityTrial() {
     77   Stop();
     78 }
     79 
     80 bool ConnectivityTrial::Retry(int start_delay_milliseconds) {
     81   SLOG(connection_.get(), 3) << "In " << __func__;
     82   if (request_.get())
     83     CleanupTrial(false);
     84   else
     85     return false;
     86   StartTrialAfterDelay(start_delay_milliseconds);
     87   return true;
     88 }
     89 
     90 bool ConnectivityTrial::Start(const string& url_string,
     91                               int start_delay_milliseconds) {
     92   SLOG(connection_.get(), 3) << "In " << __func__;
     93 
     94   if (!url_.ParseFromString(url_string)) {
     95     LOG(ERROR) << "Failed to parse URL string: " << url_string;
     96     return false;
     97   }
     98   if (request_.get()) {
     99     CleanupTrial(false);
    100   } else {
    101     request_.reset(new HTTPRequest(connection_, dispatcher_, &sockets_));
    102   }
    103   StartTrialAfterDelay(start_delay_milliseconds);
    104   return true;
    105 }
    106 
    107 void ConnectivityTrial::Stop() {
    108   SLOG(connection_.get(), 3) << "In " << __func__;
    109 
    110   if (!request_.get()) {
    111     return;
    112   }
    113 
    114   CleanupTrial(true);
    115 }
    116 
    117 void ConnectivityTrial::StartTrialAfterDelay(int start_delay_milliseconds) {
    118   SLOG(connection_.get(), 4) << "In " << __func__
    119                              << " delay = " << start_delay_milliseconds
    120                              << "ms.";
    121   trial_.Reset(Bind(&ConnectivityTrial::StartTrialTask,
    122                     weak_ptr_factory_.GetWeakPtr()));
    123   dispatcher_->PostDelayedTask(trial_.callback(), start_delay_milliseconds);
    124 }
    125 
    126 void ConnectivityTrial::StartTrialTask() {
    127   HTTPRequest::Result result =
    128       request_->Start(url_, request_read_callback_, request_result_callback_);
    129   if (result != HTTPRequest::kResultInProgress) {
    130     CompleteTrial(ConnectivityTrial::GetPortalResultForRequestResult(result));
    131     return;
    132   }
    133   is_active_ = true;
    134 
    135   trial_timeout_.Reset(Bind(&ConnectivityTrial::TimeoutTrialTask,
    136                             weak_ptr_factory_.GetWeakPtr()));
    137   dispatcher_->PostDelayedTask(trial_timeout_.callback(),
    138                                trial_timeout_seconds_ * 1000);
    139 }
    140 
    141 bool ConnectivityTrial::IsActive() {
    142   return is_active_;
    143 }
    144 
    145 void ConnectivityTrial::RequestReadCallback(const ByteString& response_data) {
    146   const string response_expected(kResponseExpected);
    147   bool expected_length_received = false;
    148   int compare_length = 0;
    149   if (response_data.GetLength() < response_expected.length()) {
    150     // There isn't enough data yet for a final decision, but we can still
    151     // test to see if the partial string matches so far.
    152     expected_length_received = false;
    153     compare_length = response_data.GetLength();
    154   } else {
    155     expected_length_received = true;
    156     compare_length = response_expected.length();
    157   }
    158 
    159   if (base::MatchPattern(
    160           string(reinterpret_cast<const char*>(response_data.GetConstData()),
    161                  compare_length),
    162           response_expected.substr(0, compare_length))) {
    163     if (expected_length_received) {
    164       CompleteTrial(Result(kPhaseContent, kStatusSuccess));
    165     }
    166     // Otherwise, we wait for more data from the server.
    167   } else {
    168     CompleteTrial(Result(kPhaseContent, kStatusFailure));
    169   }
    170 }
    171 
    172 void ConnectivityTrial::RequestResultCallback(
    173     HTTPRequest::Result result, const ByteString& /*response_data*/) {
    174   CompleteTrial(GetPortalResultForRequestResult(result));
    175 }
    176 
    177 void ConnectivityTrial::CompleteTrial(Result result) {
    178   SLOG(connection_.get(), 3)
    179       << StringPrintf("Connectivity Trial completed with phase==%s, status==%s",
    180                       PhaseToString(result.phase).c_str(),
    181                       StatusToString(result.status).c_str());
    182   CleanupTrial(false);
    183   trial_callback_.Run(result);
    184 }
    185 
    186 void ConnectivityTrial::CleanupTrial(bool reset_request) {
    187   trial_timeout_.Cancel();
    188 
    189   if (request_.get())
    190     request_->Stop();
    191 
    192   is_active_ = false;
    193 
    194   if (!reset_request || !request_.get())
    195     return;
    196 
    197   request_.reset();
    198 }
    199 
    200 void ConnectivityTrial::TimeoutTrialTask() {
    201   LOG(ERROR) << "Connectivity Trial - Request timed out";
    202   if (request_->response_data().GetLength()) {
    203     CompleteTrial(ConnectivityTrial::Result(ConnectivityTrial::kPhaseContent,
    204                                             ConnectivityTrial::kStatusTimeout));
    205   } else {
    206     CompleteTrial(ConnectivityTrial::Result(ConnectivityTrial::kPhaseUnknown,
    207                                             ConnectivityTrial::kStatusTimeout));
    208   }
    209 }
    210 
    211 // statiic
    212 const string ConnectivityTrial::PhaseToString(Phase phase) {
    213   switch (phase) {
    214     case kPhaseConnection:
    215       return kPortalDetectionPhaseConnection;
    216     case kPhaseDNS:
    217       return kPortalDetectionPhaseDns;
    218     case kPhaseHTTP:
    219       return kPortalDetectionPhaseHttp;
    220     case kPhaseContent:
    221       return kPortalDetectionPhaseContent;
    222     case kPhaseUnknown:
    223     default:
    224       return kPortalDetectionPhaseUnknown;
    225   }
    226 }
    227 
    228 // static
    229 const string ConnectivityTrial::StatusToString(Status status) {
    230   switch (status) {
    231     case kStatusSuccess:
    232       return kPortalDetectionStatusSuccess;
    233     case kStatusTimeout:
    234       return kPortalDetectionStatusTimeout;
    235     case kStatusFailure:
    236     default:
    237       return kPortalDetectionStatusFailure;
    238   }
    239 }
    240 
    241 ConnectivityTrial::Result ConnectivityTrial::GetPortalResultForRequestResult(
    242     HTTPRequest::Result result) {
    243   switch (result) {
    244     case HTTPRequest::kResultSuccess:
    245       // The request completed without receiving the expected payload.
    246       return Result(kPhaseContent, kStatusFailure);
    247     case HTTPRequest::kResultDNSFailure:
    248       return Result(kPhaseDNS, kStatusFailure);
    249     case HTTPRequest::kResultDNSTimeout:
    250       return Result(kPhaseDNS, kStatusTimeout);
    251     case HTTPRequest::kResultConnectionFailure:
    252       return Result(kPhaseConnection, kStatusFailure);
    253     case HTTPRequest::kResultConnectionTimeout:
    254       return Result(kPhaseConnection, kStatusTimeout);
    255     case HTTPRequest::kResultRequestFailure:
    256     case HTTPRequest::kResultResponseFailure:
    257       return Result(kPhaseHTTP, kStatusFailure);
    258     case HTTPRequest::kResultRequestTimeout:
    259     case HTTPRequest::kResultResponseTimeout:
    260       return Result(kPhaseHTTP, kStatusTimeout);
    261     case HTTPRequest::kResultUnknown:
    262     default:
    263       return Result(kPhaseUnknown, kStatusFailure);
    264   }
    265 }
    266 
    267 }  // namespace shill
    268