Home | History | Annotate | Download | only in test
      1 /* Copyright (c) 2014, Google Inc.
      2  *
      3  * Permission to use, copy, modify, and/or distribute this software for any
      4  * purpose with or without fee is hereby granted, provided that the above
      5  * copyright notice and this permission notice appear in all copies.
      6  *
      7  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
      8  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
      9  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
     10  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
     11  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
     12  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
     13  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
     14 
     15 #include <openssl/base.h>
     16 
     17 #if !defined(OPENSSL_WINDOWS)
     18 #include <arpa/inet.h>
     19 #include <netinet/in.h>
     20 #include <netinet/tcp.h>
     21 #include <signal.h>
     22 #include <sys/socket.h>
     23 #include <sys/types.h>
     24 #include <unistd.h>
     25 #else
     26 #include <io.h>
     27 #pragma warning(push, 3)
     28 #include <winsock2.h>
     29 #include <ws2tcpip.h>
     30 #pragma warning(pop)
     31 
     32 #pragma comment(lib, "Ws2_32.lib")
     33 #endif
     34 
     35 #include <string.h>
     36 #include <sys/types.h>
     37 
     38 #include <openssl/bio.h>
     39 #include <openssl/buf.h>
     40 #include <openssl/bytestring.h>
     41 #include <openssl/err.h>
     42 #include <openssl/ssl.h>
     43 
     44 #include <memory>
     45 #include <vector>
     46 
     47 #include "../../crypto/test/scoped_types.h"
     48 #include "async_bio.h"
     49 #include "packeted_bio.h"
     50 #include "scoped_types.h"
     51 #include "test_config.h"
     52 
     53 
     54 #if !defined(OPENSSL_WINDOWS)
     55 static int closesocket(int sock) {
     56   return close(sock);
     57 }
     58 
     59 static void PrintSocketError(const char *func) {
     60   perror(func);
     61 }
     62 #else
     63 static void PrintSocketError(const char *func) {
     64   fprintf(stderr, "%s: %d\n", func, WSAGetLastError());
     65 }
     66 #endif
     67 
     68 static int Usage(const char *program) {
     69   fprintf(stderr, "Usage: %s [flags...]\n", program);
     70   return 1;
     71 }
     72 
     73 struct TestState {
     74   TestState() {
     75     // MSVC cannot initialize these inline.
     76     memset(&clock, 0, sizeof(clock));
     77     memset(&clock_delta, 0, sizeof(clock_delta));
     78   }
     79 
     80   // async_bio is async BIO which pauses reads and writes.
     81   BIO *async_bio = nullptr;
     82   // clock is the current time for the SSL connection.
     83   timeval clock;
     84   // clock_delta is how far the clock advanced in the most recent failed
     85   // |BIO_read|.
     86   timeval clock_delta;
     87   ScopedEVP_PKEY channel_id;
     88   bool cert_ready = false;
     89   ScopedSSL_SESSION session;
     90   ScopedSSL_SESSION pending_session;
     91   bool early_callback_called = false;
     92   bool handshake_done = false;
     93 };
     94 
     95 static void TestStateExFree(void *parent, void *ptr, CRYPTO_EX_DATA *ad,
     96 			    int index, long argl, void *argp) {
     97   delete ((TestState *)ptr);
     98 }
     99 
    100 static int g_config_index = 0;
    101 static int g_state_index = 0;
    102 
    103 static bool SetConfigPtr(SSL *ssl, const TestConfig *config) {
    104   return SSL_set_ex_data(ssl, g_config_index, (void *)config) == 1;
    105 }
    106 
    107 static const TestConfig *GetConfigPtr(const SSL *ssl) {
    108   return (const TestConfig *)SSL_get_ex_data(ssl, g_config_index);
    109 }
    110 
    111 static bool SetTestState(SSL *ssl, std::unique_ptr<TestState> async) {
    112   if (SSL_set_ex_data(ssl, g_state_index, (void *)async.get()) == 1) {
    113     async.release();
    114     return true;
    115   }
    116   return false;
    117 }
    118 
    119 static TestState *GetTestState(const SSL *ssl) {
    120   return (TestState *)SSL_get_ex_data(ssl, g_state_index);
    121 }
    122 
    123 static ScopedEVP_PKEY LoadPrivateKey(const std::string &file) {
    124   ScopedBIO bio(BIO_new(BIO_s_file()));
    125   if (!bio || !BIO_read_filename(bio.get(), file.c_str())) {
    126     return nullptr;
    127   }
    128   ScopedEVP_PKEY pkey(PEM_read_bio_PrivateKey(bio.get(), NULL, NULL, NULL));
    129   return pkey;
    130 }
    131 
    132 static bool InstallCertificate(SSL *ssl) {
    133   const TestConfig *config = GetConfigPtr(ssl);
    134   if (!config->key_file.empty() &&
    135       !SSL_use_PrivateKey_file(ssl, config->key_file.c_str(),
    136                                SSL_FILETYPE_PEM)) {
    137     return false;
    138   }
    139   if (!config->cert_file.empty() &&
    140       !SSL_use_certificate_file(ssl, config->cert_file.c_str(),
    141                                 SSL_FILETYPE_PEM)) {
    142     return false;
    143   }
    144   return true;
    145 }
    146 
    147 static int SelectCertificateCallback(const struct ssl_early_callback_ctx *ctx) {
    148   const TestConfig *config = GetConfigPtr(ctx->ssl);
    149   GetTestState(ctx->ssl)->early_callback_called = true;
    150 
    151   if (!config->expected_server_name.empty()) {
    152     const uint8_t *extension_data;
    153     size_t extension_len;
    154     CBS extension, server_name_list, host_name;
    155     uint8_t name_type;
    156 
    157     if (!SSL_early_callback_ctx_extension_get(ctx, TLSEXT_TYPE_server_name,
    158                                               &extension_data,
    159                                               &extension_len)) {
    160       fprintf(stderr, "Could not find server_name extension.\n");
    161       return -1;
    162     }
    163 
    164     CBS_init(&extension, extension_data, extension_len);
    165     if (!CBS_get_u16_length_prefixed(&extension, &server_name_list) ||
    166         CBS_len(&extension) != 0 ||
    167         !CBS_get_u8(&server_name_list, &name_type) ||
    168         name_type != TLSEXT_NAMETYPE_host_name ||
    169         !CBS_get_u16_length_prefixed(&server_name_list, &host_name) ||
    170         CBS_len(&server_name_list) != 0) {
    171       fprintf(stderr, "Could not decode server_name extension.\n");
    172       return -1;
    173     }
    174 
    175     if (!CBS_mem_equal(&host_name,
    176                        (const uint8_t*)config->expected_server_name.data(),
    177                        config->expected_server_name.size())) {
    178       fprintf(stderr, "Server name mismatch.\n");
    179     }
    180   }
    181 
    182   if (config->fail_early_callback) {
    183     return -1;
    184   }
    185 
    186   // Install the certificate in the early callback.
    187   if (config->use_early_callback) {
    188     if (config->async) {
    189       // Install the certificate asynchronously.
    190       return 0;
    191     }
    192     if (!InstallCertificate(ctx->ssl)) {
    193       return -1;
    194     }
    195   }
    196   return 1;
    197 }
    198 
    199 static int SkipVerify(int preverify_ok, X509_STORE_CTX *store_ctx) {
    200   return 1;
    201 }
    202 
    203 static int NextProtosAdvertisedCallback(SSL *ssl, const uint8_t **out,
    204                                         unsigned int *out_len, void *arg) {
    205   const TestConfig *config = GetConfigPtr(ssl);
    206   if (config->advertise_npn.empty()) {
    207     return SSL_TLSEXT_ERR_NOACK;
    208   }
    209 
    210   *out = (const uint8_t*)config->advertise_npn.data();
    211   *out_len = config->advertise_npn.size();
    212   return SSL_TLSEXT_ERR_OK;
    213 }
    214 
    215 static int NextProtoSelectCallback(SSL* ssl, uint8_t** out, uint8_t* outlen,
    216                                    const uint8_t* in, unsigned inlen, void* arg) {
    217   const TestConfig *config = GetConfigPtr(ssl);
    218   if (config->select_next_proto.empty()) {
    219     return SSL_TLSEXT_ERR_NOACK;
    220   }
    221 
    222   *out = (uint8_t*)config->select_next_proto.data();
    223   *outlen = config->select_next_proto.size();
    224   return SSL_TLSEXT_ERR_OK;
    225 }
    226 
    227 static int AlpnSelectCallback(SSL* ssl, const uint8_t** out, uint8_t* outlen,
    228                               const uint8_t* in, unsigned inlen, void* arg) {
    229   const TestConfig *config = GetConfigPtr(ssl);
    230   if (config->select_alpn.empty()) {
    231     return SSL_TLSEXT_ERR_NOACK;
    232   }
    233 
    234   if (!config->expected_advertised_alpn.empty() &&
    235       (config->expected_advertised_alpn.size() != inlen ||
    236        memcmp(config->expected_advertised_alpn.data(),
    237               in, inlen) != 0)) {
    238     fprintf(stderr, "bad ALPN select callback inputs\n");
    239     exit(1);
    240   }
    241 
    242   *out = (const uint8_t*)config->select_alpn.data();
    243   *outlen = config->select_alpn.size();
    244   return SSL_TLSEXT_ERR_OK;
    245 }
    246 
    247 static unsigned PskClientCallback(SSL *ssl, const char *hint,
    248                                   char *out_identity,
    249                                   unsigned max_identity_len,
    250                                   uint8_t *out_psk, unsigned max_psk_len) {
    251   const TestConfig *config = GetConfigPtr(ssl);
    252 
    253   if (strcmp(hint ? hint : "", config->psk_identity.c_str()) != 0) {
    254     fprintf(stderr, "Server PSK hint did not match.\n");
    255     return 0;
    256   }
    257 
    258   // Account for the trailing '\0' for the identity.
    259   if (config->psk_identity.size() >= max_identity_len ||
    260       config->psk.size() > max_psk_len) {
    261     fprintf(stderr, "PSK buffers too small\n");
    262     return 0;
    263   }
    264 
    265   BUF_strlcpy(out_identity, config->psk_identity.c_str(),
    266               max_identity_len);
    267   memcpy(out_psk, config->psk.data(), config->psk.size());
    268   return config->psk.size();
    269 }
    270 
    271 static unsigned PskServerCallback(SSL *ssl, const char *identity,
    272                                   uint8_t *out_psk, unsigned max_psk_len) {
    273   const TestConfig *config = GetConfigPtr(ssl);
    274 
    275   if (strcmp(identity, config->psk_identity.c_str()) != 0) {
    276     fprintf(stderr, "Client PSK identity did not match.\n");
    277     return 0;
    278   }
    279 
    280   if (config->psk.size() > max_psk_len) {
    281     fprintf(stderr, "PSK buffers too small\n");
    282     return 0;
    283   }
    284 
    285   memcpy(out_psk, config->psk.data(), config->psk.size());
    286   return config->psk.size();
    287 }
    288 
    289 static void CurrentTimeCallback(const SSL *ssl, timeval *out_clock) {
    290   *out_clock = GetTestState(ssl)->clock;
    291 }
    292 
    293 static void ChannelIdCallback(SSL *ssl, EVP_PKEY **out_pkey) {
    294   *out_pkey = GetTestState(ssl)->channel_id.release();
    295 }
    296 
    297 static int CertCallback(SSL *ssl, void *arg) {
    298   if (!GetTestState(ssl)->cert_ready) {
    299     return -1;
    300   }
    301   if (!InstallCertificate(ssl)) {
    302     return 0;
    303   }
    304   return 1;
    305 }
    306 
    307 static SSL_SESSION *GetSessionCallback(SSL *ssl, uint8_t *data, int len,
    308                                        int *copy) {
    309   TestState *async_state = GetTestState(ssl);
    310   if (async_state->session) {
    311     *copy = 0;
    312     return async_state->session.release();
    313   } else if (async_state->pending_session) {
    314     return SSL_magic_pending_session_ptr();
    315   } else {
    316     return NULL;
    317   }
    318 }
    319 
    320 static int DDoSCallback(const struct ssl_early_callback_ctx *early_context) {
    321   const TestConfig *config = GetConfigPtr(early_context->ssl);
    322   static int callback_num = 0;
    323 
    324   callback_num++;
    325   if (config->fail_ddos_callback ||
    326       (config->fail_second_ddos_callback && callback_num == 2)) {
    327     return 0;
    328   }
    329   return 1;
    330 }
    331 
    332 static void InfoCallback(const SSL *ssl, int type, int val) {
    333   if (type == SSL_CB_HANDSHAKE_DONE) {
    334     if (GetConfigPtr(ssl)->handshake_never_done) {
    335       fprintf(stderr, "handshake completed\n");
    336       // Abort before any expected error code is printed, to ensure the overall
    337       // test fails.
    338       abort();
    339     }
    340     GetTestState(ssl)->handshake_done = true;
    341   }
    342 }
    343 
    344 // Connect returns a new socket connected to localhost on |port| or -1 on
    345 // error.
    346 static int Connect(uint16_t port) {
    347   int sock = socket(AF_INET, SOCK_STREAM, 0);
    348   if (sock == -1) {
    349     PrintSocketError("socket");
    350     return -1;
    351   }
    352   int nodelay = 1;
    353   if (setsockopt(sock, IPPROTO_TCP, TCP_NODELAY,
    354           reinterpret_cast<const char*>(&nodelay), sizeof(nodelay)) != 0) {
    355     PrintSocketError("setsockopt");
    356     closesocket(sock);
    357     return -1;
    358   }
    359   sockaddr_in sin;
    360   memset(&sin, 0, sizeof(sin));
    361   sin.sin_family = AF_INET;
    362   sin.sin_port = htons(port);
    363   if (!inet_pton(AF_INET, "127.0.0.1", &sin.sin_addr)) {
    364     PrintSocketError("inet_pton");
    365     closesocket(sock);
    366     return -1;
    367   }
    368   if (connect(sock, reinterpret_cast<const sockaddr*>(&sin),
    369               sizeof(sin)) != 0) {
    370     PrintSocketError("connect");
    371     closesocket(sock);
    372     return -1;
    373   }
    374   return sock;
    375 }
    376 
    377 class SocketCloser {
    378  public:
    379   explicit SocketCloser(int sock) : sock_(sock) {}
    380   ~SocketCloser() {
    381     // Half-close and drain the socket before releasing it. This seems to be
    382     // necessary for graceful shutdown on Windows. It will also avoid write
    383     // failures in the test runner.
    384 #if defined(OPENSSL_WINDOWS)
    385     shutdown(sock_, SD_SEND);
    386 #else
    387     shutdown(sock_, SHUT_WR);
    388 #endif
    389     while (true) {
    390       char buf[1024];
    391       if (recv(sock_, buf, sizeof(buf), 0) <= 0) {
    392         break;
    393       }
    394     }
    395     closesocket(sock_);
    396   }
    397 
    398  private:
    399   const int sock_;
    400 };
    401 
    402 static ScopedSSL_CTX SetupCtx(const TestConfig *config) {
    403   ScopedSSL_CTX ssl_ctx(SSL_CTX_new(
    404       config->is_dtls ? DTLS_method() : TLS_method()));
    405   if (!ssl_ctx) {
    406     return nullptr;
    407   }
    408 
    409   if (!SSL_CTX_set_cipher_list(ssl_ctx.get(), "ALL")) {
    410     return nullptr;
    411   }
    412 
    413   ScopedDH dh(DH_get_2048_256(NULL));
    414   if (!dh || !SSL_CTX_set_tmp_dh(ssl_ctx.get(), dh.get())) {
    415     return nullptr;
    416   }
    417 
    418   if (config->async && config->is_server) {
    419     // Disable the internal session cache. To test asynchronous session lookup,
    420     // we use an external session cache.
    421     SSL_CTX_set_session_cache_mode(
    422         ssl_ctx.get(), SSL_SESS_CACHE_BOTH | SSL_SESS_CACHE_NO_INTERNAL);
    423     SSL_CTX_sess_set_get_cb(ssl_ctx.get(), GetSessionCallback);
    424   } else {
    425     SSL_CTX_set_session_cache_mode(ssl_ctx.get(), SSL_SESS_CACHE_BOTH);
    426   }
    427 
    428   ssl_ctx->select_certificate_cb = SelectCertificateCallback;
    429 
    430   SSL_CTX_set_next_protos_advertised_cb(
    431       ssl_ctx.get(), NextProtosAdvertisedCallback, NULL);
    432   if (!config->select_next_proto.empty()) {
    433     SSL_CTX_set_next_proto_select_cb(ssl_ctx.get(), NextProtoSelectCallback,
    434                                      NULL);
    435   }
    436 
    437   if (!config->select_alpn.empty()) {
    438     SSL_CTX_set_alpn_select_cb(ssl_ctx.get(), AlpnSelectCallback, NULL);
    439   }
    440 
    441   ssl_ctx->tlsext_channel_id_enabled_new = 1;
    442   SSL_CTX_set_channel_id_cb(ssl_ctx.get(), ChannelIdCallback);
    443 
    444   ssl_ctx->current_time_cb = CurrentTimeCallback;
    445 
    446   SSL_CTX_set_info_callback(ssl_ctx.get(), InfoCallback);
    447 
    448   return ssl_ctx;
    449 }
    450 
    451 // RetryAsync is called after a failed operation on |ssl| with return code
    452 // |ret|. If the operation should be retried, it simulates one asynchronous
    453 // event and returns true. Otherwise it returns false.
    454 static bool RetryAsync(SSL *ssl, int ret) {
    455   // No error; don't retry.
    456   if (ret >= 0) {
    457     return false;
    458   }
    459 
    460   TestState *test_state = GetTestState(ssl);
    461   if (test_state->clock_delta.tv_usec != 0 ||
    462       test_state->clock_delta.tv_sec != 0) {
    463     // Process the timeout and retry.
    464     test_state->clock.tv_usec += test_state->clock_delta.tv_usec;
    465     test_state->clock.tv_sec += test_state->clock.tv_usec / 1000000;
    466     test_state->clock.tv_usec %= 1000000;
    467     test_state->clock.tv_sec += test_state->clock_delta.tv_sec;
    468     memset(&test_state->clock_delta, 0, sizeof(test_state->clock_delta));
    469 
    470     if (DTLSv1_handle_timeout(ssl) < 0) {
    471       fprintf(stderr, "Error retransmitting.\n");
    472       return false;
    473     }
    474     return true;
    475   }
    476 
    477   // See if we needed to read or write more. If so, allow one byte through on
    478   // the appropriate end to maximally stress the state machine.
    479   switch (SSL_get_error(ssl, ret)) {
    480     case SSL_ERROR_WANT_READ:
    481       AsyncBioAllowRead(test_state->async_bio, 1);
    482       return true;
    483     case SSL_ERROR_WANT_WRITE:
    484       AsyncBioAllowWrite(test_state->async_bio, 1);
    485       return true;
    486     case SSL_ERROR_WANT_CHANNEL_ID_LOOKUP: {
    487       ScopedEVP_PKEY pkey = LoadPrivateKey(GetConfigPtr(ssl)->send_channel_id);
    488       if (!pkey) {
    489         return false;
    490       }
    491       test_state->channel_id = std::move(pkey);
    492       return true;
    493     }
    494     case SSL_ERROR_WANT_X509_LOOKUP:
    495       test_state->cert_ready = true;
    496       return true;
    497     case SSL_ERROR_PENDING_SESSION:
    498       test_state->session = std::move(test_state->pending_session);
    499       return true;
    500     case SSL_ERROR_PENDING_CERTIFICATE:
    501       // The handshake will resume without a second call to the early callback.
    502       return InstallCertificate(ssl);
    503     default:
    504       return false;
    505   }
    506 }
    507 
    508 // DoRead reads from |ssl|, resolving any asynchronous operations. It returns
    509 // the result value of the final |SSL_read| call.
    510 static int DoRead(SSL *ssl, uint8_t *out, size_t max_out) {
    511   const TestConfig *config = GetConfigPtr(ssl);
    512   int ret;
    513   do {
    514     ret = SSL_read(ssl, out, max_out);
    515   } while (config->async && RetryAsync(ssl, ret));
    516   return ret;
    517 }
    518 
    519 // WriteAll writes |in_len| bytes from |in| to |ssl|, resolving any asynchronous
    520 // operations. It returns the result of the final |SSL_write| call.
    521 static int WriteAll(SSL *ssl, const uint8_t *in, size_t in_len) {
    522   const TestConfig *config = GetConfigPtr(ssl);
    523   int ret;
    524   do {
    525     ret = SSL_write(ssl, in, in_len);
    526     if (ret > 0) {
    527       in += ret;
    528       in_len -= ret;
    529     }
    530   } while ((config->async && RetryAsync(ssl, ret)) || (ret > 0 && in_len > 0));
    531   return ret;
    532 }
    533 
    534 // DoExchange runs a test SSL exchange against the peer. On success, it returns
    535 // true and sets |*out_session| to the negotiated SSL session. If the test is a
    536 // resumption attempt, |is_resume| is true and |session| is the session from the
    537 // previous exchange.
    538 static bool DoExchange(ScopedSSL_SESSION *out_session, SSL_CTX *ssl_ctx,
    539                        const TestConfig *config, bool is_resume,
    540                        SSL_SESSION *session) {
    541   ScopedSSL ssl(SSL_new(ssl_ctx));
    542   if (!ssl) {
    543     return false;
    544   }
    545 
    546   if (!SetConfigPtr(ssl.get(), config) ||
    547       !SetTestState(ssl.get(), std::unique_ptr<TestState>(new TestState))) {
    548     return false;
    549   }
    550 
    551   if (config->fallback_scsv &&
    552       !SSL_set_mode(ssl.get(), SSL_MODE_SEND_FALLBACK_SCSV)) {
    553     return false;
    554   }
    555   if (!config->use_early_callback) {
    556     if (config->async) {
    557       // TODO(davidben): Also test |s->ctx->client_cert_cb| on the client.
    558       SSL_set_cert_cb(ssl.get(), CertCallback, NULL);
    559     } else if (!InstallCertificate(ssl.get())) {
    560       return false;
    561     }
    562   }
    563   if (config->require_any_client_certificate) {
    564     SSL_set_verify(ssl.get(), SSL_VERIFY_PEER|SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
    565                    SkipVerify);
    566   }
    567   if (config->false_start) {
    568     SSL_set_mode(ssl.get(), SSL_MODE_ENABLE_FALSE_START);
    569   }
    570   if (config->cbc_record_splitting) {
    571     SSL_set_mode(ssl.get(), SSL_MODE_CBC_RECORD_SPLITTING);
    572   }
    573   if (config->partial_write) {
    574     SSL_set_mode(ssl.get(), SSL_MODE_ENABLE_PARTIAL_WRITE);
    575   }
    576   if (config->no_tls12) {
    577     SSL_set_options(ssl.get(), SSL_OP_NO_TLSv1_2);
    578   }
    579   if (config->no_tls11) {
    580     SSL_set_options(ssl.get(), SSL_OP_NO_TLSv1_1);
    581   }
    582   if (config->no_tls1) {
    583     SSL_set_options(ssl.get(), SSL_OP_NO_TLSv1);
    584   }
    585   if (config->no_ssl3) {
    586     SSL_set_options(ssl.get(), SSL_OP_NO_SSLv3);
    587   }
    588   if (config->tls_d5_bug) {
    589     SSL_set_options(ssl.get(), SSL_OP_TLS_D5_BUG);
    590   }
    591   if (config->allow_unsafe_legacy_renegotiation) {
    592     SSL_set_options(ssl.get(), SSL_OP_ALLOW_UNSAFE_LEGACY_RENEGOTIATION);
    593   }
    594   if (config->no_legacy_server_connect) {
    595     SSL_clear_options(ssl.get(), SSL_OP_LEGACY_SERVER_CONNECT);
    596   }
    597   if (!config->expected_channel_id.empty()) {
    598     SSL_enable_tls_channel_id(ssl.get());
    599   }
    600   if (!config->send_channel_id.empty()) {
    601     SSL_enable_tls_channel_id(ssl.get());
    602     if (!config->async) {
    603       // The async case will be supplied by |ChannelIdCallback|.
    604       ScopedEVP_PKEY pkey = LoadPrivateKey(config->send_channel_id);
    605       if (!pkey || !SSL_set1_tls_channel_id(ssl.get(), pkey.get())) {
    606         return false;
    607       }
    608     }
    609   }
    610   if (!config->host_name.empty() &&
    611       !SSL_set_tlsext_host_name(ssl.get(), config->host_name.c_str())) {
    612     return false;
    613   }
    614   if (!config->advertise_alpn.empty() &&
    615       SSL_set_alpn_protos(ssl.get(),
    616                           (const uint8_t *)config->advertise_alpn.data(),
    617                           config->advertise_alpn.size()) != 0) {
    618     return false;
    619   }
    620   if (!config->psk.empty()) {
    621     SSL_set_psk_client_callback(ssl.get(), PskClientCallback);
    622     SSL_set_psk_server_callback(ssl.get(), PskServerCallback);
    623   }
    624   if (!config->psk_identity.empty() &&
    625       !SSL_use_psk_identity_hint(ssl.get(), config->psk_identity.c_str())) {
    626     return false;
    627   }
    628   if (!config->srtp_profiles.empty() &&
    629       !SSL_set_srtp_profiles(ssl.get(), config->srtp_profiles.c_str())) {
    630     return false;
    631   }
    632   if (config->enable_ocsp_stapling &&
    633       !SSL_enable_ocsp_stapling(ssl.get())) {
    634     return false;
    635   }
    636   if (config->enable_signed_cert_timestamps &&
    637       !SSL_enable_signed_cert_timestamps(ssl.get())) {
    638     return false;
    639   }
    640   SSL_enable_fastradio_padding(ssl.get(), config->fastradio_padding);
    641   if (config->min_version != 0) {
    642     SSL_set_min_version(ssl.get(), (uint16_t)config->min_version);
    643   }
    644   if (config->max_version != 0) {
    645     SSL_set_max_version(ssl.get(), (uint16_t)config->max_version);
    646   }
    647   if (config->mtu != 0) {
    648     SSL_set_options(ssl.get(), SSL_OP_NO_QUERY_MTU);
    649     SSL_set_mtu(ssl.get(), config->mtu);
    650   }
    651   if (config->install_ddos_callback) {
    652     SSL_CTX_set_dos_protection_cb(ssl_ctx, DDoSCallback);
    653   }
    654   if (!config->cipher.empty() &&
    655       !SSL_set_cipher_list(ssl.get(), config->cipher.c_str())) {
    656     return false;
    657   }
    658   if (!config->reject_peer_renegotiations) {
    659     /* Renegotiations are disabled by default. */
    660     SSL_set_reject_peer_renegotiations(ssl.get(), 0);
    661   }
    662 
    663   int sock = Connect(config->port);
    664   if (sock == -1) {
    665     return false;
    666   }
    667   SocketCloser closer(sock);
    668 
    669   ScopedBIO bio(BIO_new_socket(sock, BIO_NOCLOSE));
    670   if (!bio) {
    671     return false;
    672   }
    673   if (config->is_dtls) {
    674     ScopedBIO packeted =
    675         PacketedBioCreate(&GetTestState(ssl.get())->clock_delta);
    676     BIO_push(packeted.get(), bio.release());
    677     bio = std::move(packeted);
    678   }
    679   if (config->async) {
    680     ScopedBIO async_scoped =
    681         config->is_dtls ? AsyncBioCreateDatagram() : AsyncBioCreate();
    682     BIO_push(async_scoped.get(), bio.release());
    683     GetTestState(ssl.get())->async_bio = async_scoped.get();
    684     bio = std::move(async_scoped);
    685   }
    686   SSL_set_bio(ssl.get(), bio.get(), bio.get());
    687   bio.release();  // SSL_set_bio takes ownership.
    688 
    689   if (session != NULL) {
    690     if (!config->is_server) {
    691       if (SSL_set_session(ssl.get(), session) != 1) {
    692         return false;
    693       }
    694     } else if (config->async) {
    695       // The internal session cache is disabled, so install the session
    696       // manually.
    697       GetTestState(ssl.get())->pending_session.reset(
    698           SSL_SESSION_up_ref(session));
    699     }
    700   }
    701 
    702   if (SSL_get_current_cipher(ssl.get()) != nullptr) {
    703     fprintf(stderr, "non-null cipher before handshake\n");
    704     return false;
    705   }
    706 
    707   int ret;
    708   if (config->implicit_handshake) {
    709     if (config->is_server) {
    710       SSL_set_accept_state(ssl.get());
    711     } else {
    712       SSL_set_connect_state(ssl.get());
    713     }
    714   } else {
    715     do {
    716       if (config->is_server) {
    717         ret = SSL_accept(ssl.get());
    718       } else {
    719         ret = SSL_connect(ssl.get());
    720       }
    721     } while (config->async && RetryAsync(ssl.get(), ret));
    722     if (ret != 1) {
    723       return false;
    724     }
    725 
    726     if (SSL_get_current_cipher(ssl.get()) == nullptr) {
    727       fprintf(stderr, "null cipher after handshake\n");
    728       return false;
    729     }
    730 
    731     if (is_resume &&
    732         (!!SSL_session_reused(ssl.get()) == config->expect_session_miss)) {
    733       fprintf(stderr, "session was%s reused\n",
    734               SSL_session_reused(ssl.get()) ? "" : " not");
    735       return false;
    736     }
    737 
    738     bool expect_handshake_done = is_resume || !config->false_start;
    739     if (expect_handshake_done != GetTestState(ssl.get())->handshake_done) {
    740       fprintf(stderr, "handshake was%s completed\n",
    741               GetTestState(ssl.get())->handshake_done ? "" : " not");
    742       return false;
    743     }
    744 
    745     if (config->is_server && !GetTestState(ssl.get())->early_callback_called) {
    746       fprintf(stderr, "early callback not called\n");
    747       return false;
    748     }
    749 
    750     if (!config->expected_server_name.empty()) {
    751       const char *server_name =
    752         SSL_get_servername(ssl.get(), TLSEXT_NAMETYPE_host_name);
    753       if (server_name != config->expected_server_name) {
    754         fprintf(stderr, "servername mismatch (got %s; want %s)\n",
    755                 server_name, config->expected_server_name.c_str());
    756         return false;
    757       }
    758     }
    759 
    760     if (!config->expected_certificate_types.empty()) {
    761       uint8_t *certificate_types;
    762       int num_certificate_types =
    763         SSL_get0_certificate_types(ssl.get(), &certificate_types);
    764       if (num_certificate_types !=
    765           (int)config->expected_certificate_types.size() ||
    766           memcmp(certificate_types,
    767                  config->expected_certificate_types.data(),
    768                  num_certificate_types) != 0) {
    769         fprintf(stderr, "certificate types mismatch\n");
    770         return false;
    771       }
    772     }
    773 
    774     if (!config->expected_next_proto.empty()) {
    775       const uint8_t *next_proto;
    776       unsigned next_proto_len;
    777       SSL_get0_next_proto_negotiated(ssl.get(), &next_proto, &next_proto_len);
    778       if (next_proto_len != config->expected_next_proto.size() ||
    779           memcmp(next_proto, config->expected_next_proto.data(),
    780                  next_proto_len) != 0) {
    781         fprintf(stderr, "negotiated next proto mismatch\n");
    782         return false;
    783       }
    784     }
    785 
    786     if (!config->expected_alpn.empty()) {
    787       const uint8_t *alpn_proto;
    788       unsigned alpn_proto_len;
    789       SSL_get0_alpn_selected(ssl.get(), &alpn_proto, &alpn_proto_len);
    790       if (alpn_proto_len != config->expected_alpn.size() ||
    791           memcmp(alpn_proto, config->expected_alpn.data(),
    792                  alpn_proto_len) != 0) {
    793         fprintf(stderr, "negotiated alpn proto mismatch\n");
    794         return false;
    795       }
    796     }
    797 
    798     if (!config->expected_channel_id.empty()) {
    799       uint8_t channel_id[64];
    800       if (!SSL_get_tls_channel_id(ssl.get(), channel_id, sizeof(channel_id))) {
    801         fprintf(stderr, "no channel id negotiated\n");
    802         return false;
    803       }
    804       if (config->expected_channel_id.size() != 64 ||
    805           memcmp(config->expected_channel_id.data(),
    806                  channel_id, 64) != 0) {
    807         fprintf(stderr, "channel id mismatch\n");
    808         return false;
    809       }
    810     }
    811 
    812     if (config->expect_extended_master_secret) {
    813       if (!ssl->session->extended_master_secret) {
    814         fprintf(stderr, "No EMS for session when expected");
    815         return false;
    816       }
    817     }
    818 
    819     if (!config->expected_ocsp_response.empty()) {
    820       const uint8_t *data;
    821       size_t len;
    822       SSL_get0_ocsp_response(ssl.get(), &data, &len);
    823       if (config->expected_ocsp_response.size() != len ||
    824           memcmp(config->expected_ocsp_response.data(), data, len) != 0) {
    825         fprintf(stderr, "OCSP response mismatch\n");
    826         return false;
    827       }
    828     }
    829 
    830     if (!config->expected_signed_cert_timestamps.empty()) {
    831       const uint8_t *data;
    832       size_t len;
    833       SSL_get0_signed_cert_timestamp_list(ssl.get(), &data, &len);
    834       if (config->expected_signed_cert_timestamps.size() != len ||
    835           memcmp(config->expected_signed_cert_timestamps.data(),
    836                  data, len) != 0) {
    837         fprintf(stderr, "SCT list mismatch\n");
    838         return false;
    839       }
    840     }
    841   }
    842 
    843   if (config->export_keying_material > 0) {
    844     std::vector<uint8_t> result(
    845         static_cast<size_t>(config->export_keying_material));
    846     if (!SSL_export_keying_material(
    847             ssl.get(), result.data(), result.size(),
    848             config->export_label.data(), config->export_label.size(),
    849             reinterpret_cast<const uint8_t*>(config->export_context.data()),
    850             config->export_context.size(), config->use_export_context)) {
    851       fprintf(stderr, "failed to export keying material\n");
    852       return false;
    853     }
    854     if (WriteAll(ssl.get(), result.data(), result.size()) < 0) {
    855       return false;
    856     }
    857   }
    858 
    859   if (config->tls_unique) {
    860     uint8_t tls_unique[16];
    861     size_t tls_unique_len;
    862     if (!SSL_get_tls_unique(ssl.get(), tls_unique, &tls_unique_len,
    863                             sizeof(tls_unique))) {
    864       fprintf(stderr, "failed to get tls-unique\n");
    865       return false;
    866     }
    867 
    868     if (tls_unique_len != 12) {
    869       fprintf(stderr, "expected 12 bytes of tls-unique but got %u",
    870               static_cast<unsigned>(tls_unique_len));
    871       return false;
    872     }
    873 
    874     if (WriteAll(ssl.get(), tls_unique, tls_unique_len) < 0) {
    875       return false;
    876     }
    877   }
    878 
    879   if (config->write_different_record_sizes) {
    880     if (config->is_dtls) {
    881       fprintf(stderr, "write_different_record_sizes not supported for DTLS\n");
    882       return false;
    883     }
    884     // This mode writes a number of different record sizes in an attempt to
    885     // trip up the CBC record splitting code.
    886     uint8_t buf[32769];
    887     memset(buf, 0x42, sizeof(buf));
    888     static const size_t kRecordSizes[] = {
    889         0, 1, 255, 256, 257, 16383, 16384, 16385, 32767, 32768, 32769};
    890     for (size_t i = 0; i < sizeof(kRecordSizes) / sizeof(kRecordSizes[0]);
    891          i++) {
    892       const size_t len = kRecordSizes[i];
    893       if (len > sizeof(buf)) {
    894         fprintf(stderr, "Bad kRecordSizes value.\n");
    895         return false;
    896       }
    897       if (WriteAll(ssl.get(), buf, len) < 0) {
    898         return false;
    899       }
    900     }
    901   } else {
    902     if (config->shim_writes_first) {
    903       if (WriteAll(ssl.get(), reinterpret_cast<const uint8_t *>("hello"),
    904                    5) < 0) {
    905         return false;
    906       }
    907     }
    908     for (;;) {
    909       uint8_t buf[512];
    910       int n = DoRead(ssl.get(), buf, sizeof(buf));
    911       int err = SSL_get_error(ssl.get(), n);
    912       if (err == SSL_ERROR_ZERO_RETURN ||
    913           (n == 0 && err == SSL_ERROR_SYSCALL)) {
    914         if (n != 0) {
    915           fprintf(stderr, "Invalid SSL_get_error output\n");
    916           return false;
    917         }
    918         // Accept shutdowns with or without close_notify.
    919         // TODO(davidben): Write tests which distinguish these two cases.
    920         break;
    921       } else if (err != SSL_ERROR_NONE) {
    922         if (n > 0) {
    923           fprintf(stderr, "Invalid SSL_get_error output\n");
    924           return false;
    925         }
    926         return false;
    927       }
    928       // Successfully read data.
    929       if (n <= 0) {
    930         fprintf(stderr, "Invalid SSL_get_error output\n");
    931         return false;
    932       }
    933 
    934       // After a successful read, with or without False Start, the handshake
    935       // must be complete.
    936       if (!GetTestState(ssl.get())->handshake_done) {
    937         fprintf(stderr, "handshake was not completed after SSL_read\n");
    938         return false;
    939       }
    940 
    941       for (int i = 0; i < n; i++) {
    942         buf[i] ^= 0xff;
    943       }
    944       if (WriteAll(ssl.get(), buf, n) < 0) {
    945         return false;
    946       }
    947     }
    948   }
    949 
    950   if (out_session) {
    951     out_session->reset(SSL_get1_session(ssl.get()));
    952   }
    953 
    954   SSL_shutdown(ssl.get());
    955   return true;
    956 }
    957 
    958 int main(int argc, char **argv) {
    959 #if defined(OPENSSL_WINDOWS)
    960   /* Initialize Winsock. */
    961   WORD wsa_version = MAKEWORD(2, 2);
    962   WSADATA wsa_data;
    963   int wsa_err = WSAStartup(wsa_version, &wsa_data);
    964   if (wsa_err != 0) {
    965     fprintf(stderr, "WSAStartup failed: %d\n", wsa_err);
    966     return 1;
    967   }
    968   if (wsa_data.wVersion != wsa_version) {
    969     fprintf(stderr, "Didn't get expected version: %x\n", wsa_data.wVersion);
    970     return 1;
    971   }
    972 #else
    973   signal(SIGPIPE, SIG_IGN);
    974 #endif
    975 
    976   if (!SSL_library_init()) {
    977     return 1;
    978   }
    979   g_config_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
    980   g_state_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, TestStateExFree);
    981   if (g_config_index < 0 || g_state_index < 0) {
    982     return 1;
    983   }
    984 
    985   TestConfig config;
    986   if (!ParseConfig(argc - 1, argv + 1, &config)) {
    987     return Usage(argv[0]);
    988   }
    989 
    990   ScopedSSL_CTX ssl_ctx = SetupCtx(&config);
    991   if (!ssl_ctx) {
    992     ERR_print_errors_fp(stderr);
    993     return 1;
    994   }
    995 
    996   ScopedSSL_SESSION session;
    997   if (!DoExchange(&session, ssl_ctx.get(), &config, false /* is_resume */,
    998                   NULL /* session */)) {
    999     ERR_print_errors_fp(stderr);
   1000     return 1;
   1001   }
   1002 
   1003   if (config.resume &&
   1004       !DoExchange(NULL, ssl_ctx.get(), &config, true /* is_resume */,
   1005                   session.get())) {
   1006     ERR_print_errors_fp(stderr);
   1007     return 1;
   1008   }
   1009 
   1010   return 0;
   1011 }
   1012