Home | History | Annotate | Download | only in buffet
      1 // Copyright 2015 The Android Open Source Project
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include <arpa/inet.h>
     16 #include <map>
     17 #include <netdb.h>
     18 #include <string>
     19 #include <sys/socket.h>
     20 #include <sys/types.h>
     21 #include <unistd.h>
     22 
     23 #include <base/bind.h>
     24 #include <base/bind_helpers.h>
     25 #include <base/files/file_util.h>
     26 #include <base/message_loop/message_loop.h>
     27 #include <base/strings/stringprintf.h>
     28 #include <brillo/bind_lambda.h>
     29 #include <brillo/streams/file_stream.h>
     30 #include <brillo/streams/tls_stream.h>
     31 
     32 #include "buffet/socket_stream.h"
     33 #include "buffet/weave_error_conversion.h"
     34 
     35 namespace buffet {
     36 
     37 using weave::provider::Network;
     38 
     39 namespace {
     40 
     41 std::string GetIPAddress(const sockaddr* sa) {
     42   std::string addr;
     43   char str[INET6_ADDRSTRLEN] = {};
     44   switch (sa->sa_family) {
     45     case AF_INET:
     46       if (inet_ntop(AF_INET,
     47                     &(reinterpret_cast<const sockaddr_in*>(sa)->sin_addr), str,
     48                     sizeof(str))) {
     49         addr = str;
     50       }
     51       break;
     52 
     53     case AF_INET6:
     54       if (inet_ntop(AF_INET6,
     55                     &(reinterpret_cast<const sockaddr_in6*>(sa)->sin6_addr),
     56                     str, sizeof(str))) {
     57         addr = str;
     58       }
     59       break;
     60   }
     61   if (addr.empty())
     62     addr = base::StringPrintf("<Unknown address family: %d>", sa->sa_family);
     63   return addr;
     64 }
     65 
     66 int ConnectSocket(const std::string& host, uint16_t port) {
     67   std::string service = std::to_string(port);
     68   addrinfo hints = {0, AF_UNSPEC, SOCK_STREAM};
     69   addrinfo* result = nullptr;
     70   if (getaddrinfo(host.c_str(), service.c_str(), &hints, &result)) {
     71     PLOG(WARNING) << "Failed to resolve host name: " << host;
     72     return -1;
     73   }
     74 
     75   int socket_fd = -1;
     76   for (const addrinfo* info = result; info != nullptr; info = info->ai_next) {
     77     socket_fd = socket(info->ai_family, info->ai_socktype, info->ai_protocol);
     78     if (socket_fd < 0)
     79       continue;
     80 
     81     std::string addr = GetIPAddress(info->ai_addr);
     82     LOG(INFO) << "Connecting to address: " << addr;
     83     if (connect(socket_fd, info->ai_addr, info->ai_addrlen) == 0)
     84       break;  // Success.
     85 
     86     PLOG(WARNING) << "Failed to connect to address: " << addr;
     87     close(socket_fd);
     88     socket_fd = -1;
     89   }
     90 
     91   freeaddrinfo(result);
     92   return socket_fd;
     93 }
     94 
     95 void OnSuccess(const Network::OpenSslSocketCallback& callback,
     96                brillo::StreamPtr tls_stream) {
     97   callback.Run(
     98       std::unique_ptr<weave::Stream>{new SocketStream{std::move(tls_stream)}},
     99       nullptr);
    100 }
    101 
    102 void OnError(const weave::DoneCallback& callback,
    103              const brillo::Error* brillo_error) {
    104   weave::ErrorPtr error;
    105   ConvertError(*brillo_error, &error);
    106   callback.Run(std::move(error));
    107 }
    108 
    109 }  // namespace
    110 
    111 void SocketStream::Read(void* buffer,
    112                         size_t size_to_read,
    113                         const ReadCallback& callback) {
    114   brillo::ErrorPtr brillo_error;
    115   if (!ptr_->ReadAsync(
    116           buffer, size_to_read,
    117           base::Bind([](const ReadCallback& callback,
    118                         size_t size) { callback.Run(size, nullptr); },
    119                      callback),
    120           base::Bind(&OnError, base::Bind(callback, 0)), &brillo_error)) {
    121     weave::ErrorPtr error;
    122     ConvertError(*brillo_error, &error);
    123     base::MessageLoop::current()->PostTask(
    124         FROM_HERE, base::Bind(callback, 0, base::Passed(&error)));
    125   }
    126 }
    127 
    128 void SocketStream::Write(const void* buffer,
    129                          size_t size_to_write,
    130                          const WriteCallback& callback) {
    131   brillo::ErrorPtr brillo_error;
    132   if (!ptr_->WriteAllAsync(buffer, size_to_write, base::Bind(callback, nullptr),
    133                            base::Bind(&OnError, callback), &brillo_error)) {
    134     weave::ErrorPtr error;
    135     ConvertError(*brillo_error, &error);
    136     base::MessageLoop::current()->PostTask(
    137         FROM_HERE, base::Bind(callback, base::Passed(&error)));
    138   }
    139 }
    140 
    141 void SocketStream::CancelPendingOperations() {
    142   ptr_->CancelPendingAsyncOperations();
    143 }
    144 
    145 std::unique_ptr<weave::Stream> SocketStream::ConnectBlocking(
    146     const std::string& host,
    147     uint16_t port) {
    148   int socket_fd = ConnectSocket(host, port);
    149   if (socket_fd <= 0)
    150     return nullptr;
    151 
    152   auto ptr_ = brillo::FileStream::FromFileDescriptor(socket_fd, true, nullptr);
    153   if (ptr_)
    154     return std::unique_ptr<Stream>{new SocketStream{std::move(ptr_)}};
    155 
    156   close(socket_fd);
    157   return nullptr;
    158 }
    159 
    160 void SocketStream::TlsConnect(std::unique_ptr<Stream> socket,
    161                               const std::string& host,
    162                               const Network::OpenSslSocketCallback& callback) {
    163   SocketStream* stream = static_cast<SocketStream*>(socket.get());
    164   brillo::TlsStream::Connect(
    165       std::move(stream->ptr_), host, base::Bind(&OnSuccess, callback),
    166       base::Bind(&OnError, base::Bind(callback, nullptr)));
    167 }
    168 
    169 }  // namespace buffet
    170