Home | History | Annotate | Download | only in host
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "remoting/host/gnubby_auth_handler_posix.h"
      6 
      7 #include <unistd.h>
      8 #include <utility>
      9 
     10 #include "base/bind.h"
     11 #include "base/files/file_util.h"
     12 #include "base/json/json_reader.h"
     13 #include "base/json/json_writer.h"
     14 #include "base/lazy_instance.h"
     15 #include "base/stl_util.h"
     16 #include "base/values.h"
     17 #include "net/socket/unix_domain_listen_socket_posix.h"
     18 #include "remoting/base/logging.h"
     19 #include "remoting/host/gnubby_socket.h"
     20 #include "remoting/proto/control.pb.h"
     21 #include "remoting/protocol/client_stub.h"
     22 
     23 namespace remoting {
     24 
     25 namespace {
     26 
     27 const char kConnectionId[] = "connectionId";
     28 const char kControlMessage[] = "control";
     29 const char kControlOption[] = "option";
     30 const char kDataMessage[] = "data";
     31 const char kDataPayload[] = "data";
     32 const char kErrorMessage[] = "error";
     33 const char kGnubbyAuthMessage[] = "gnubby-auth";
     34 const char kGnubbyAuthV1[] = "auth-v1";
     35 const char kMessageType[] = "type";
     36 
     37 // The name of the socket to listen for gnubby requests on.
     38 base::LazyInstance<base::FilePath>::Leaky g_gnubby_socket_name =
     39     LAZY_INSTANCE_INITIALIZER;
     40 
     41 // STL predicate to match by a StreamListenSocket pointer.
     42 class CompareSocket {
     43  public:
     44   explicit CompareSocket(net::StreamListenSocket* socket) : socket_(socket) {}
     45 
     46   bool operator()(const std::pair<int, GnubbySocket*> element) const {
     47     return element.second->IsSocket(socket_);
     48   }
     49 
     50  private:
     51   net::StreamListenSocket* socket_;
     52 };
     53 
     54 // Socket authentication function that only allows connections from callers with
     55 // the current uid.
     56 bool MatchUid(const net::UnixDomainServerSocket::Credentials& credentials) {
     57   bool allowed = credentials.user_id == getuid();
     58   if (!allowed)
     59     HOST_LOG << "Refused socket connection from uid " << credentials.user_id;
     60   return allowed;
     61 }
     62 
     63 // Returns the command code (the first byte of the data) if it exists, or -1 if
     64 // the data is empty.
     65 unsigned int GetCommandCode(const std::string& data) {
     66   return data.empty() ? -1 : static_cast<unsigned int>(data[0]);
     67 }
     68 
     69 // Creates a string of byte data from a ListValue of numbers. Returns true if
     70 // all of the list elements are numbers.
     71 bool ConvertListValueToString(base::ListValue* bytes, std::string* out) {
     72   out->clear();
     73 
     74   unsigned int byte_count = bytes->GetSize();
     75   if (byte_count != 0) {
     76     out->reserve(byte_count);
     77     for (unsigned int i = 0; i < byte_count; i++) {
     78       int value;
     79       if (!bytes->GetInteger(i, &value))
     80         return false;
     81       out->push_back(static_cast<char>(value));
     82     }
     83   }
     84   return true;
     85 }
     86 
     87 }  // namespace
     88 
     89 GnubbyAuthHandlerPosix::GnubbyAuthHandlerPosix(
     90     protocol::ClientStub* client_stub)
     91     : client_stub_(client_stub), last_connection_id_(0) {
     92   DCHECK(client_stub_);
     93 }
     94 
     95 GnubbyAuthHandlerPosix::~GnubbyAuthHandlerPosix() {
     96   STLDeleteValues(&active_sockets_);
     97 }
     98 
     99 // static
    100 scoped_ptr<GnubbyAuthHandler> GnubbyAuthHandler::Create(
    101     protocol::ClientStub* client_stub) {
    102   return scoped_ptr<GnubbyAuthHandler>(new GnubbyAuthHandlerPosix(client_stub));
    103 }
    104 
    105 // static
    106 void GnubbyAuthHandler::SetGnubbySocketName(
    107     const base::FilePath& gnubby_socket_name) {
    108   g_gnubby_socket_name.Get() = gnubby_socket_name;
    109 }
    110 
    111 void GnubbyAuthHandlerPosix::DeliverClientMessage(const std::string& message) {
    112   DCHECK(CalledOnValidThread());
    113 
    114   scoped_ptr<base::Value> value(base::JSONReader::Read(message));
    115   base::DictionaryValue* client_message;
    116   if (value && value->GetAsDictionary(&client_message)) {
    117     std::string type;
    118     if (!client_message->GetString(kMessageType, &type)) {
    119       LOG(ERROR) << "Invalid gnubby-auth message";
    120       return;
    121     }
    122 
    123     if (type == kControlMessage) {
    124       std::string option;
    125       if (client_message->GetString(kControlOption, &option) &&
    126           option == kGnubbyAuthV1) {
    127         CreateAuthorizationSocket();
    128       } else {
    129         LOG(ERROR) << "Invalid gnubby-auth control option";
    130       }
    131     } else if (type == kDataMessage) {
    132       ActiveSockets::iterator iter = GetSocketForMessage(client_message);
    133       if (iter != active_sockets_.end()) {
    134         base::ListValue* bytes;
    135         std::string response;
    136         if (client_message->GetList(kDataPayload, &bytes) &&
    137             ConvertListValueToString(bytes, &response)) {
    138           HOST_LOG << "Sending gnubby response: " << GetCommandCode(response);
    139           iter->second->SendResponse(response);
    140         } else {
    141           LOG(ERROR) << "Invalid gnubby data";
    142           SendErrorAndCloseActiveSocket(iter);
    143         }
    144       } else {
    145         LOG(ERROR) << "Unknown gnubby-auth data connection";
    146       }
    147     } else if (type == kErrorMessage) {
    148       ActiveSockets::iterator iter = GetSocketForMessage(client_message);
    149       if (iter != active_sockets_.end()) {
    150         HOST_LOG << "Sending gnubby error";
    151         SendErrorAndCloseActiveSocket(iter);
    152       } else {
    153         LOG(ERROR) << "Unknown gnubby-auth error connection";
    154       }
    155     } else {
    156       LOG(ERROR) << "Unknown gnubby-auth message type: " << type;
    157     }
    158   }
    159 }
    160 
    161 void GnubbyAuthHandlerPosix::DeliverHostDataMessage(
    162     int connection_id,
    163     const std::string& data) const {
    164   DCHECK(CalledOnValidThread());
    165 
    166   base::DictionaryValue request;
    167   request.SetString(kMessageType, kDataMessage);
    168   request.SetInteger(kConnectionId, connection_id);
    169 
    170   base::ListValue* bytes = new base::ListValue();
    171   for (std::string::const_iterator i = data.begin(); i != data.end(); ++i) {
    172     bytes->AppendInteger(static_cast<unsigned char>(*i));
    173   }
    174   request.Set(kDataPayload, bytes);
    175 
    176   std::string request_json;
    177   if (!base::JSONWriter::Write(&request, &request_json)) {
    178     LOG(ERROR) << "Failed to create request json";
    179     return;
    180   }
    181 
    182   protocol::ExtensionMessage message;
    183   message.set_type(kGnubbyAuthMessage);
    184   message.set_data(request_json);
    185 
    186   client_stub_->DeliverHostMessage(message);
    187 }
    188 
    189 bool GnubbyAuthHandlerPosix::HasActiveSocketForTesting(
    190     net::StreamListenSocket* socket) const {
    191   return std::find_if(active_sockets_.begin(),
    192                       active_sockets_.end(),
    193                       CompareSocket(socket)) != active_sockets_.end();
    194 }
    195 
    196 int GnubbyAuthHandlerPosix::GetConnectionIdForTesting(
    197     net::StreamListenSocket* socket) const {
    198   ActiveSockets::const_iterator iter = std::find_if(
    199       active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
    200   return iter->first;
    201 }
    202 
    203 GnubbySocket* GnubbyAuthHandlerPosix::GetGnubbySocketForTesting(
    204     net::StreamListenSocket* socket) const {
    205   ActiveSockets::const_iterator iter = std::find_if(
    206       active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
    207   return iter->second;
    208 }
    209 
    210 void GnubbyAuthHandlerPosix::DidAccept(
    211     net::StreamListenSocket* server,
    212     scoped_ptr<net::StreamListenSocket> socket) {
    213   DCHECK(CalledOnValidThread());
    214 
    215   int connection_id = ++last_connection_id_;
    216   active_sockets_[connection_id] =
    217       new GnubbySocket(socket.Pass(),
    218                        base::Bind(&GnubbyAuthHandlerPosix::RequestTimedOut,
    219                                   base::Unretained(this),
    220                                   connection_id));
    221 }
    222 
    223 void GnubbyAuthHandlerPosix::DidRead(net::StreamListenSocket* socket,
    224                                      const char* data,
    225                                      int len) {
    226   DCHECK(CalledOnValidThread());
    227 
    228   ActiveSockets::iterator iter = std::find_if(
    229       active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
    230   if (iter != active_sockets_.end()) {
    231     GnubbySocket* gnubby_socket = iter->second;
    232     gnubby_socket->AddRequestData(data, len);
    233     if (gnubby_socket->IsRequestTooLarge()) {
    234       SendErrorAndCloseActiveSocket(iter);
    235     } else if (gnubby_socket->IsRequestComplete()) {
    236       std::string request_data;
    237       gnubby_socket->GetAndClearRequestData(&request_data);
    238       ProcessGnubbyRequest(iter->first, request_data);
    239     }
    240   } else {
    241     LOG(ERROR) << "Received data for unknown connection";
    242   }
    243 }
    244 
    245 void GnubbyAuthHandlerPosix::DidClose(net::StreamListenSocket* socket) {
    246   DCHECK(CalledOnValidThread());
    247 
    248   ActiveSockets::iterator iter = std::find_if(
    249       active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
    250   if (iter != active_sockets_.end()) {
    251     delete iter->second;
    252     active_sockets_.erase(iter);
    253   }
    254 }
    255 
    256 void GnubbyAuthHandlerPosix::CreateAuthorizationSocket() {
    257   DCHECK(CalledOnValidThread());
    258 
    259   if (!g_gnubby_socket_name.Get().empty()) {
    260     // If the file already exists, a socket in use error is returned.
    261     base::DeleteFile(g_gnubby_socket_name.Get(), false);
    262 
    263     HOST_LOG << "Listening for gnubby requests on "
    264              << g_gnubby_socket_name.Get().value();
    265 
    266     auth_socket_ = net::deprecated::UnixDomainListenSocket::CreateAndListen(
    267         g_gnubby_socket_name.Get().value(), this, base::Bind(MatchUid));
    268     if (!auth_socket_.get()) {
    269       LOG(ERROR) << "Failed to open socket for gnubby requests";
    270     }
    271   } else {
    272     HOST_LOG << "No gnubby socket name specified";
    273   }
    274 }
    275 
    276 void GnubbyAuthHandlerPosix::ProcessGnubbyRequest(
    277     int connection_id,
    278     const std::string& request_data) {
    279   HOST_LOG << "Received gnubby request: " << GetCommandCode(request_data);
    280   DeliverHostDataMessage(connection_id, request_data);
    281 }
    282 
    283 GnubbyAuthHandlerPosix::ActiveSockets::iterator
    284 GnubbyAuthHandlerPosix::GetSocketForMessage(base::DictionaryValue* message) {
    285   int connection_id;
    286   if (message->GetInteger(kConnectionId, &connection_id)) {
    287     return active_sockets_.find(connection_id);
    288   }
    289   return active_sockets_.end();
    290 }
    291 
    292 void GnubbyAuthHandlerPosix::SendErrorAndCloseActiveSocket(
    293     const ActiveSockets::iterator& iter) {
    294   iter->second->SendSshError();
    295 
    296   delete iter->second;
    297   active_sockets_.erase(iter);
    298 }
    299 
    300 void GnubbyAuthHandlerPosix::RequestTimedOut(int connection_id) {
    301   HOST_LOG << "Gnubby request timed out";
    302   ActiveSockets::iterator iter = active_sockets_.find(connection_id);
    303   if (iter != active_sockets_.end())
    304     SendErrorAndCloseActiveSocket(iter);
    305 }
    306 
    307 }  // namespace remoting
    308