Home | History | Annotate | Download | only in tool
      1 /* Copyright (c) 2014, Google Inc.
      2  *
      3  * Permission to use, copy, modify, and/or distribute this software for any
      4  * purpose with or without fee is hereby granted, provided that the above
      5  * copyright notice and this permission notice appear in all copies.
      6  *
      7  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
      8  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
      9  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
     10  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
     11  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
     12  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
     13  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
     14 
     15 #include <openssl/base.h>
     16 
     17 #include <string>
     18 #include <vector>
     19 
     20 #include <errno.h>
     21 #include <stdlib.h>
     22 #include <sys/types.h>
     23 #include <sys/socket.h>
     24 
     25 #if !defined(OPENSSL_WINDOWS)
     26 #include <arpa/inet.h>
     27 #include <fcntl.h>
     28 #include <netdb.h>
     29 #include <sys/select.h>
     30 #include <unistd.h>
     31 #else
     32 #include <WinSock2.h>
     33 #include <WS2tcpip.h>
     34 typedef int socklen_t;
     35 #endif
     36 
     37 #include <openssl/err.h>
     38 #include <openssl/ssl.h>
     39 
     40 #include "internal.h"
     41 
     42 
     43 static const struct argument kArguments[] = {
     44     {
     45      "-connect", true,
     46      "The hostname and port of the server to connect to, e.g. foo.com:443",
     47     },
     48     {
     49      "", false, "",
     50     },
     51 };
     52 
     53 // Connect sets |*out_sock| to be a socket connected to the destination given
     54 // in |hostname_and_port|, which should be of the form "www.example.com:123".
     55 // It returns true on success and false otherwise.
     56 static bool Connect(int *out_sock, const std::string &hostname_and_port) {
     57   const size_t colon_offset = hostname_and_port.find_last_of(':');
     58   std::string hostname, port;
     59 
     60   if (colon_offset == std::string::npos) {
     61     hostname = hostname_and_port;
     62     port = "443";
     63   } else {
     64     hostname = hostname_and_port.substr(0, colon_offset);
     65     port = hostname_and_port.substr(colon_offset + 1);
     66   }
     67 
     68   struct addrinfo hint, *result;
     69   memset(&hint, 0, sizeof(hint));
     70   hint.ai_family = AF_UNSPEC;
     71   hint.ai_socktype = SOCK_STREAM;
     72 
     73   int ret = getaddrinfo(hostname.c_str(), port.c_str(), &hint, &result);
     74   if (ret != 0) {
     75     fprintf(stderr, "getaddrinfo returned: %s\n", gai_strerror(ret));
     76     return false;
     77   }
     78 
     79   bool ok = false;
     80   char buf[256];
     81 
     82   *out_sock =
     83       socket(result->ai_family, result->ai_socktype, result->ai_protocol);
     84   if (*out_sock < 0) {
     85     perror("socket");
     86     goto out;
     87   }
     88 
     89   switch (result->ai_family) {
     90     case AF_INET: {
     91       struct sockaddr_in *sin =
     92           reinterpret_cast<struct sockaddr_in *>(result->ai_addr);
     93       fprintf(stderr, "Connecting to %s:%d\n",
     94               inet_ntop(result->ai_family, &sin->sin_addr, buf, sizeof(buf)),
     95               ntohs(sin->sin_port));
     96       break;
     97     }
     98     case AF_INET6: {
     99       struct sockaddr_in6 *sin6 =
    100           reinterpret_cast<struct sockaddr_in6 *>(result->ai_addr);
    101       fprintf(stderr, "Connecting to [%s]:%d\n",
    102               inet_ntop(result->ai_family, &sin6->sin6_addr, buf, sizeof(buf)),
    103               ntohs(sin6->sin6_port));
    104       break;
    105     }
    106   }
    107 
    108   if (connect(*out_sock, result->ai_addr, result->ai_addrlen) != 0) {
    109     perror("connect");
    110     goto out;
    111   }
    112   ok = true;
    113 
    114 out:
    115   freeaddrinfo(result);
    116   return ok;
    117 }
    118 
    119 static void PrintConnectionInfo(const SSL *ssl) {
    120   const SSL_CIPHER *cipher = SSL_get_current_cipher(ssl);
    121 
    122   fprintf(stderr, "  Version: %s\n", SSL_get_version(ssl));
    123   fprintf(stderr, "  Cipher: %s\n", SSL_CIPHER_get_name(cipher));
    124   fprintf(stderr, "  Secure renegotiation: %s\n",
    125           SSL_get_secure_renegotiation_support(ssl) ? "yes" : "no");
    126 }
    127 
    128 static bool SocketSetNonBlocking(int sock, bool is_non_blocking) {
    129   bool ok;
    130 
    131 #if defined(OPENSSL_WINDOWS)
    132   u_long arg = is_non_blocking;
    133   ok = 0 == ioctlsocket(sock, FIOBIO, &arg);
    134 #else
    135   int flags = fcntl(sock, F_GETFL, 0);
    136   if (flags < 0) {
    137     return false;
    138   }
    139   if (is_non_blocking) {
    140     flags |= O_NONBLOCK;
    141   } else {
    142     flags &= ~O_NONBLOCK;
    143   }
    144   ok = 0 == fcntl(sock, F_SETFL, flags);
    145 #endif
    146   if (!ok) {
    147     fprintf(stderr, "Failed to set socket non-blocking.\n");
    148   }
    149   return ok;
    150 }
    151 
    152 // PrintErrorCallback is a callback function from OpenSSL's
    153 // |ERR_print_errors_cb| that writes errors to a given |FILE*|.
    154 static int PrintErrorCallback(const char *str, size_t len, void *ctx) {
    155   fwrite(str, len, 1, reinterpret_cast<FILE*>(ctx));
    156   return 1;
    157 }
    158 
    159 bool TransferData(SSL *ssl, int sock) {
    160   bool stdin_open = true;
    161 
    162   fd_set read_fds;
    163   FD_ZERO(&read_fds);
    164 
    165   if (!SocketSetNonBlocking(sock, true)) {
    166     return false;
    167   }
    168 
    169   for (;;) {
    170     if (stdin_open) {
    171       FD_SET(0, &read_fds);
    172     }
    173     FD_SET(sock, &read_fds);
    174 
    175     int ret = select(sock + 1, &read_fds, NULL, NULL, NULL);
    176     if (ret <= 0) {
    177       perror("select");
    178       return false;
    179     }
    180 
    181     if (FD_ISSET(0, &read_fds)) {
    182       uint8_t buffer[512];
    183       ssize_t n;
    184 
    185       do {
    186         n = read(0, buffer, sizeof(buffer));
    187       } while (n == -1 && errno == EINTR);
    188 
    189       if (n == 0) {
    190         FD_CLR(0, &read_fds);
    191         stdin_open = false;
    192         shutdown(sock, SHUT_WR);
    193         continue;
    194       } else if (n < 0) {
    195         perror("read from stdin");
    196         return false;
    197       }
    198 
    199       if (!SocketSetNonBlocking(sock, false)) {
    200         return false;
    201       }
    202       int ssl_ret = SSL_write(ssl, buffer, n);
    203       if (!SocketSetNonBlocking(sock, true)) {
    204         return false;
    205       }
    206 
    207       if (ssl_ret <= 0) {
    208         int ssl_err = SSL_get_error(ssl, ssl_ret);
    209         fprintf(stderr, "Error while writing: %d\n", ssl_err);
    210         ERR_print_errors_cb(PrintErrorCallback, stderr);
    211         return false;
    212       } else if (ssl_ret != n) {
    213         fprintf(stderr, "Short write from SSL_write.\n");
    214         return false;
    215       }
    216     }
    217 
    218     if (FD_ISSET(sock, &read_fds)) {
    219       uint8_t buffer[512];
    220       int ssl_ret = SSL_read(ssl, buffer, sizeof(buffer));
    221 
    222       if (ssl_ret < 0) {
    223         int ssl_err = SSL_get_error(ssl, ssl_ret);
    224         if (ssl_err == SSL_ERROR_WANT_READ) {
    225           continue;
    226         }
    227         fprintf(stderr, "Error while reading: %d\n", ssl_err);
    228         ERR_print_errors_cb(PrintErrorCallback, stderr);
    229         return false;
    230       } else if (ssl_ret == 0) {
    231         return true;
    232       }
    233 
    234       ssize_t n;
    235       do {
    236         n = write(1, buffer, ssl_ret);
    237       } while (n == -1 && errno == EINTR);
    238 
    239       if (n != ssl_ret) {
    240         fprintf(stderr, "Short write to stderr.\n");
    241         return false;
    242       }
    243     }
    244   }
    245 }
    246 
    247 bool Client(const std::vector<std::string> &args) {
    248   std::map<std::string, std::string> args_map;
    249 
    250   if (!ParseKeyValueArguments(&args_map, args, kArguments)) {
    251     PrintUsage(kArguments);
    252     return false;
    253   }
    254 
    255   SSL_CTX *ctx = SSL_CTX_new(SSLv23_client_method());
    256 
    257   const char *keylog_file = getenv("SSLKEYLOGFILE");
    258   if (keylog_file) {
    259     BIO *keylog_bio = BIO_new_file(keylog_file, "a");
    260     if (!keylog_bio) {
    261       ERR_print_errors_cb(PrintErrorCallback, stderr);
    262       return false;
    263     }
    264     SSL_CTX_set_keylog_bio(ctx, keylog_bio);
    265   }
    266 
    267   int sock = -1;
    268   if (!Connect(&sock, args_map["-connect"])) {
    269     return false;
    270   }
    271 
    272   BIO *bio = BIO_new_socket(sock, BIO_CLOSE);
    273   SSL *ssl = SSL_new(ctx);
    274   SSL_set_bio(ssl, bio, bio);
    275 
    276   int ret = SSL_connect(ssl);
    277   if (ret != 1) {
    278     int ssl_err = SSL_get_error(ssl, ret);
    279     fprintf(stderr, "Error while connecting: %d\n", ssl_err);
    280     ERR_print_errors_cb(PrintErrorCallback, stderr);
    281     return false;
    282   }
    283 
    284   fprintf(stderr, "Connected.\n");
    285   PrintConnectionInfo(ssl);
    286 
    287   bool ok = TransferData(ssl, sock);
    288 
    289   SSL_free(ssl);
    290   SSL_CTX_free(ctx);
    291   return ok;
    292 }
    293