Home | History | Annotate | Download | only in setup
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "remoting/host/setup/native_messaging_host.h"
      6 
      7 #include <string>
      8 
      9 #include "base/basictypes.h"
     10 #include "base/bind.h"
     11 #include "base/callback.h"
     12 #include "base/json/json_string_value_serializer.h"
     13 #include "base/location.h"
     14 #include "base/message_loop/message_loop.h"
     15 #include "base/run_loop.h"
     16 #include "base/strings/stringize_macros.h"
     17 #include "base/values.h"
     18 #include "net/base/net_util.h"
     19 #include "remoting/base/rsa_key_pair.h"
     20 #include "remoting/host/host_exit_codes.h"
     21 #include "remoting/host/pairing_registry_delegate.h"
     22 #include "remoting/host/pairing_registry_delegate.h"
     23 #include "remoting/host/pin_hash.h"
     24 #include "remoting/protocol/pairing_registry.h"
     25 
     26 #if defined(OS_POSIX)
     27 #include <unistd.h>
     28 #endif
     29 
     30 namespace {
     31 
     32 // Features supported in addition to the base protocol.
     33 const char* kSupportedFeatures[] = {
     34   "pairingRegistry",
     35 };
     36 
     37 // Helper to extract the "config" part of a message as a DictionaryValue.
     38 // Returns NULL on failure, and logs an error message.
     39 scoped_ptr<base::DictionaryValue> ConfigDictionaryFromMessage(
     40     const base::DictionaryValue& message) {
     41   scoped_ptr<base::DictionaryValue> result;
     42   const base::DictionaryValue* config_dict;
     43   if (message.GetDictionary("config", &config_dict)) {
     44     result.reset(config_dict->DeepCopy());
     45   } else {
     46     LOG(ERROR) << "'config' dictionary not found";
     47   }
     48   return result.Pass();
     49 }
     50 
     51 }  // namespace
     52 
     53 namespace remoting {
     54 
     55 NativeMessagingHost::NativeMessagingHost(
     56     scoped_ptr<DaemonController> daemon_controller,
     57     scoped_refptr<protocol::PairingRegistry> pairing_registry,
     58     base::PlatformFile input,
     59     base::PlatformFile output,
     60     scoped_refptr<base::SingleThreadTaskRunner> caller_task_runner,
     61     const base::Closure& quit_closure)
     62     : caller_task_runner_(caller_task_runner),
     63       quit_closure_(quit_closure),
     64       native_messaging_reader_(input),
     65       native_messaging_writer_(output),
     66       daemon_controller_(daemon_controller.Pass()),
     67       pairing_registry_(pairing_registry),
     68       weak_factory_(this) {
     69   weak_ptr_ = weak_factory_.GetWeakPtr();
     70 }
     71 
     72 NativeMessagingHost::~NativeMessagingHost() {}
     73 
     74 void NativeMessagingHost::Start() {
     75   DCHECK(caller_task_runner_->BelongsToCurrentThread());
     76 
     77   native_messaging_reader_.Start(
     78       base::Bind(&NativeMessagingHost::ProcessMessage, weak_ptr_),
     79       base::Bind(&NativeMessagingHost::Shutdown, weak_ptr_));
     80 }
     81 
     82 void NativeMessagingHost::Shutdown() {
     83   DCHECK(caller_task_runner_->BelongsToCurrentThread());
     84   if (!quit_closure_.is_null()) {
     85     caller_task_runner_->PostTask(FROM_HERE, quit_closure_);
     86     quit_closure_.Reset();
     87   }
     88 }
     89 
     90 void NativeMessagingHost::ProcessMessage(scoped_ptr<base::Value> message) {
     91   DCHECK(caller_task_runner_->BelongsToCurrentThread());
     92 
     93   // Don't process any more messages if Shutdown() has been called.
     94   if (quit_closure_.is_null())
     95     return;
     96 
     97   const base::DictionaryValue* message_dict;
     98   if (!message->GetAsDictionary(&message_dict)) {
     99     LOG(ERROR) << "Expected DictionaryValue";
    100     Shutdown();
    101     return;
    102   }
    103 
    104   scoped_ptr<base::DictionaryValue> response_dict(new base::DictionaryValue());
    105 
    106   // If the client supplies an ID, it will expect it in the response. This
    107   // might be a string or a number, so cope with both.
    108   const base::Value* id;
    109   if (message_dict->Get("id", &id))
    110     response_dict->Set("id", id->DeepCopy());
    111 
    112   std::string type;
    113   if (!message_dict->GetString("type", &type)) {
    114     LOG(ERROR) << "'type' not found";
    115     Shutdown();
    116     return;
    117   }
    118 
    119   response_dict->SetString("type", type + "Response");
    120 
    121   bool success = false;
    122   if (type == "hello") {
    123     success = ProcessHello(*message_dict, response_dict.Pass());
    124   } else if (type == "clearPairedClients") {
    125     success = ProcessClearPairedClients(*message_dict, response_dict.Pass());
    126   } else if (type == "deletePairedClient") {
    127     success = ProcessDeletePairedClient(*message_dict, response_dict.Pass());
    128   } else if (type == "getHostName") {
    129     success = ProcessGetHostName(*message_dict, response_dict.Pass());
    130   } else if (type == "getPinHash") {
    131     success = ProcessGetPinHash(*message_dict, response_dict.Pass());
    132   } else if (type == "generateKeyPair") {
    133     success = ProcessGenerateKeyPair(*message_dict, response_dict.Pass());
    134   } else if (type == "updateDaemonConfig") {
    135     success = ProcessUpdateDaemonConfig(*message_dict, response_dict.Pass());
    136   } else if (type == "getDaemonConfig") {
    137     success = ProcessGetDaemonConfig(*message_dict, response_dict.Pass());
    138   } else if (type == "getPairedClients") {
    139     success = ProcessGetPairedClients(*message_dict, response_dict.Pass());
    140   } else if (type == "getUsageStatsConsent") {
    141     success = ProcessGetUsageStatsConsent(*message_dict, response_dict.Pass());
    142   } else if (type == "startDaemon") {
    143     success = ProcessStartDaemon(*message_dict, response_dict.Pass());
    144   } else if (type == "stopDaemon") {
    145     success = ProcessStopDaemon(*message_dict, response_dict.Pass());
    146   } else if (type == "getDaemonState") {
    147     success = ProcessGetDaemonState(*message_dict, response_dict.Pass());
    148   } else {
    149     LOG(ERROR) << "Unsupported request type: " << type;
    150   }
    151 
    152   if (!success)
    153     Shutdown();
    154 }
    155 
    156 bool NativeMessagingHost::ProcessHello(
    157     const base::DictionaryValue& message,
    158     scoped_ptr<base::DictionaryValue> response) {
    159   response->SetString("version", STRINGIZE(VERSION));
    160   scoped_ptr<base::ListValue> supported_features_list(new base::ListValue());
    161   supported_features_list->AppendStrings(std::vector<std::string>(
    162       kSupportedFeatures, kSupportedFeatures + arraysize(kSupportedFeatures)));
    163   response->Set("supportedFeatures", supported_features_list.release());
    164   SendResponse(response.Pass());
    165   return true;
    166 }
    167 
    168 bool NativeMessagingHost::ProcessClearPairedClients(
    169     const base::DictionaryValue& message,
    170     scoped_ptr<base::DictionaryValue> response) {
    171   if (pairing_registry_) {
    172     pairing_registry_->ClearAllPairings(
    173         base::Bind(&NativeMessagingHost::SendBooleanResult, weak_ptr_,
    174                    base::Passed(&response)));
    175   } else {
    176     SendBooleanResult(response.Pass(), false);
    177   }
    178   return true;
    179 }
    180 
    181 bool NativeMessagingHost::ProcessDeletePairedClient(
    182     const base::DictionaryValue& message,
    183     scoped_ptr<base::DictionaryValue> response) {
    184   std::string client_id;
    185   if (!message.GetString(protocol::PairingRegistry::kClientIdKey, &client_id)) {
    186     LOG(ERROR) << "'" << protocol::PairingRegistry::kClientIdKey
    187                << "' string not found.";
    188     return false;
    189   }
    190 
    191   if (pairing_registry_) {
    192     pairing_registry_->DeletePairing(
    193         client_id, base::Bind(&NativeMessagingHost::SendBooleanResult,
    194                               weak_ptr_, base::Passed(&response)));
    195   } else {
    196     SendBooleanResult(response.Pass(), false);
    197   }
    198   return true;
    199 }
    200 
    201 bool NativeMessagingHost::ProcessGetHostName(
    202     const base::DictionaryValue& message,
    203     scoped_ptr<base::DictionaryValue> response) {
    204   response->SetString("hostname", net::GetHostName());
    205   SendResponse(response.Pass());
    206   return true;
    207 }
    208 
    209 bool NativeMessagingHost::ProcessGetPinHash(
    210     const base::DictionaryValue& message,
    211     scoped_ptr<base::DictionaryValue> response) {
    212   std::string host_id;
    213   if (!message.GetString("hostId", &host_id)) {
    214     LOG(ERROR) << "'hostId' not found: " << message;
    215     return false;
    216   }
    217   std::string pin;
    218   if (!message.GetString("pin", &pin)) {
    219     LOG(ERROR) << "'pin' not found: " << message;
    220     return false;
    221   }
    222   response->SetString("hash", remoting::MakeHostPinHash(host_id, pin));
    223   SendResponse(response.Pass());
    224   return true;
    225 }
    226 
    227 bool NativeMessagingHost::ProcessGenerateKeyPair(
    228     const base::DictionaryValue& message,
    229     scoped_ptr<base::DictionaryValue> response) {
    230   scoped_refptr<RsaKeyPair> key_pair = RsaKeyPair::Generate();
    231   response->SetString("privateKey", key_pair->ToString());
    232   response->SetString("publicKey", key_pair->GetPublicKey());
    233   SendResponse(response.Pass());
    234   return true;
    235 }
    236 
    237 bool NativeMessagingHost::ProcessUpdateDaemonConfig(
    238     const base::DictionaryValue& message,
    239     scoped_ptr<base::DictionaryValue> response) {
    240   scoped_ptr<base::DictionaryValue> config_dict =
    241       ConfigDictionaryFromMessage(message);
    242   if (!config_dict)
    243     return false;
    244 
    245   // base::Unretained() is safe because this object owns |daemon_controller_|
    246   // which owns the thread that will run the callback.
    247   daemon_controller_->UpdateConfig(
    248       config_dict.Pass(),
    249       base::Bind(&NativeMessagingHost::SendAsyncResult, base::Unretained(this),
    250                  base::Passed(&response)));
    251   return true;
    252 }
    253 
    254 bool NativeMessagingHost::ProcessGetDaemonConfig(
    255     const base::DictionaryValue& message,
    256     scoped_ptr<base::DictionaryValue> response) {
    257   daemon_controller_->GetConfig(
    258       base::Bind(&NativeMessagingHost::SendConfigResponse,
    259                  base::Unretained(this), base::Passed(&response)));
    260   return true;
    261 }
    262 
    263 bool NativeMessagingHost::ProcessGetPairedClients(
    264     const base::DictionaryValue& message,
    265     scoped_ptr<base::DictionaryValue> response) {
    266   if (pairing_registry_) {
    267     pairing_registry_->GetAllPairings(
    268         base::Bind(&NativeMessagingHost::SendPairedClientsResponse, weak_ptr_,
    269                    base::Passed(&response)));
    270   } else {
    271     scoped_ptr<base::ListValue> no_paired_clients(new base::ListValue);
    272     SendPairedClientsResponse(response.Pass(), no_paired_clients.Pass());
    273   }
    274   return true;
    275 }
    276 
    277 bool NativeMessagingHost::ProcessGetUsageStatsConsent(
    278     const base::DictionaryValue& message,
    279     scoped_ptr<base::DictionaryValue> response) {
    280   daemon_controller_->GetUsageStatsConsent(
    281       base::Bind(&NativeMessagingHost::SendUsageStatsConsentResponse,
    282                  base::Unretained(this), base::Passed(&response)));
    283   return true;
    284 }
    285 
    286 bool NativeMessagingHost::ProcessStartDaemon(
    287     const base::DictionaryValue& message,
    288     scoped_ptr<base::DictionaryValue> response) {
    289   bool consent;
    290   if (!message.GetBoolean("consent", &consent)) {
    291     LOG(ERROR) << "'consent' not found.";
    292     return false;
    293   }
    294 
    295   scoped_ptr<base::DictionaryValue> config_dict =
    296       ConfigDictionaryFromMessage(message);
    297   if (!config_dict)
    298     return false;
    299 
    300   daemon_controller_->SetConfigAndStart(
    301       config_dict.Pass(), consent,
    302       base::Bind(&NativeMessagingHost::SendAsyncResult, base::Unretained(this),
    303                  base::Passed(&response)));
    304   return true;
    305 }
    306 
    307 bool NativeMessagingHost::ProcessStopDaemon(
    308     const base::DictionaryValue& message,
    309     scoped_ptr<base::DictionaryValue> response) {
    310   daemon_controller_->Stop(
    311       base::Bind(&NativeMessagingHost::SendAsyncResult, base::Unretained(this),
    312                  base::Passed(&response)));
    313   return true;
    314 }
    315 
    316 bool NativeMessagingHost::ProcessGetDaemonState(
    317     const base::DictionaryValue& message,
    318     scoped_ptr<base::DictionaryValue> response) {
    319   DaemonController::State state = daemon_controller_->GetState();
    320   switch (state) {
    321     case DaemonController::STATE_NOT_IMPLEMENTED:
    322       response->SetString("state", "NOT_IMPLEMENTED");
    323       break;
    324     case DaemonController::STATE_NOT_INSTALLED:
    325       response->SetString("state", "NOT_INSTALLED");
    326       break;
    327     case DaemonController::STATE_INSTALLING:
    328       response->SetString("state", "INSTALLING");
    329       break;
    330     case DaemonController::STATE_STOPPED:
    331       response->SetString("state", "STOPPED");
    332       break;
    333     case DaemonController::STATE_STARTING:
    334       response->SetString("state", "STARTING");
    335       break;
    336     case DaemonController::STATE_STARTED:
    337       response->SetString("state", "STARTED");
    338       break;
    339     case DaemonController::STATE_STOPPING:
    340       response->SetString("state", "STOPPING");
    341       break;
    342     case DaemonController::STATE_UNKNOWN:
    343       response->SetString("state", "UNKNOWN");
    344       break;
    345   }
    346   SendResponse(response.Pass());
    347   return true;
    348 }
    349 
    350 void NativeMessagingHost::SendResponse(
    351     scoped_ptr<base::DictionaryValue> response) {
    352   if (!caller_task_runner_->BelongsToCurrentThread()) {
    353     caller_task_runner_->PostTask(
    354         FROM_HERE, base::Bind(&NativeMessagingHost::SendResponse, weak_ptr_,
    355                               base::Passed(&response)));
    356     return;
    357   }
    358 
    359   if (!native_messaging_writer_.WriteMessage(*response))
    360     Shutdown();
    361 }
    362 
    363 void NativeMessagingHost::SendConfigResponse(
    364     scoped_ptr<base::DictionaryValue> response,
    365     scoped_ptr<base::DictionaryValue> config) {
    366   if (config) {
    367     response->Set("config", config.release());
    368   } else {
    369     response->Set("config", Value::CreateNullValue());
    370   }
    371   SendResponse(response.Pass());
    372 }
    373 
    374 void NativeMessagingHost::SendPairedClientsResponse(
    375     scoped_ptr<base::DictionaryValue> response,
    376     scoped_ptr<base::ListValue> pairings) {
    377   response->Set("pairedClients", pairings.release());
    378   SendResponse(response.Pass());
    379 }
    380 
    381 void NativeMessagingHost::SendUsageStatsConsentResponse(
    382     scoped_ptr<base::DictionaryValue> response,
    383     bool supported,
    384     bool allowed,
    385     bool set_by_policy) {
    386   response->SetBoolean("supported", supported);
    387   response->SetBoolean("allowed", allowed);
    388   response->SetBoolean("setByPolicy", set_by_policy);
    389   SendResponse(response.Pass());
    390 }
    391 
    392 void NativeMessagingHost::SendAsyncResult(
    393     scoped_ptr<base::DictionaryValue> response,
    394     DaemonController::AsyncResult result) {
    395   switch (result) {
    396     case DaemonController::RESULT_OK:
    397       response->SetString("result", "OK");
    398       break;
    399     case DaemonController::RESULT_FAILED:
    400       response->SetString("result", "FAILED");
    401       break;
    402     case DaemonController::RESULT_CANCELLED:
    403       response->SetString("result", "CANCELLED");
    404       break;
    405     case DaemonController::RESULT_FAILED_DIRECTORY:
    406       response->SetString("result", "FAILED_DIRECTORY");
    407       break;
    408   }
    409   SendResponse(response.Pass());
    410 }
    411 
    412 void NativeMessagingHost::SendBooleanResult(
    413     scoped_ptr<base::DictionaryValue> response, bool result) {
    414   response->SetBoolean("result", result);
    415   SendResponse(response.Pass());
    416 }
    417 
    418 int NativeMessagingHostMain() {
    419 #if defined(OS_WIN)
    420   base::PlatformFile read_file = GetStdHandle(STD_INPUT_HANDLE);
    421   base::PlatformFile write_file = GetStdHandle(STD_OUTPUT_HANDLE);
    422 #elif defined(OS_POSIX)
    423   base::PlatformFile read_file = STDIN_FILENO;
    424   base::PlatformFile write_file = STDOUT_FILENO;
    425 #else
    426 #error Not implemented.
    427 #endif
    428 
    429   base::MessageLoop message_loop(base::MessageLoop::TYPE_IO);
    430   base::RunLoop run_loop;
    431   scoped_refptr<protocol::PairingRegistry> pairing_registry =
    432       CreatePairingRegistry(message_loop.message_loop_proxy());
    433   remoting::NativeMessagingHost host(remoting::DaemonController::Create(),
    434                                      pairing_registry,
    435                                      read_file, write_file,
    436                                      message_loop.message_loop_proxy(),
    437                                      run_loop.QuitClosure());
    438   host.Start();
    439   run_loop.Run();
    440   return kSuccessExitCode;
    441 }
    442 
    443 }  // namespace remoting
    444