1 /* Copyright (c) 2014, Google Inc. 2 * 3 * Permission to use, copy, modify, and/or distribute this software for any 4 * purpose with or without fee is hereby granted, provided that the above 5 * copyright notice and this permission notice appear in all copies. 6 * 7 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY 10 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION 12 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN 13 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ 14 15 #include <openssl/base.h> 16 17 #include <string> 18 #include <vector> 19 20 #include <errno.h> 21 #include <stddef.h> 22 #include <stdlib.h> 23 #include <string.h> 24 #include <sys/types.h> 25 26 #if !defined(OPENSSL_WINDOWS) 27 #include <arpa/inet.h> 28 #include <fcntl.h> 29 #include <netdb.h> 30 #include <netinet/in.h> 31 #include <sys/select.h> 32 #include <sys/socket.h> 33 #include <unistd.h> 34 #else 35 #include <io.h> 36 #pragma warning(push, 3) 37 #include <winsock2.h> 38 #include <ws2tcpip.h> 39 #pragma warning(pop) 40 41 typedef int ssize_t; 42 #pragma comment(lib, "Ws2_32.lib") 43 #endif 44 45 #include <openssl/err.h> 46 #include <openssl/ssl.h> 47 48 #include "internal.h" 49 50 51 #if !defined(OPENSSL_WINDOWS) 52 static int closesocket(int sock) { 53 return close(sock); 54 } 55 #endif 56 57 bool InitSocketLibrary() { 58 #if defined(OPENSSL_WINDOWS) 59 WSADATA wsaData; 60 int err = WSAStartup(MAKEWORD(2, 2), &wsaData); 61 if (err != 0) { 62 fprintf(stderr, "WSAStartup failed with error %d\n", err); 63 return false; 64 } 65 #endif 66 return true; 67 } 68 69 // Connect sets |*out_sock| to be a socket connected to the destination given 70 // in |hostname_and_port|, which should be of the form "www.example.com:123". 71 // It returns true on success and false otherwise. 72 bool Connect(int *out_sock, const std::string &hostname_and_port) { 73 const size_t colon_offset = hostname_and_port.find_last_of(':'); 74 std::string hostname, port; 75 76 if (colon_offset == std::string::npos) { 77 hostname = hostname_and_port; 78 port = "443"; 79 } else { 80 hostname = hostname_and_port.substr(0, colon_offset); 81 port = hostname_and_port.substr(colon_offset + 1); 82 } 83 84 struct addrinfo hint, *result; 85 memset(&hint, 0, sizeof(hint)); 86 hint.ai_family = AF_UNSPEC; 87 hint.ai_socktype = SOCK_STREAM; 88 89 int ret = getaddrinfo(hostname.c_str(), port.c_str(), &hint, &result); 90 if (ret != 0) { 91 fprintf(stderr, "getaddrinfo returned: %s\n", gai_strerror(ret)); 92 return false; 93 } 94 95 bool ok = false; 96 char buf[256]; 97 98 *out_sock = 99 socket(result->ai_family, result->ai_socktype, result->ai_protocol); 100 if (*out_sock < 0) { 101 perror("socket"); 102 goto out; 103 } 104 105 switch (result->ai_family) { 106 case AF_INET: { 107 struct sockaddr_in *sin = 108 reinterpret_cast<struct sockaddr_in *>(result->ai_addr); 109 fprintf(stderr, "Connecting to %s:%d\n", 110 inet_ntop(result->ai_family, &sin->sin_addr, buf, sizeof(buf)), 111 ntohs(sin->sin_port)); 112 break; 113 } 114 case AF_INET6: { 115 struct sockaddr_in6 *sin6 = 116 reinterpret_cast<struct sockaddr_in6 *>(result->ai_addr); 117 fprintf(stderr, "Connecting to [%s]:%d\n", 118 inet_ntop(result->ai_family, &sin6->sin6_addr, buf, sizeof(buf)), 119 ntohs(sin6->sin6_port)); 120 break; 121 } 122 } 123 124 if (connect(*out_sock, result->ai_addr, result->ai_addrlen) != 0) { 125 perror("connect"); 126 goto out; 127 } 128 ok = true; 129 130 out: 131 freeaddrinfo(result); 132 return ok; 133 } 134 135 bool Accept(int *out_sock, const std::string &port) { 136 struct sockaddr_in6 addr, cli_addr; 137 socklen_t cli_addr_len = sizeof(cli_addr); 138 memset(&addr, 0, sizeof(addr)); 139 140 addr.sin6_family = AF_INET6; 141 addr.sin6_addr = in6addr_any; 142 addr.sin6_port = htons(atoi(port.c_str())); 143 144 bool ok = false; 145 int server_sock = -1; 146 147 server_sock = 148 socket(addr.sin6_family, SOCK_STREAM, 0); 149 if (server_sock < 0) { 150 perror("socket"); 151 goto out; 152 } 153 154 if (bind(server_sock, (struct sockaddr*)&addr, sizeof(addr)) != 0) { 155 perror("connect"); 156 goto out; 157 } 158 listen(server_sock, 1); 159 *out_sock = accept(server_sock, (struct sockaddr*)&cli_addr, &cli_addr_len); 160 161 ok = true; 162 163 out: 164 closesocket(server_sock); 165 return ok; 166 } 167 168 void PrintConnectionInfo(const SSL *ssl) { 169 const SSL_CIPHER *cipher = SSL_get_current_cipher(ssl); 170 171 fprintf(stderr, " Version: %s\n", SSL_get_version(ssl)); 172 fprintf(stderr, " Resumed session: %s\n", 173 SSL_session_reused(ssl) ? "yes" : "no"); 174 fprintf(stderr, " Cipher: %s\n", SSL_CIPHER_get_name(cipher)); 175 if (SSL_CIPHER_is_ECDHE(cipher)) { 176 fprintf(stderr, " ECDHE curve: %s\n", 177 SSL_get_curve_name( 178 SSL_SESSION_get_key_exchange_info(SSL_get_session(ssl)))); 179 } 180 fprintf(stderr, " Secure renegotiation: %s\n", 181 SSL_get_secure_renegotiation_support(ssl) ? "yes" : "no"); 182 183 const uint8_t *next_proto; 184 unsigned next_proto_len; 185 SSL_get0_next_proto_negotiated(ssl, &next_proto, &next_proto_len); 186 fprintf(stderr, " Next protocol negotiated: %.*s\n", next_proto_len, 187 next_proto); 188 189 const uint8_t *alpn; 190 unsigned alpn_len; 191 SSL_get0_alpn_selected(ssl, &alpn, &alpn_len); 192 fprintf(stderr, " ALPN protocol: %.*s\n", alpn_len, alpn); 193 } 194 195 bool SocketSetNonBlocking(int sock, bool is_non_blocking) { 196 bool ok; 197 198 #if defined(OPENSSL_WINDOWS) 199 u_long arg = is_non_blocking; 200 ok = 0 == ioctlsocket(sock, FIONBIO, &arg); 201 #else 202 int flags = fcntl(sock, F_GETFL, 0); 203 if (flags < 0) { 204 return false; 205 } 206 if (is_non_blocking) { 207 flags |= O_NONBLOCK; 208 } else { 209 flags &= ~O_NONBLOCK; 210 } 211 ok = 0 == fcntl(sock, F_SETFL, flags); 212 #endif 213 if (!ok) { 214 fprintf(stderr, "Failed to set socket non-blocking.\n"); 215 } 216 return ok; 217 } 218 219 // PrintErrorCallback is a callback function from OpenSSL's 220 // |ERR_print_errors_cb| that writes errors to a given |FILE*|. 221 int PrintErrorCallback(const char *str, size_t len, void *ctx) { 222 fwrite(str, len, 1, reinterpret_cast<FILE*>(ctx)); 223 return 1; 224 } 225 226 bool TransferData(SSL *ssl, int sock) { 227 bool stdin_open = true; 228 229 fd_set read_fds; 230 FD_ZERO(&read_fds); 231 232 if (!SocketSetNonBlocking(sock, true)) { 233 return false; 234 } 235 236 for (;;) { 237 if (stdin_open) { 238 FD_SET(0, &read_fds); 239 } 240 FD_SET(sock, &read_fds); 241 242 int ret = select(sock + 1, &read_fds, NULL, NULL, NULL); 243 if (ret <= 0) { 244 perror("select"); 245 return false; 246 } 247 248 if (FD_ISSET(0, &read_fds)) { 249 uint8_t buffer[512]; 250 ssize_t n; 251 252 do { 253 n = read(0, buffer, sizeof(buffer)); 254 } while (n == -1 && errno == EINTR); 255 256 if (n == 0) { 257 FD_CLR(0, &read_fds); 258 stdin_open = false; 259 #if !defined(OPENSSL_WINDOWS) 260 shutdown(sock, SHUT_WR); 261 #else 262 shutdown(sock, SD_SEND); 263 #endif 264 continue; 265 } else if (n < 0) { 266 perror("read from stdin"); 267 return false; 268 } 269 270 if (!SocketSetNonBlocking(sock, false)) { 271 return false; 272 } 273 int ssl_ret = SSL_write(ssl, buffer, n); 274 if (!SocketSetNonBlocking(sock, true)) { 275 return false; 276 } 277 278 if (ssl_ret <= 0) { 279 int ssl_err = SSL_get_error(ssl, ssl_ret); 280 fprintf(stderr, "Error while writing: %d\n", ssl_err); 281 ERR_print_errors_cb(PrintErrorCallback, stderr); 282 return false; 283 } else if (ssl_ret != n) { 284 fprintf(stderr, "Short write from SSL_write.\n"); 285 return false; 286 } 287 } 288 289 if (FD_ISSET(sock, &read_fds)) { 290 uint8_t buffer[512]; 291 int ssl_ret = SSL_read(ssl, buffer, sizeof(buffer)); 292 293 if (ssl_ret < 0) { 294 int ssl_err = SSL_get_error(ssl, ssl_ret); 295 if (ssl_err == SSL_ERROR_WANT_READ) { 296 continue; 297 } 298 fprintf(stderr, "Error while reading: %d\n", ssl_err); 299 ERR_print_errors_cb(PrintErrorCallback, stderr); 300 return false; 301 } else if (ssl_ret == 0) { 302 return true; 303 } 304 305 ssize_t n; 306 do { 307 n = write(1, buffer, ssl_ret); 308 } while (n == -1 && errno == EINTR); 309 310 if (n != ssl_ret) { 311 fprintf(stderr, "Short write to stderr.\n"); 312 return false; 313 } 314 } 315 } 316 } 317