Home | History | Annotate | Download | only in handshaker
      1 /*
      2  *
      3  * Copyright 2018 gRPC authors.
      4  *
      5  * Licensed under the Apache License, Version 2.0 (the "License");
      6  * you may not use this file except in compliance with the License.
      7  * You may obtain a copy of the License at
      8  *
      9  *     http://www.apache.org/licenses/LICENSE-2.0
     10  *
     11  * Unless required by applicable law or agreed to in writing, software
     12  * distributed under the License is distributed on an "AS IS" BASIS,
     13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14  * See the License for the specific language governing permissions and
     15  * limitations under the License.
     16  *
     17  */
     18 
     19 #include <stdio.h>
     20 #include <stdlib.h>
     21 
     22 #include <grpc/grpc.h>
     23 #include <grpc/support/sync.h>
     24 
     25 #include "src/core/lib/gprpp/thd.h"
     26 #include "src/core/tsi/alts/handshaker/alts_handshaker_client.h"
     27 #include "src/core/tsi/alts/handshaker/alts_tsi_event.h"
     28 #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h"
     29 #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker_private.h"
     30 #include "test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.h"
     31 
     32 #define ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES "Hello World"
     33 #define ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME "Hello Google"
     34 #define ALTS_TSI_HANDSHAKER_TEST_CONSUMED_BYTES "Hello "
     35 #define ALTS_TSI_HANDSHAKER_TEST_REMAIN_BYTES "Google"
     36 #define ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY "chapi (at) service.google.com"
     37 #define ALTS_TSI_HANDSHAKER_TEST_KEY_DATA \
     38   "ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKL"
     39 #define ALTS_TSI_HANDSHAKER_TEST_BUFFER_SIZE 100
     40 #define ALTS_TSI_HANDSHAKER_TEST_SLEEP_TIME_IN_SECONDS 2
     41 #define ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MAJOR 3
     42 #define ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MINOR 2
     43 #define ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MAJOR 2
     44 #define ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MINOR 1
     45 
     46 using grpc_core::internal::
     47     alts_tsi_handshaker_get_has_sent_start_message_for_testing;
     48 using grpc_core::internal::alts_tsi_handshaker_get_is_client_for_testing;
     49 using grpc_core::internal::alts_tsi_handshaker_get_recv_bytes_for_testing;
     50 using grpc_core::internal::alts_tsi_handshaker_set_client_for_testing;
     51 using grpc_core::internal::alts_tsi_handshaker_set_recv_bytes_for_testing;
     52 
     53 /* ALTS mock notification. */
     54 typedef struct notification {
     55   gpr_cv cv;
     56   gpr_mu mu;
     57   bool notified;
     58 } notification;
     59 
     60 /* ALTS mock handshaker client. */
     61 typedef struct alts_mock_handshaker_client {
     62   alts_handshaker_client base;
     63   bool used_for_success_test;
     64 } alts_mock_handshaker_client;
     65 
     66 /* Type of ALTS handshaker response. */
     67 typedef enum {
     68   INVALID,
     69   FAILED,
     70   CLIENT_START,
     71   SERVER_START,
     72   CLIENT_NEXT,
     73   SERVER_NEXT,
     74 } alts_handshaker_response_type;
     75 
     76 static alts_tsi_event* client_start_event;
     77 static alts_tsi_event* client_next_event;
     78 static alts_tsi_event* server_start_event;
     79 static alts_tsi_event* server_next_event;
     80 static notification caller_to_tsi_notification;
     81 static notification tsi_to_caller_notification;
     82 
     83 static void notification_init(notification* n) {
     84   gpr_mu_init(&n->mu);
     85   gpr_cv_init(&n->cv);
     86   n->notified = false;
     87 }
     88 
     89 static void notification_destroy(notification* n) {
     90   gpr_mu_destroy(&n->mu);
     91   gpr_cv_destroy(&n->cv);
     92 }
     93 
     94 static void signal(notification* n) {
     95   gpr_mu_lock(&n->mu);
     96   n->notified = true;
     97   gpr_cv_signal(&n->cv);
     98   gpr_mu_unlock(&n->mu);
     99 }
    100 
    101 static void wait(notification* n) {
    102   gpr_mu_lock(&n->mu);
    103   while (!n->notified) {
    104     gpr_cv_wait(&n->cv, &n->mu, gpr_inf_future(GPR_CLOCK_REALTIME));
    105   }
    106   n->notified = false;
    107   gpr_mu_unlock(&n->mu);
    108 }
    109 
    110 /**
    111  * This method mocks ALTS handshaker service to generate handshaker response
    112  * for a specific request.
    113  */
    114 static grpc_byte_buffer* generate_handshaker_response(
    115     alts_handshaker_response_type type) {
    116   grpc_gcp_handshaker_resp* resp = grpc_gcp_handshaker_resp_create();
    117   GPR_ASSERT(grpc_gcp_handshaker_resp_set_code(resp, 0));
    118   switch (type) {
    119     case INVALID:
    120       break;
    121     case CLIENT_START:
    122     case SERVER_START:
    123       GPR_ASSERT(grpc_gcp_handshaker_resp_set_out_frames(
    124           resp, ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME,
    125           strlen(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME)));
    126       break;
    127     case CLIENT_NEXT:
    128       GPR_ASSERT(grpc_gcp_handshaker_resp_set_out_frames(
    129           resp, ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME,
    130           strlen(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME)));
    131       GPR_ASSERT(grpc_gcp_handshaker_resp_set_peer_identity_service_account(
    132           resp, ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY));
    133       GPR_ASSERT(grpc_gcp_handshaker_resp_set_bytes_consumed(
    134           resp, strlen(ALTS_TSI_HANDSHAKER_TEST_CONSUMED_BYTES)));
    135       GPR_ASSERT(grpc_gcp_handshaker_resp_set_key_data(
    136           resp, ALTS_TSI_HANDSHAKER_TEST_KEY_DATA,
    137           strlen(ALTS_TSI_HANDSHAKER_TEST_KEY_DATA)));
    138       GPR_ASSERT(grpc_gcp_handshaker_resp_set_peer_rpc_versions(
    139           resp, ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MAJOR,
    140           ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MINOR,
    141           ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MAJOR,
    142           ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MINOR));
    143       break;
    144     case SERVER_NEXT:
    145       GPR_ASSERT(grpc_gcp_handshaker_resp_set_peer_identity_service_account(
    146           resp, ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY));
    147       GPR_ASSERT(grpc_gcp_handshaker_resp_set_bytes_consumed(
    148           resp, strlen(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME)));
    149       GPR_ASSERT(grpc_gcp_handshaker_resp_set_key_data(
    150           resp, ALTS_TSI_HANDSHAKER_TEST_KEY_DATA,
    151           strlen(ALTS_TSI_HANDSHAKER_TEST_KEY_DATA)));
    152       GPR_ASSERT(grpc_gcp_handshaker_resp_set_peer_rpc_versions(
    153           resp, ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MAJOR,
    154           ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MINOR,
    155           ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MAJOR,
    156           ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MINOR));
    157       break;
    158     case FAILED:
    159       GPR_ASSERT(
    160           grpc_gcp_handshaker_resp_set_code(resp, 3 /* INVALID ARGUMENT */));
    161       break;
    162   }
    163   grpc_slice slice;
    164   GPR_ASSERT(grpc_gcp_handshaker_resp_encode(resp, &slice));
    165   if (type == INVALID) {
    166     grpc_slice bad_slice =
    167         grpc_slice_split_head(&slice, GRPC_SLICE_LENGTH(slice) - 1);
    168     grpc_slice_unref(slice);
    169     slice = grpc_slice_ref(bad_slice);
    170     grpc_slice_unref(bad_slice);
    171   }
    172   grpc_byte_buffer* buffer =
    173       grpc_raw_byte_buffer_create(&slice, 1 /* number of slices */);
    174   grpc_slice_unref(slice);
    175   grpc_gcp_handshaker_resp_destroy(resp);
    176   return buffer;
    177 }
    178 
    179 static void check_must_not_be_called(tsi_result status, void* user_data,
    180                                      const unsigned char* bytes_to_send,
    181                                      size_t bytes_to_send_size,
    182                                      tsi_handshaker_result* result) {
    183   GPR_ASSERT(0);
    184 }
    185 
    186 static void on_client_start_success_cb(tsi_result status, void* user_data,
    187                                        const unsigned char* bytes_to_send,
    188                                        size_t bytes_to_send_size,
    189                                        tsi_handshaker_result* result) {
    190   GPR_ASSERT(status == TSI_OK);
    191   GPR_ASSERT(user_data == nullptr);
    192   GPR_ASSERT(bytes_to_send_size == strlen(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME));
    193   GPR_ASSERT(memcmp(bytes_to_send, ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME,
    194                     bytes_to_send_size) == 0);
    195   GPR_ASSERT(result == nullptr);
    196   /* Validate peer identity. */
    197   tsi_peer peer;
    198   GPR_ASSERT(tsi_handshaker_result_extract_peer(result, &peer) ==
    199              TSI_INVALID_ARGUMENT);
    200   /* Validate frame protector. */
    201   tsi_frame_protector* protector = nullptr;
    202   GPR_ASSERT(tsi_handshaker_result_create_frame_protector(
    203                  result, nullptr, &protector) == TSI_INVALID_ARGUMENT);
    204   /* Validate unused bytes. */
    205   const unsigned char* unused_bytes = nullptr;
    206   size_t unused_bytes_size = 0;
    207   GPR_ASSERT(tsi_handshaker_result_get_unused_bytes(result, &unused_bytes,
    208                                                     &unused_bytes_size) ==
    209              TSI_INVALID_ARGUMENT);
    210   signal(&tsi_to_caller_notification);
    211 }
    212 
    213 static void on_server_start_success_cb(tsi_result status, void* user_data,
    214                                        const unsigned char* bytes_to_send,
    215                                        size_t bytes_to_send_size,
    216                                        tsi_handshaker_result* result) {
    217   GPR_ASSERT(status == TSI_OK);
    218   GPR_ASSERT(user_data == nullptr);
    219   GPR_ASSERT(bytes_to_send_size == strlen(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME));
    220   GPR_ASSERT(memcmp(bytes_to_send, ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME,
    221                     bytes_to_send_size) == 0);
    222   GPR_ASSERT(result == nullptr);
    223   /* Validate peer identity. */
    224   tsi_peer peer;
    225   GPR_ASSERT(tsi_handshaker_result_extract_peer(result, &peer) ==
    226              TSI_INVALID_ARGUMENT);
    227   /* Validate frame protector. */
    228   tsi_frame_protector* protector = nullptr;
    229   GPR_ASSERT(tsi_handshaker_result_create_frame_protector(
    230                  result, nullptr, &protector) == TSI_INVALID_ARGUMENT);
    231   /* Validate unused bytes. */
    232   const unsigned char* unused_bytes = nullptr;
    233   size_t unused_bytes_size = 0;
    234   GPR_ASSERT(tsi_handshaker_result_get_unused_bytes(result, &unused_bytes,
    235                                                     &unused_bytes_size) ==
    236              TSI_INVALID_ARGUMENT);
    237   signal(&tsi_to_caller_notification);
    238 }
    239 
    240 static void on_client_next_success_cb(tsi_result status, void* user_data,
    241                                       const unsigned char* bytes_to_send,
    242                                       size_t bytes_to_send_size,
    243                                       tsi_handshaker_result* result) {
    244   GPR_ASSERT(status == TSI_OK);
    245   GPR_ASSERT(user_data == nullptr);
    246   GPR_ASSERT(bytes_to_send_size == strlen(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME));
    247   GPR_ASSERT(memcmp(bytes_to_send, ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME,
    248                     bytes_to_send_size) == 0);
    249   GPR_ASSERT(result != nullptr);
    250   /* Validate peer identity. */
    251   tsi_peer peer;
    252   GPR_ASSERT(tsi_handshaker_result_extract_peer(result, &peer) == TSI_OK);
    253   GPR_ASSERT(peer.property_count == kTsiAltsNumOfPeerProperties);
    254   GPR_ASSERT(memcmp(TSI_ALTS_CERTIFICATE_TYPE, peer.properties[0].value.data,
    255                     peer.properties[0].value.length) == 0);
    256   GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY,
    257                     peer.properties[1].value.data,
    258                     peer.properties[1].value.length) == 0);
    259   tsi_peer_destruct(&peer);
    260   /* Validate unused bytes. */
    261   const unsigned char* bytes = nullptr;
    262   size_t bytes_size = 0;
    263   GPR_ASSERT(tsi_handshaker_result_get_unused_bytes(result, &bytes,
    264                                                     &bytes_size) == TSI_OK);
    265   GPR_ASSERT(bytes_size == strlen(ALTS_TSI_HANDSHAKER_TEST_REMAIN_BYTES));
    266   GPR_ASSERT(memcmp(bytes, ALTS_TSI_HANDSHAKER_TEST_REMAIN_BYTES, bytes_size) ==
    267              0);
    268   /* Validate frame protector. */
    269   tsi_frame_protector* protector = nullptr;
    270   GPR_ASSERT(tsi_handshaker_result_create_frame_protector(
    271                  result, nullptr, &protector) == TSI_OK);
    272   GPR_ASSERT(protector != nullptr);
    273   tsi_frame_protector_destroy(protector);
    274   tsi_handshaker_result_destroy(result);
    275   signal(&tsi_to_caller_notification);
    276 }
    277 
    278 static void on_server_next_success_cb(tsi_result status, void* user_data,
    279                                       const unsigned char* bytes_to_send,
    280                                       size_t bytes_to_send_size,
    281                                       tsi_handshaker_result* result) {
    282   GPR_ASSERT(status == TSI_OK);
    283   GPR_ASSERT(user_data == nullptr);
    284   GPR_ASSERT(bytes_to_send_size == 0);
    285   GPR_ASSERT(bytes_to_send == nullptr);
    286   GPR_ASSERT(result != nullptr);
    287   /* Validate peer identity. */
    288   tsi_peer peer;
    289   GPR_ASSERT(tsi_handshaker_result_extract_peer(result, &peer) == TSI_OK);
    290   GPR_ASSERT(peer.property_count == kTsiAltsNumOfPeerProperties);
    291   GPR_ASSERT(memcmp(TSI_ALTS_CERTIFICATE_TYPE, peer.properties[0].value.data,
    292                     peer.properties[0].value.length) == 0);
    293   GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY,
    294                     peer.properties[1].value.data,
    295                     peer.properties[1].value.length) == 0);
    296   tsi_peer_destruct(&peer);
    297   /* Validate unused bytes. */
    298   const unsigned char* bytes = nullptr;
    299   size_t bytes_size = 0;
    300   GPR_ASSERT(tsi_handshaker_result_get_unused_bytes(result, &bytes,
    301                                                     &bytes_size) == TSI_OK);
    302   GPR_ASSERT(bytes_size == 0);
    303   GPR_ASSERT(bytes == nullptr);
    304   /* Validate frame protector. */
    305   tsi_frame_protector* protector = nullptr;
    306   GPR_ASSERT(tsi_handshaker_result_create_frame_protector(
    307                  result, nullptr, &protector) == TSI_OK);
    308   GPR_ASSERT(protector != nullptr);
    309   tsi_frame_protector_destroy(protector);
    310   tsi_handshaker_result_destroy(result);
    311   signal(&tsi_to_caller_notification);
    312 }
    313 
    314 static tsi_result mock_client_start(alts_handshaker_client* self,
    315                                     alts_tsi_event* event) {
    316   alts_mock_handshaker_client* client =
    317       reinterpret_cast<alts_mock_handshaker_client*>(self);
    318   if (!client->used_for_success_test) {
    319     alts_tsi_event_destroy(event);
    320     return TSI_INTERNAL_ERROR;
    321   }
    322   GPR_ASSERT(event->cb == on_client_start_success_cb);
    323   GPR_ASSERT(event->user_data == nullptr);
    324   GPR_ASSERT(!alts_tsi_handshaker_get_has_sent_start_message_for_testing(
    325       event->handshaker));
    326   /* Populate handshaker response for client_start request. */
    327   event->recv_buffer = generate_handshaker_response(CLIENT_START);
    328   client_start_event = event;
    329   signal(&caller_to_tsi_notification);
    330   return TSI_OK;
    331 }
    332 
    333 static void mock_shutdown(alts_handshaker_client* self) {}
    334 
    335 static tsi_result mock_server_start(alts_handshaker_client* self,
    336                                     alts_tsi_event* event,
    337                                     grpc_slice* bytes_received) {
    338   alts_mock_handshaker_client* client =
    339       reinterpret_cast<alts_mock_handshaker_client*>(self);
    340   if (!client->used_for_success_test) {
    341     alts_tsi_event_destroy(event);
    342     return TSI_INTERNAL_ERROR;
    343   }
    344   GPR_ASSERT(event->cb == on_server_start_success_cb);
    345   GPR_ASSERT(event->user_data == nullptr);
    346   grpc_slice slice = grpc_empty_slice();
    347   GPR_ASSERT(grpc_slice_cmp(*bytes_received, slice) == 0);
    348   GPR_ASSERT(!alts_tsi_handshaker_get_has_sent_start_message_for_testing(
    349       event->handshaker));
    350   /* Populate handshaker response for server_start request. */
    351   event->recv_buffer = generate_handshaker_response(SERVER_START);
    352   server_start_event = event;
    353   grpc_slice_unref(slice);
    354   signal(&caller_to_tsi_notification);
    355   return TSI_OK;
    356 }
    357 
    358 static tsi_result mock_next(alts_handshaker_client* self, alts_tsi_event* event,
    359                             grpc_slice* bytes_received) {
    360   alts_mock_handshaker_client* client =
    361       reinterpret_cast<alts_mock_handshaker_client*>(self);
    362   if (!client->used_for_success_test) {
    363     alts_tsi_event_destroy(event);
    364     return TSI_INTERNAL_ERROR;
    365   }
    366   bool is_client =
    367       alts_tsi_handshaker_get_is_client_for_testing(event->handshaker);
    368   if (is_client) {
    369     GPR_ASSERT(event->cb == on_client_next_success_cb);
    370   } else {
    371     GPR_ASSERT(event->cb == on_server_next_success_cb);
    372   }
    373   GPR_ASSERT(event->user_data == nullptr);
    374   GPR_ASSERT(bytes_received != nullptr);
    375   GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(*bytes_received),
    376                     ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES,
    377                     GRPC_SLICE_LENGTH(*bytes_received)) == 0);
    378   GPR_ASSERT(grpc_slice_cmp(alts_tsi_handshaker_get_recv_bytes_for_testing(
    379                                 event->handshaker),
    380                             *bytes_received) == 0);
    381   GPR_ASSERT(alts_tsi_handshaker_get_has_sent_start_message_for_testing(
    382       event->handshaker));
    383   /* Populate handshaker response for next request. */
    384   grpc_slice out_frame =
    385       grpc_slice_from_static_string(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME);
    386   if (is_client) {
    387     event->recv_buffer = generate_handshaker_response(CLIENT_NEXT);
    388   } else {
    389     event->recv_buffer = generate_handshaker_response(SERVER_NEXT);
    390   }
    391   alts_tsi_handshaker_set_recv_bytes_for_testing(event->handshaker, &out_frame);
    392   if (is_client) {
    393     client_next_event = event;
    394   } else {
    395     server_next_event = event;
    396   }
    397   signal(&caller_to_tsi_notification);
    398   grpc_slice_unref(out_frame);
    399   return TSI_OK;
    400 }
    401 
    402 static void mock_destruct(alts_handshaker_client* client) {}
    403 
    404 static const alts_handshaker_client_vtable vtable = {
    405     mock_client_start, mock_server_start, mock_next, mock_shutdown,
    406     mock_destruct};
    407 
    408 static alts_handshaker_client* alts_mock_handshaker_client_create(
    409     bool used_for_success_test) {
    410   alts_mock_handshaker_client* client =
    411       static_cast<alts_mock_handshaker_client*>(gpr_zalloc(sizeof(*client)));
    412   client->base.vtable = &vtable;
    413   client->used_for_success_test = used_for_success_test;
    414   return &client->base;
    415 }
    416 
    417 static tsi_handshaker* create_test_handshaker(bool used_for_success_test,
    418                                               bool is_client) {
    419   tsi_handshaker* handshaker = nullptr;
    420   alts_handshaker_client* client =
    421       alts_mock_handshaker_client_create(used_for_success_test);
    422   grpc_alts_credentials_options* options =
    423       grpc_alts_credentials_client_options_create();
    424   alts_tsi_handshaker_create(options, "target_name", "lame", is_client,
    425                              &handshaker);
    426   alts_tsi_handshaker* alts_handshaker =
    427       reinterpret_cast<alts_tsi_handshaker*>(handshaker);
    428   alts_tsi_handshaker_set_client_for_testing(alts_handshaker, client);
    429   grpc_alts_credentials_options_destroy(options);
    430   return handshaker;
    431 }
    432 
    433 static void check_handshaker_next_invalid_input() {
    434   /* Initialization. */
    435   tsi_handshaker* handshaker = create_test_handshaker(true, true);
    436   /* Check nullptr handshaker. */
    437   GPR_ASSERT(tsi_handshaker_next(nullptr, nullptr, 0, nullptr, nullptr, nullptr,
    438                                  check_must_not_be_called,
    439                                  nullptr) == TSI_INVALID_ARGUMENT);
    440   /* Check nullptr callback. */
    441   GPR_ASSERT(tsi_handshaker_next(handshaker, nullptr, 0, nullptr, nullptr,
    442                                  nullptr, nullptr,
    443                                  nullptr) == TSI_INVALID_ARGUMENT);
    444   /* Cleanup. */
    445   tsi_handshaker_destroy(handshaker);
    446 }
    447 
    448 static void check_handshaker_shutdown_invalid_input() {
    449   /* Initialization. */
    450   tsi_handshaker* handshaker = create_test_handshaker(
    451       false /* used_for_success_test */, true /* is_client */);
    452   /* Check nullptr handshaker. */
    453   tsi_handshaker_shutdown(nullptr);
    454   /* Cleanup. */
    455   tsi_handshaker_destroy(handshaker);
    456 }
    457 
    458 static void check_handshaker_next_success() {
    459   /**
    460    * Create handshakers for which internal mock client is going to do
    461    * correctness check.
    462    */
    463   tsi_handshaker* client_handshaker = create_test_handshaker(
    464       true /* used_for_success_test */, true /* is_client */);
    465   tsi_handshaker* server_handshaker = create_test_handshaker(
    466       true /* used_for_success_test */, false /* is_client */);
    467   /* Client start. */
    468   GPR_ASSERT(tsi_handshaker_next(client_handshaker, nullptr, 0, nullptr,
    469                                  nullptr, nullptr, on_client_start_success_cb,
    470                                  nullptr) == TSI_ASYNC);
    471   wait(&tsi_to_caller_notification);
    472   /* Client next. */
    473   GPR_ASSERT(tsi_handshaker_next(
    474                  client_handshaker,
    475                  (const unsigned char*)ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES,
    476                  strlen(ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES), nullptr, nullptr,
    477                  nullptr, on_client_next_success_cb, nullptr) == TSI_ASYNC);
    478   wait(&tsi_to_caller_notification);
    479   /* Server start. */
    480   GPR_ASSERT(tsi_handshaker_next(server_handshaker, nullptr, 0, nullptr,
    481                                  nullptr, nullptr, on_server_start_success_cb,
    482                                  nullptr) == TSI_ASYNC);
    483   wait(&tsi_to_caller_notification);
    484   /* Server next. */
    485   GPR_ASSERT(tsi_handshaker_next(
    486                  server_handshaker,
    487                  (const unsigned char*)ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES,
    488                  strlen(ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES), nullptr, nullptr,
    489                  nullptr, on_server_next_success_cb, nullptr) == TSI_ASYNC);
    490   wait(&tsi_to_caller_notification);
    491   /* Cleanup. */
    492   tsi_handshaker_destroy(server_handshaker);
    493   tsi_handshaker_destroy(client_handshaker);
    494 }
    495 
    496 static void check_handshaker_next_with_shutdown() {
    497   /* Initialization. */
    498   tsi_handshaker* handshaker = create_test_handshaker(
    499       true /* used_for_success_test */, true /* is_client*/);
    500   /* next(success) -- shutdown(success) -- next (fail) */
    501   GPR_ASSERT(tsi_handshaker_next(handshaker, nullptr, 0, nullptr, nullptr,
    502                                  nullptr, on_client_start_success_cb,
    503                                  nullptr) == TSI_ASYNC);
    504   wait(&tsi_to_caller_notification);
    505   tsi_handshaker_shutdown(handshaker);
    506   GPR_ASSERT(tsi_handshaker_next(
    507                  handshaker,
    508                  (const unsigned char*)ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES,
    509                  strlen(ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES), nullptr, nullptr,
    510                  nullptr, on_client_next_success_cb,
    511                  nullptr) == TSI_HANDSHAKE_SHUTDOWN);
    512   /* Cleanup. */
    513   tsi_handshaker_destroy(handshaker);
    514 }
    515 
    516 static void check_handle_response_with_shutdown(void* unused) {
    517   /* Client start. */
    518   wait(&caller_to_tsi_notification);
    519   alts_tsi_event_dispatch_to_handshaker(client_start_event, true /* is_ok */);
    520   alts_tsi_event_destroy(client_start_event);
    521 }
    522 
    523 static void check_handshaker_next_failure() {
    524   /**
    525    * Create handshakers for which internal mock client is always going to fail.
    526    */
    527   tsi_handshaker* client_handshaker = create_test_handshaker(
    528       false /* used_for_success_test */, true /* is_client */);
    529   tsi_handshaker* server_handshaker = create_test_handshaker(
    530       false /* used_for_success_test */, false /* is_client */);
    531   /* Client start. */
    532   GPR_ASSERT(tsi_handshaker_next(client_handshaker, nullptr, 0, nullptr,
    533                                  nullptr, nullptr, check_must_not_be_called,
    534                                  nullptr) == TSI_INTERNAL_ERROR);
    535   /* Server start. */
    536   GPR_ASSERT(tsi_handshaker_next(server_handshaker, nullptr, 0, nullptr,
    537                                  nullptr, nullptr, check_must_not_be_called,
    538                                  nullptr) == TSI_INTERNAL_ERROR);
    539   /* Server next. */
    540   GPR_ASSERT(tsi_handshaker_next(
    541                  server_handshaker,
    542                  (const unsigned char*)ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES,
    543                  strlen(ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES), nullptr, nullptr,
    544                  nullptr, check_must_not_be_called,
    545                  nullptr) == TSI_INTERNAL_ERROR);
    546   /* Client next. */
    547   GPR_ASSERT(tsi_handshaker_next(
    548                  client_handshaker,
    549                  (const unsigned char*)ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES,
    550                  strlen(ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES), nullptr, nullptr,
    551                  nullptr, check_must_not_be_called,
    552                  nullptr) == TSI_INTERNAL_ERROR);
    553   /* Cleanup. */
    554   tsi_handshaker_destroy(server_handshaker);
    555   tsi_handshaker_destroy(client_handshaker);
    556 }
    557 
    558 static void on_invalid_input_cb(tsi_result status, void* user_data,
    559                                 const unsigned char* bytes_to_send,
    560                                 size_t bytes_to_send_size,
    561                                 tsi_handshaker_result* result) {
    562   GPR_ASSERT(status == TSI_INTERNAL_ERROR);
    563   GPR_ASSERT(user_data == nullptr);
    564   GPR_ASSERT(bytes_to_send == nullptr);
    565   GPR_ASSERT(bytes_to_send_size == 0);
    566   GPR_ASSERT(result == nullptr);
    567 }
    568 
    569 static void on_failed_grpc_call_cb(tsi_result status, void* user_data,
    570                                    const unsigned char* bytes_to_send,
    571                                    size_t bytes_to_send_size,
    572                                    tsi_handshaker_result* result) {
    573   GPR_ASSERT(status == TSI_INTERNAL_ERROR);
    574   GPR_ASSERT(user_data == nullptr);
    575   GPR_ASSERT(bytes_to_send == nullptr);
    576   GPR_ASSERT(bytes_to_send_size == 0);
    577   GPR_ASSERT(result == nullptr);
    578 }
    579 
    580 static void check_handle_response_invalid_input() {
    581   /**
    582    * Create a handshaker at the client side, for which internal mock client is
    583    * always going to fail.
    584    */
    585   tsi_handshaker* handshaker = create_test_handshaker(
    586       false /* used_for_success_test */, true /* is_client */);
    587   alts_tsi_handshaker* alts_handshaker =
    588       reinterpret_cast<alts_tsi_handshaker*>(handshaker);
    589   grpc_byte_buffer recv_buffer;
    590   /* Check nullptr handshaker. */
    591   alts_tsi_handshaker_handle_response(nullptr, &recv_buffer, GRPC_STATUS_OK,
    592                                       nullptr, on_invalid_input_cb, nullptr,
    593                                       true);
    594   /* Check nullptr recv_bytes. */
    595   alts_tsi_handshaker_handle_response(alts_handshaker, nullptr, GRPC_STATUS_OK,
    596                                       nullptr, on_invalid_input_cb, nullptr,
    597                                       true);
    598   /* Check failed grpc call made to handshaker service. */
    599   alts_tsi_handshaker_handle_response(alts_handshaker, &recv_buffer,
    600                                       GRPC_STATUS_UNKNOWN, nullptr,
    601                                       on_failed_grpc_call_cb, nullptr, true);
    602 
    603   alts_tsi_handshaker_handle_response(alts_handshaker, &recv_buffer,
    604                                       GRPC_STATUS_OK, nullptr,
    605                                       on_failed_grpc_call_cb, nullptr, false);
    606 
    607   /* Cleanup. */
    608   tsi_handshaker_destroy(handshaker);
    609 }
    610 
    611 static void on_invalid_resp_cb(tsi_result status, void* user_data,
    612                                const unsigned char* bytes_to_send,
    613                                size_t bytes_to_send_size,
    614                                tsi_handshaker_result* result) {
    615   GPR_ASSERT(status == TSI_DATA_CORRUPTED);
    616   GPR_ASSERT(user_data == nullptr);
    617   GPR_ASSERT(bytes_to_send == nullptr);
    618   GPR_ASSERT(bytes_to_send_size == 0);
    619   GPR_ASSERT(result == nullptr);
    620 }
    621 
    622 static void check_handle_response_invalid_resp() {
    623   /**
    624    * Create a handshaker at the client side, for which internal mock client is
    625    * always going to fail.
    626    */
    627   tsi_handshaker* handshaker = create_test_handshaker(
    628       false /* used_for_success_test */, true /* is_client */);
    629   alts_tsi_handshaker* alts_handshaker =
    630       reinterpret_cast<alts_tsi_handshaker*>(handshaker);
    631   /* Tests. */
    632   grpc_byte_buffer* recv_buffer = generate_handshaker_response(INVALID);
    633   alts_tsi_handshaker_handle_response(alts_handshaker, recv_buffer,
    634                                       GRPC_STATUS_OK, nullptr,
    635                                       on_invalid_resp_cb, nullptr, true);
    636   /* Cleanup. */
    637   grpc_byte_buffer_destroy(recv_buffer);
    638   tsi_handshaker_destroy(handshaker);
    639 }
    640 
    641 static void check_handle_response_success(void* unused) {
    642   /* Client start. */
    643   wait(&caller_to_tsi_notification);
    644   alts_tsi_event_dispatch_to_handshaker(client_start_event, true /* is_ok */);
    645   alts_tsi_event_destroy(client_start_event);
    646   /* Client next. */
    647   wait(&caller_to_tsi_notification);
    648   alts_tsi_event_dispatch_to_handshaker(client_next_event, true /* is_ok */);
    649   alts_tsi_event_destroy(client_next_event);
    650   /* Server start. */
    651   wait(&caller_to_tsi_notification);
    652   alts_tsi_event_dispatch_to_handshaker(server_start_event, true /* is_ok */);
    653   alts_tsi_event_destroy(server_start_event);
    654   /* Server next. */
    655   wait(&caller_to_tsi_notification);
    656   alts_tsi_event_dispatch_to_handshaker(server_next_event, true /* is_ok */);
    657   alts_tsi_event_destroy(server_next_event);
    658 }
    659 
    660 static void on_failed_resp_cb(tsi_result status, void* user_data,
    661                               const unsigned char* bytes_to_send,
    662                               size_t bytes_to_send_size,
    663                               tsi_handshaker_result* result) {
    664   GPR_ASSERT(status == TSI_INVALID_ARGUMENT);
    665   GPR_ASSERT(user_data == nullptr);
    666   GPR_ASSERT(bytes_to_send == nullptr);
    667   GPR_ASSERT(bytes_to_send_size == 0);
    668   GPR_ASSERT(result == nullptr);
    669 }
    670 
    671 static void check_handle_response_failure() {
    672   /**
    673    * Create a handshaker at the client side, for which internal mock client is
    674    * always going to fail.
    675    */
    676   tsi_handshaker* handshaker = create_test_handshaker(
    677       false /* used_for_success_test */, true /* is_client */);
    678   alts_tsi_handshaker* alts_handshaker =
    679       reinterpret_cast<alts_tsi_handshaker*>(handshaker);
    680   /* Tests. */
    681   grpc_byte_buffer* recv_buffer = generate_handshaker_response(FAILED);
    682   alts_tsi_handshaker_handle_response(alts_handshaker, recv_buffer,
    683                                       GRPC_STATUS_OK, nullptr,
    684                                       on_failed_resp_cb, nullptr, true);
    685   grpc_byte_buffer_destroy(recv_buffer);
    686   /* Cleanup. */
    687   tsi_handshaker_destroy(handshaker);
    688 }
    689 
    690 static void on_shutdown_resp_cb(tsi_result status, void* user_data,
    691                                 const unsigned char* bytes_to_send,
    692                                 size_t bytes_to_send_size,
    693                                 tsi_handshaker_result* result) {
    694   GPR_ASSERT(status == TSI_HANDSHAKE_SHUTDOWN);
    695   GPR_ASSERT(user_data == nullptr);
    696   GPR_ASSERT(bytes_to_send == nullptr);
    697   GPR_ASSERT(bytes_to_send_size == 0);
    698   GPR_ASSERT(result == nullptr);
    699 }
    700 
    701 static void check_handle_response_after_shutdown() {
    702   tsi_handshaker* handshaker = create_test_handshaker(
    703       true /* used_for_success_test */, true /* is_client */);
    704   alts_tsi_handshaker* alts_handshaker =
    705       reinterpret_cast<alts_tsi_handshaker*>(handshaker);
    706   /* Tests. */
    707   tsi_handshaker_shutdown(handshaker);
    708   grpc_byte_buffer* recv_buffer = generate_handshaker_response(CLIENT_START);
    709   alts_tsi_handshaker_handle_response(alts_handshaker, recv_buffer,
    710                                       GRPC_STATUS_OK, nullptr,
    711                                       on_shutdown_resp_cb, nullptr, true);
    712   grpc_byte_buffer_destroy(recv_buffer);
    713   /* Cleanup. */
    714   tsi_handshaker_destroy(handshaker);
    715 }
    716 
    717 void check_handshaker_next_fails_after_shutdown() {
    718   /* Initialization. */
    719   notification_init(&caller_to_tsi_notification);
    720   notification_init(&tsi_to_caller_notification);
    721   client_start_event = nullptr;
    722   /* Tests. */
    723   grpc_core::Thread thd("alts_tsi_handshaker_test",
    724                         &check_handle_response_with_shutdown, nullptr);
    725   thd.Start();
    726   check_handshaker_next_with_shutdown();
    727   thd.Join();
    728   /* Cleanup. */
    729   notification_destroy(&caller_to_tsi_notification);
    730   notification_destroy(&tsi_to_caller_notification);
    731 }
    732 
    733 void check_handshaker_success() {
    734   /* Initialization. */
    735   notification_init(&caller_to_tsi_notification);
    736   notification_init(&tsi_to_caller_notification);
    737   client_start_event = nullptr;
    738   client_next_event = nullptr;
    739   server_start_event = nullptr;
    740   server_next_event = nullptr;
    741   /* Tests. */
    742   grpc_core::Thread thd("alts_tsi_handshaker_test",
    743                         &check_handle_response_success, nullptr);
    744   thd.Start();
    745   check_handshaker_next_success();
    746   thd.Join();
    747   /* Cleanup. */
    748   notification_destroy(&caller_to_tsi_notification);
    749   notification_destroy(&tsi_to_caller_notification);
    750 }
    751 
    752 int main(int argc, char** argv) {
    753   /* Initialization. */
    754   grpc_init();
    755   /* Tests. */
    756   check_handshaker_success();
    757   check_handshaker_next_invalid_input();
    758   check_handshaker_shutdown_invalid_input();
    759   check_handshaker_next_fails_after_shutdown();
    760   check_handshaker_next_failure();
    761   check_handle_response_invalid_input();
    762   check_handle_response_invalid_resp();
    763   check_handle_response_failure();
    764   check_handle_response_after_shutdown();
    765   /* Cleanup. */
    766   grpc_shutdown();
    767   return 0;
    768 }
    769