1 // Copyright 2015 The Android Open Source Project 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include <arpa/inet.h> 16 #include <map> 17 #include <netdb.h> 18 #include <string> 19 #include <sys/socket.h> 20 #include <sys/types.h> 21 #include <unistd.h> 22 23 #include <base/bind.h> 24 #include <base/bind_helpers.h> 25 #include <base/files/file_util.h> 26 #include <base/message_loop/message_loop.h> 27 #include <base/strings/stringprintf.h> 28 #include <brillo/bind_lambda.h> 29 #include <brillo/streams/file_stream.h> 30 #include <brillo/streams/tls_stream.h> 31 32 #include "buffet/socket_stream.h" 33 #include "buffet/weave_error_conversion.h" 34 35 namespace buffet { 36 37 using weave::provider::Network; 38 39 namespace { 40 41 std::string GetIPAddress(const sockaddr* sa) { 42 std::string addr; 43 char str[INET6_ADDRSTRLEN] = {}; 44 switch (sa->sa_family) { 45 case AF_INET: 46 if (inet_ntop(AF_INET, 47 &(reinterpret_cast<const sockaddr_in*>(sa)->sin_addr), str, 48 sizeof(str))) { 49 addr = str; 50 } 51 break; 52 53 case AF_INET6: 54 if (inet_ntop(AF_INET6, 55 &(reinterpret_cast<const sockaddr_in6*>(sa)->sin6_addr), 56 str, sizeof(str))) { 57 addr = str; 58 } 59 break; 60 } 61 if (addr.empty()) 62 addr = base::StringPrintf("<Unknown address family: %d>", sa->sa_family); 63 return addr; 64 } 65 66 int ConnectSocket(const std::string& host, uint16_t port) { 67 std::string service = std::to_string(port); 68 addrinfo hints = {0, AF_UNSPEC, SOCK_STREAM}; 69 addrinfo* result = nullptr; 70 if (getaddrinfo(host.c_str(), service.c_str(), &hints, &result)) { 71 PLOG(WARNING) << "Failed to resolve host name: " << host; 72 return -1; 73 } 74 75 int socket_fd = -1; 76 for (const addrinfo* info = result; info != nullptr; info = info->ai_next) { 77 socket_fd = socket(info->ai_family, info->ai_socktype, info->ai_protocol); 78 if (socket_fd < 0) 79 continue; 80 81 std::string addr = GetIPAddress(info->ai_addr); 82 LOG(INFO) << "Connecting to address: " << addr; 83 if (connect(socket_fd, info->ai_addr, info->ai_addrlen) == 0) 84 break; // Success. 85 86 PLOG(WARNING) << "Failed to connect to address: " << addr; 87 close(socket_fd); 88 socket_fd = -1; 89 } 90 91 freeaddrinfo(result); 92 return socket_fd; 93 } 94 95 void OnSuccess(const Network::OpenSslSocketCallback& callback, 96 brillo::StreamPtr tls_stream) { 97 callback.Run( 98 std::unique_ptr<weave::Stream>{new SocketStream{std::move(tls_stream)}}, 99 nullptr); 100 } 101 102 void OnError(const weave::DoneCallback& callback, 103 const brillo::Error* brillo_error) { 104 weave::ErrorPtr error; 105 ConvertError(*brillo_error, &error); 106 callback.Run(std::move(error)); 107 } 108 109 } // namespace 110 111 void SocketStream::Read(void* buffer, 112 size_t size_to_read, 113 const ReadCallback& callback) { 114 brillo::ErrorPtr brillo_error; 115 if (!ptr_->ReadAsync( 116 buffer, size_to_read, 117 base::Bind([](const ReadCallback& callback, 118 size_t size) { callback.Run(size, nullptr); }, 119 callback), 120 base::Bind(&OnError, base::Bind(callback, 0)), &brillo_error)) { 121 weave::ErrorPtr error; 122 ConvertError(*brillo_error, &error); 123 base::MessageLoop::current()->PostTask( 124 FROM_HERE, base::Bind(callback, 0, base::Passed(&error))); 125 } 126 } 127 128 void SocketStream::Write(const void* buffer, 129 size_t size_to_write, 130 const WriteCallback& callback) { 131 brillo::ErrorPtr brillo_error; 132 if (!ptr_->WriteAllAsync(buffer, size_to_write, base::Bind(callback, nullptr), 133 base::Bind(&OnError, callback), &brillo_error)) { 134 weave::ErrorPtr error; 135 ConvertError(*brillo_error, &error); 136 base::MessageLoop::current()->PostTask( 137 FROM_HERE, base::Bind(callback, base::Passed(&error))); 138 } 139 } 140 141 void SocketStream::CancelPendingOperations() { 142 ptr_->CancelPendingAsyncOperations(); 143 } 144 145 std::unique_ptr<weave::Stream> SocketStream::ConnectBlocking( 146 const std::string& host, 147 uint16_t port) { 148 int socket_fd = ConnectSocket(host, port); 149 if (socket_fd <= 0) 150 return nullptr; 151 152 auto ptr_ = brillo::FileStream::FromFileDescriptor(socket_fd, true, nullptr); 153 if (ptr_) 154 return std::unique_ptr<Stream>{new SocketStream{std::move(ptr_)}}; 155 156 close(socket_fd); 157 return nullptr; 158 } 159 160 void SocketStream::TlsConnect(std::unique_ptr<Stream> socket, 161 const std::string& host, 162 const Network::OpenSslSocketCallback& callback) { 163 SocketStream* stream = static_cast<SocketStream*>(socket.get()); 164 brillo::TlsStream::Connect( 165 std::move(stream->ptr_), host, base::Bind(&OnSuccess, callback), 166 base::Bind(&OnError, base::Bind(callback, nullptr))); 167 } 168 169 } // namespace buffet 170