Home | History | Annotate | Download | only in tool
      1 /* Copyright (c) 2014, Google Inc.
      2  *
      3  * Permission to use, copy, modify, and/or distribute this software for any
      4  * purpose with or without fee is hereby granted, provided that the above
      5  * copyright notice and this permission notice appear in all copies.
      6  *
      7  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
      8  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
      9  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
     10  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
     11  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
     12  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
     13  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
     14 
     15 #include <openssl/base.h>
     16 
     17 #include <string>
     18 #include <vector>
     19 
     20 #include <errno.h>
     21 #include <stddef.h>
     22 #include <stdlib.h>
     23 #include <string.h>
     24 #include <sys/types.h>
     25 
     26 #if !defined(OPENSSL_WINDOWS)
     27 #include <arpa/inet.h>
     28 #include <fcntl.h>
     29 #include <netdb.h>
     30 #include <netinet/in.h>
     31 #include <sys/select.h>
     32 #include <sys/socket.h>
     33 #include <unistd.h>
     34 #else
     35 #include <io.h>
     36 #pragma warning(push, 3)
     37 #include <winsock2.h>
     38 #include <ws2tcpip.h>
     39 #pragma warning(pop)
     40 
     41 typedef int ssize_t;
     42 #pragma comment(lib, "Ws2_32.lib")
     43 #endif
     44 
     45 #include <openssl/err.h>
     46 #include <openssl/ssl.h>
     47 
     48 #include "internal.h"
     49 
     50 
     51 #if !defined(OPENSSL_WINDOWS)
     52 static int closesocket(int sock) {
     53   return close(sock);
     54 }
     55 #endif
     56 
     57 bool InitSocketLibrary() {
     58 #if defined(OPENSSL_WINDOWS)
     59   WSADATA wsaData;
     60   int err = WSAStartup(MAKEWORD(2, 2), &wsaData);
     61   if (err != 0) {
     62     fprintf(stderr, "WSAStartup failed with error %d\n", err);
     63     return false;
     64   }
     65 #endif
     66   return true;
     67 }
     68 
     69 // Connect sets |*out_sock| to be a socket connected to the destination given
     70 // in |hostname_and_port|, which should be of the form "www.example.com:123".
     71 // It returns true on success and false otherwise.
     72 bool Connect(int *out_sock, const std::string &hostname_and_port) {
     73   const size_t colon_offset = hostname_and_port.find_last_of(':');
     74   std::string hostname, port;
     75 
     76   if (colon_offset == std::string::npos) {
     77     hostname = hostname_and_port;
     78     port = "443";
     79   } else {
     80     hostname = hostname_and_port.substr(0, colon_offset);
     81     port = hostname_and_port.substr(colon_offset + 1);
     82   }
     83 
     84   struct addrinfo hint, *result;
     85   memset(&hint, 0, sizeof(hint));
     86   hint.ai_family = AF_UNSPEC;
     87   hint.ai_socktype = SOCK_STREAM;
     88 
     89   int ret = getaddrinfo(hostname.c_str(), port.c_str(), &hint, &result);
     90   if (ret != 0) {
     91     fprintf(stderr, "getaddrinfo returned: %s\n", gai_strerror(ret));
     92     return false;
     93   }
     94 
     95   bool ok = false;
     96   char buf[256];
     97 
     98   *out_sock =
     99       socket(result->ai_family, result->ai_socktype, result->ai_protocol);
    100   if (*out_sock < 0) {
    101     perror("socket");
    102     goto out;
    103   }
    104 
    105   switch (result->ai_family) {
    106     case AF_INET: {
    107       struct sockaddr_in *sin =
    108           reinterpret_cast<struct sockaddr_in *>(result->ai_addr);
    109       fprintf(stderr, "Connecting to %s:%d\n",
    110               inet_ntop(result->ai_family, &sin->sin_addr, buf, sizeof(buf)),
    111               ntohs(sin->sin_port));
    112       break;
    113     }
    114     case AF_INET6: {
    115       struct sockaddr_in6 *sin6 =
    116           reinterpret_cast<struct sockaddr_in6 *>(result->ai_addr);
    117       fprintf(stderr, "Connecting to [%s]:%d\n",
    118               inet_ntop(result->ai_family, &sin6->sin6_addr, buf, sizeof(buf)),
    119               ntohs(sin6->sin6_port));
    120       break;
    121     }
    122   }
    123 
    124   if (connect(*out_sock, result->ai_addr, result->ai_addrlen) != 0) {
    125     perror("connect");
    126     goto out;
    127   }
    128   ok = true;
    129 
    130 out:
    131   freeaddrinfo(result);
    132   return ok;
    133 }
    134 
    135 bool Accept(int *out_sock, const std::string &port) {
    136   struct sockaddr_in6 addr, cli_addr;
    137   socklen_t cli_addr_len = sizeof(cli_addr);
    138   memset(&addr, 0, sizeof(addr));
    139 
    140   addr.sin6_family = AF_INET6;
    141   addr.sin6_addr = in6addr_any;
    142   addr.sin6_port = htons(atoi(port.c_str()));
    143 
    144   bool ok = false;
    145   int server_sock = -1;
    146 
    147   server_sock =
    148       socket(addr.sin6_family, SOCK_STREAM, 0);
    149   if (server_sock < 0) {
    150     perror("socket");
    151     goto out;
    152   }
    153 
    154   if (bind(server_sock, (struct sockaddr*)&addr, sizeof(addr)) != 0) {
    155     perror("connect");
    156     goto out;
    157   }
    158   listen(server_sock, 1);
    159   *out_sock = accept(server_sock, (struct sockaddr*)&cli_addr, &cli_addr_len);
    160 
    161   ok = true;
    162 
    163 out:
    164   closesocket(server_sock);
    165   return ok;
    166 }
    167 
    168 void PrintConnectionInfo(const SSL *ssl) {
    169   const SSL_CIPHER *cipher = SSL_get_current_cipher(ssl);
    170 
    171   fprintf(stderr, "  Version: %s\n", SSL_get_version(ssl));
    172   fprintf(stderr, "  Resumed session: %s\n",
    173           SSL_session_reused(ssl) ? "yes" : "no");
    174   fprintf(stderr, "  Cipher: %s\n", SSL_CIPHER_get_name(cipher));
    175   if (SSL_CIPHER_is_ECDHE(cipher)) {
    176     fprintf(stderr, "  ECDHE curve: %s\n",
    177             SSL_get_curve_name(
    178                 SSL_SESSION_get_key_exchange_info(SSL_get_session(ssl))));
    179   }
    180   fprintf(stderr, "  Secure renegotiation: %s\n",
    181           SSL_get_secure_renegotiation_support(ssl) ? "yes" : "no");
    182 
    183   const uint8_t *next_proto;
    184   unsigned next_proto_len;
    185   SSL_get0_next_proto_negotiated(ssl, &next_proto, &next_proto_len);
    186   fprintf(stderr, "  Next protocol negotiated: %.*s\n", next_proto_len,
    187           next_proto);
    188 
    189   const uint8_t *alpn;
    190   unsigned alpn_len;
    191   SSL_get0_alpn_selected(ssl, &alpn, &alpn_len);
    192   fprintf(stderr, "  ALPN protocol: %.*s\n", alpn_len, alpn);
    193 }
    194 
    195 bool SocketSetNonBlocking(int sock, bool is_non_blocking) {
    196   bool ok;
    197 
    198 #if defined(OPENSSL_WINDOWS)
    199   u_long arg = is_non_blocking;
    200   ok = 0 == ioctlsocket(sock, FIONBIO, &arg);
    201 #else
    202   int flags = fcntl(sock, F_GETFL, 0);
    203   if (flags < 0) {
    204     return false;
    205   }
    206   if (is_non_blocking) {
    207     flags |= O_NONBLOCK;
    208   } else {
    209     flags &= ~O_NONBLOCK;
    210   }
    211   ok = 0 == fcntl(sock, F_SETFL, flags);
    212 #endif
    213   if (!ok) {
    214     fprintf(stderr, "Failed to set socket non-blocking.\n");
    215   }
    216   return ok;
    217 }
    218 
    219 // PrintErrorCallback is a callback function from OpenSSL's
    220 // |ERR_print_errors_cb| that writes errors to a given |FILE*|.
    221 int PrintErrorCallback(const char *str, size_t len, void *ctx) {
    222   fwrite(str, len, 1, reinterpret_cast<FILE*>(ctx));
    223   return 1;
    224 }
    225 
    226 bool TransferData(SSL *ssl, int sock) {
    227   bool stdin_open = true;
    228 
    229   fd_set read_fds;
    230   FD_ZERO(&read_fds);
    231 
    232   if (!SocketSetNonBlocking(sock, true)) {
    233     return false;
    234   }
    235 
    236   for (;;) {
    237     if (stdin_open) {
    238       FD_SET(0, &read_fds);
    239     }
    240     FD_SET(sock, &read_fds);
    241 
    242     int ret = select(sock + 1, &read_fds, NULL, NULL, NULL);
    243     if (ret <= 0) {
    244       perror("select");
    245       return false;
    246     }
    247 
    248     if (FD_ISSET(0, &read_fds)) {
    249       uint8_t buffer[512];
    250       ssize_t n;
    251 
    252       do {
    253         n = read(0, buffer, sizeof(buffer));
    254       } while (n == -1 && errno == EINTR);
    255 
    256       if (n == 0) {
    257         FD_CLR(0, &read_fds);
    258         stdin_open = false;
    259 #if !defined(OPENSSL_WINDOWS)
    260         shutdown(sock, SHUT_WR);
    261 #else
    262         shutdown(sock, SD_SEND);
    263 #endif
    264         continue;
    265       } else if (n < 0) {
    266         perror("read from stdin");
    267         return false;
    268       }
    269 
    270       if (!SocketSetNonBlocking(sock, false)) {
    271         return false;
    272       }
    273       int ssl_ret = SSL_write(ssl, buffer, n);
    274       if (!SocketSetNonBlocking(sock, true)) {
    275         return false;
    276       }
    277 
    278       if (ssl_ret <= 0) {
    279         int ssl_err = SSL_get_error(ssl, ssl_ret);
    280         fprintf(stderr, "Error while writing: %d\n", ssl_err);
    281         ERR_print_errors_cb(PrintErrorCallback, stderr);
    282         return false;
    283       } else if (ssl_ret != n) {
    284         fprintf(stderr, "Short write from SSL_write.\n");
    285         return false;
    286       }
    287     }
    288 
    289     if (FD_ISSET(sock, &read_fds)) {
    290       uint8_t buffer[512];
    291       int ssl_ret = SSL_read(ssl, buffer, sizeof(buffer));
    292 
    293       if (ssl_ret < 0) {
    294         int ssl_err = SSL_get_error(ssl, ssl_ret);
    295         if (ssl_err == SSL_ERROR_WANT_READ) {
    296           continue;
    297         }
    298         fprintf(stderr, "Error while reading: %d\n", ssl_err);
    299         ERR_print_errors_cb(PrintErrorCallback, stderr);
    300         return false;
    301       } else if (ssl_ret == 0) {
    302         return true;
    303       }
    304 
    305       ssize_t n;
    306       do {
    307         n = write(1, buffer, ssl_ret);
    308       } while (n == -1 && errno == EINTR);
    309 
    310       if (n != ssl_ret) {
    311         fprintf(stderr, "Short write to stderr.\n");
    312         return false;
    313       }
    314     }
    315   }
    316 }
    317