Home | History | Annotate | Download | only in handshaker
      1 /*
      2  *
      3  * Copyright 2018 gRPC authors.
      4  *
      5  * Licensed under the Apache License, Version 2.0 (the "License");
      6  * you may not use this file except in compliance with the License.
      7  * You may obtain a copy of the License at
      8  *
      9  *     http://www.apache.org/licenses/LICENSE-2.0
     10  *
     11  * Unless required by applicable law or agreed to in writing, software
     12  * distributed under the License is distributed on an "AS IS" BASIS,
     13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14  * See the License for the specific language governing permissions and
     15  * limitations under the License.
     16  *
     17  */
     18 
     19 #include <grpc/support/port_platform.h>
     20 
     21 #include "src/core/tsi/alts/handshaker/alts_handshaker_service_api.h"
     22 
     23 #include <stdio.h>
     24 #include <stdlib.h>
     25 
     26 #include "src/core/tsi/alts/handshaker/transport_security_common_api.h"
     27 
     28 /* HandshakerReq */
     29 grpc_gcp_handshaker_req* grpc_gcp_handshaker_req_create(
     30     grpc_gcp_handshaker_req_type type) {
     31   grpc_gcp_handshaker_req* req =
     32       static_cast<grpc_gcp_handshaker_req*>(gpr_zalloc(sizeof(*req)));
     33   switch (type) {
     34     case CLIENT_START_REQ:
     35       req->has_client_start = true;
     36       break;
     37     case SERVER_START_REQ:
     38       req->has_server_start = true;
     39       break;
     40     case NEXT_REQ:
     41       req->has_next = true;
     42       break;
     43   }
     44   return req;
     45 }
     46 
     47 void grpc_gcp_handshaker_req_destroy(grpc_gcp_handshaker_req* req) {
     48   if (req == nullptr) {
     49     return;
     50   }
     51   if (req->has_client_start) {
     52     /* Destroy client_start request. */
     53     destroy_repeated_field_list_identity(
     54         static_cast<repeated_field*>(req->client_start.target_identities.arg));
     55     destroy_repeated_field_list_string(static_cast<repeated_field*>(
     56         req->client_start.application_protocols.arg));
     57     destroy_repeated_field_list_string(
     58         static_cast<repeated_field*>(req->client_start.record_protocols.arg));
     59     if (req->client_start.has_local_identity) {
     60       destroy_slice(static_cast<grpc_slice*>(
     61           req->client_start.local_identity.hostname.arg));
     62       destroy_slice(static_cast<grpc_slice*>(
     63           req->client_start.local_identity.service_account.arg));
     64     }
     65     if (req->client_start.has_local_endpoint) {
     66       destroy_slice(static_cast<grpc_slice*>(
     67           req->client_start.local_endpoint.ip_address.arg));
     68     }
     69     if (req->client_start.has_remote_endpoint) {
     70       destroy_slice(static_cast<grpc_slice*>(
     71           req->client_start.remote_endpoint.ip_address.arg));
     72     }
     73     destroy_slice(static_cast<grpc_slice*>(req->client_start.target_name.arg));
     74   } else if (req->has_server_start) {
     75     /* Destroy server_start request. */
     76     size_t i = 0;
     77     for (i = 0; i < req->server_start.handshake_parameters_count; i++) {
     78       destroy_repeated_field_list_identity(
     79           static_cast<repeated_field*>(req->server_start.handshake_parameters[i]
     80                                            .value.local_identities.arg));
     81       destroy_repeated_field_list_string(
     82           static_cast<repeated_field*>(req->server_start.handshake_parameters[i]
     83                                            .value.record_protocols.arg));
     84     }
     85     destroy_repeated_field_list_string(static_cast<repeated_field*>(
     86         req->server_start.application_protocols.arg));
     87     if (req->server_start.has_local_endpoint) {
     88       destroy_slice(static_cast<grpc_slice*>(
     89           req->server_start.local_endpoint.ip_address.arg));
     90     }
     91     if (req->server_start.has_remote_endpoint) {
     92       destroy_slice(static_cast<grpc_slice*>(
     93           req->server_start.remote_endpoint.ip_address.arg));
     94     }
     95     destroy_slice(static_cast<grpc_slice*>(req->server_start.in_bytes.arg));
     96   } else {
     97     /* Destroy next request. */
     98     destroy_slice(static_cast<grpc_slice*>(req->next.in_bytes.arg));
     99   }
    100   gpr_free(req);
    101 }
    102 
    103 bool grpc_gcp_handshaker_req_set_handshake_protocol(
    104     grpc_gcp_handshaker_req* req,
    105     grpc_gcp_handshake_protocol handshake_protocol) {
    106   if (req == nullptr || !req->has_client_start) {
    107     gpr_log(GPR_ERROR,
    108             "Invalid arguments to "
    109             "grpc_gcp_handshaker_req_set_handshake_protocol().");
    110     return false;
    111   }
    112   req->client_start.has_handshake_security_protocol = true;
    113   req->client_start.handshake_security_protocol = handshake_protocol;
    114   return true;
    115 }
    116 
    117 bool grpc_gcp_handshaker_req_set_target_name(grpc_gcp_handshaker_req* req,
    118                                              const char* target_name) {
    119   if (req == nullptr || target_name == nullptr || !req->has_client_start) {
    120     gpr_log(GPR_ERROR,
    121             "Invalid arguments to "
    122             "grpc_gcp_handshaker_req_set_target_name().");
    123     return false;
    124   }
    125   grpc_slice* slice = create_slice(target_name, strlen(target_name));
    126   req->client_start.target_name.arg = slice;
    127   req->client_start.target_name.funcs.encode = encode_string_or_bytes_cb;
    128   return true;
    129 }
    130 
    131 bool grpc_gcp_handshaker_req_add_application_protocol(
    132     grpc_gcp_handshaker_req* req, const char* application_protocol) {
    133   if (req == nullptr || application_protocol == nullptr || req->has_next) {
    134     gpr_log(GPR_ERROR,
    135             "Invalid arguments to "
    136             "grpc_gcp_handshaker_req_add_application_protocol().");
    137     return false;
    138   }
    139   grpc_slice* slice =
    140       create_slice(application_protocol, strlen(application_protocol));
    141   if (req->has_client_start) {
    142     add_repeated_field(reinterpret_cast<repeated_field**>(
    143                            &req->client_start.application_protocols.arg),
    144                        slice);
    145     req->client_start.application_protocols.funcs.encode =
    146         encode_repeated_string_cb;
    147   } else {
    148     add_repeated_field(reinterpret_cast<repeated_field**>(
    149                            &req->server_start.application_protocols.arg),
    150                        slice);
    151     req->server_start.application_protocols.funcs.encode =
    152         encode_repeated_string_cb;
    153   }
    154   return true;
    155 }
    156 
    157 bool grpc_gcp_handshaker_req_add_record_protocol(grpc_gcp_handshaker_req* req,
    158                                                  const char* record_protocol) {
    159   if (req == nullptr || record_protocol == nullptr || !req->has_client_start) {
    160     gpr_log(GPR_ERROR,
    161             "Invalid arguments to "
    162             "grpc_gcp_handshaker_req_add_record_protocol().");
    163     return false;
    164   }
    165   grpc_slice* slice = create_slice(record_protocol, strlen(record_protocol));
    166   add_repeated_field(reinterpret_cast<repeated_field**>(
    167                          &req->client_start.record_protocols.arg),
    168                      slice);
    169   req->client_start.record_protocols.funcs.encode = encode_repeated_string_cb;
    170   return true;
    171 }
    172 
    173 static void set_identity_hostname(grpc_gcp_identity* identity,
    174                                   const char* hostname) {
    175   grpc_slice* slice = create_slice(hostname, strlen(hostname));
    176   identity->hostname.arg = slice;
    177   identity->hostname.funcs.encode = encode_string_or_bytes_cb;
    178 }
    179 
    180 static void set_identity_service_account(grpc_gcp_identity* identity,
    181                                          const char* service_account) {
    182   grpc_slice* slice = create_slice(service_account, strlen(service_account));
    183   identity->service_account.arg = slice;
    184   identity->service_account.funcs.encode = encode_string_or_bytes_cb;
    185 }
    186 
    187 bool grpc_gcp_handshaker_req_add_target_identity_hostname(
    188     grpc_gcp_handshaker_req* req, const char* hostname) {
    189   if (req == nullptr || hostname == nullptr || !req->has_client_start) {
    190     gpr_log(GPR_ERROR,
    191             "Invalid nullptr arguments to "
    192             "grpc_gcp_handshaker_req_add_target_identity_hostname().");
    193     return false;
    194   }
    195   grpc_gcp_identity* target_identity =
    196       static_cast<grpc_gcp_identity*>(gpr_zalloc(sizeof(*target_identity)));
    197   set_identity_hostname(target_identity, hostname);
    198   req->client_start.target_identities.funcs.encode =
    199       encode_repeated_identity_cb;
    200   add_repeated_field(reinterpret_cast<repeated_field**>(
    201                          &req->client_start.target_identities.arg),
    202                      target_identity);
    203   return true;
    204 }
    205 
    206 bool grpc_gcp_handshaker_req_add_target_identity_service_account(
    207     grpc_gcp_handshaker_req* req, const char* service_account) {
    208   if (req == nullptr || service_account == nullptr || !req->has_client_start) {
    209     gpr_log(GPR_ERROR,
    210             "Invalid nullptr arguments to "
    211             "grpc_gcp_handshaker_req_add_target_identity_service_account().");
    212     return false;
    213   }
    214   grpc_gcp_identity* target_identity =
    215       static_cast<grpc_gcp_identity*>(gpr_zalloc(sizeof(*target_identity)));
    216   set_identity_service_account(target_identity, service_account);
    217   req->client_start.target_identities.funcs.encode =
    218       encode_repeated_identity_cb;
    219   add_repeated_field(reinterpret_cast<repeated_field**>(
    220                          &req->client_start.target_identities.arg),
    221                      target_identity);
    222   return true;
    223 }
    224 
    225 bool grpc_gcp_handshaker_req_set_local_identity_hostname(
    226     grpc_gcp_handshaker_req* req, const char* hostname) {
    227   if (req == nullptr || hostname == nullptr || !req->has_client_start) {
    228     gpr_log(GPR_ERROR,
    229             "Invalid nullptr arguments to "
    230             "grpc_gcp_handshaker_req_set_local_identity_hostname().");
    231     return false;
    232   }
    233   req->client_start.has_local_identity = true;
    234   set_identity_hostname(&req->client_start.local_identity, hostname);
    235   return true;
    236 }
    237 
    238 bool grpc_gcp_handshaker_req_set_local_identity_service_account(
    239     grpc_gcp_handshaker_req* req, const char* service_account) {
    240   if (req == nullptr || service_account == nullptr || !req->has_client_start) {
    241     gpr_log(GPR_ERROR,
    242             "Invalid nullptr arguments to "
    243             "grpc_gcp_handshaker_req_set_local_identity_service_account().");
    244     return false;
    245   }
    246   req->client_start.has_local_identity = true;
    247   set_identity_service_account(&req->client_start.local_identity,
    248                                service_account);
    249   return true;
    250 }
    251 
    252 static void set_endpoint(grpc_gcp_endpoint* endpoint, const char* ip_address,
    253                          size_t port, grpc_gcp_network_protocol protocol) {
    254   grpc_slice* slice = create_slice(ip_address, strlen(ip_address));
    255   endpoint->ip_address.arg = slice;
    256   endpoint->ip_address.funcs.encode = encode_string_or_bytes_cb;
    257   endpoint->has_port = true;
    258   endpoint->port = static_cast<int32_t>(port);
    259   endpoint->has_protocol = true;
    260   endpoint->protocol = protocol;
    261 }
    262 
    263 bool grpc_gcp_handshaker_req_set_rpc_versions(grpc_gcp_handshaker_req* req,
    264                                               uint32_t max_major,
    265                                               uint32_t max_minor,
    266                                               uint32_t min_major,
    267                                               uint32_t min_minor) {
    268   if (req == nullptr || req->has_next) {
    269     gpr_log(GPR_ERROR,
    270             "Invalid arguments to "
    271             "grpc_gcp_handshaker_req_set_rpc_versions().");
    272     return false;
    273   }
    274   if (req->has_client_start) {
    275     req->client_start.has_rpc_versions = true;
    276     grpc_gcp_rpc_protocol_versions_set_max(&req->client_start.rpc_versions,
    277                                            max_major, max_minor);
    278     grpc_gcp_rpc_protocol_versions_set_min(&req->client_start.rpc_versions,
    279                                            min_major, min_minor);
    280   } else {
    281     req->server_start.has_rpc_versions = true;
    282     grpc_gcp_rpc_protocol_versions_set_max(&req->server_start.rpc_versions,
    283                                            max_major, max_minor);
    284     grpc_gcp_rpc_protocol_versions_set_min(&req->server_start.rpc_versions,
    285                                            min_major, min_minor);
    286   }
    287   return true;
    288 }
    289 
    290 bool grpc_gcp_handshaker_req_set_local_endpoint(
    291     grpc_gcp_handshaker_req* req, const char* ip_address, size_t port,
    292     grpc_gcp_network_protocol protocol) {
    293   if (req == nullptr || ip_address == nullptr || port > 65535 ||
    294       req->has_next) {
    295     gpr_log(GPR_ERROR,
    296             "Invalid arguments to "
    297             "grpc_gcp_handshaker_req_set_local_endpoint().");
    298     return false;
    299   }
    300   if (req->has_client_start) {
    301     req->client_start.has_local_endpoint = true;
    302     set_endpoint(&req->client_start.local_endpoint, ip_address, port, protocol);
    303   } else {
    304     req->server_start.has_local_endpoint = true;
    305     set_endpoint(&req->server_start.local_endpoint, ip_address, port, protocol);
    306   }
    307   return true;
    308 }
    309 
    310 bool grpc_gcp_handshaker_req_set_remote_endpoint(
    311     grpc_gcp_handshaker_req* req, const char* ip_address, size_t port,
    312     grpc_gcp_network_protocol protocol) {
    313   if (req == nullptr || ip_address == nullptr || port > 65535 ||
    314       req->has_next) {
    315     gpr_log(GPR_ERROR,
    316             "Invalid arguments to "
    317             "grpc_gcp_handshaker_req_set_remote_endpoint().");
    318     return false;
    319   }
    320   if (req->has_client_start) {
    321     req->client_start.has_remote_endpoint = true;
    322     set_endpoint(&req->client_start.remote_endpoint, ip_address, port,
    323                  protocol);
    324   } else {
    325     req->server_start.has_remote_endpoint = true;
    326     set_endpoint(&req->server_start.remote_endpoint, ip_address, port,
    327                  protocol);
    328   }
    329   return true;
    330 }
    331 
    332 bool grpc_gcp_handshaker_req_set_in_bytes(grpc_gcp_handshaker_req* req,
    333                                           const char* in_bytes, size_t size) {
    334   if (req == nullptr || in_bytes == nullptr || req->has_client_start) {
    335     gpr_log(GPR_ERROR,
    336             "Invalid arguments to "
    337             "grpc_gcp_handshaker_req_set_in_bytes().");
    338     return false;
    339   }
    340   grpc_slice* slice = create_slice(in_bytes, size);
    341   if (req->has_next) {
    342     req->next.in_bytes.arg = slice;
    343     req->next.in_bytes.funcs.encode = &encode_string_or_bytes_cb;
    344   } else {
    345     req->server_start.in_bytes.arg = slice;
    346     req->server_start.in_bytes.funcs.encode = &encode_string_or_bytes_cb;
    347   }
    348   return true;
    349 }
    350 
    351 static grpc_gcp_server_handshake_parameters* server_start_find_param(
    352     grpc_gcp_handshaker_req* req, int32_t key) {
    353   size_t i = 0;
    354   for (i = 0; i < req->server_start.handshake_parameters_count; i++) {
    355     if (req->server_start.handshake_parameters[i].key == key) {
    356       return &req->server_start.handshake_parameters[i].value;
    357     }
    358   }
    359   req->server_start
    360       .handshake_parameters[req->server_start.handshake_parameters_count]
    361       .has_key = true;
    362   req->server_start
    363       .handshake_parameters[req->server_start.handshake_parameters_count]
    364       .has_value = true;
    365   req->server_start
    366       .handshake_parameters[req->server_start.handshake_parameters_count++]
    367       .key = key;
    368   return &req->server_start
    369               .handshake_parameters
    370                   [req->server_start.handshake_parameters_count - 1]
    371               .value;
    372 }
    373 
    374 bool grpc_gcp_handshaker_req_param_add_record_protocol(
    375     grpc_gcp_handshaker_req* req, grpc_gcp_handshake_protocol key,
    376     const char* record_protocol) {
    377   if (req == nullptr || record_protocol == nullptr || !req->has_server_start) {
    378     gpr_log(GPR_ERROR,
    379             "Invalid arguments to "
    380             "grpc_gcp_handshaker_req_param_add_record_protocol().");
    381     return false;
    382   }
    383   grpc_gcp_server_handshake_parameters* param =
    384       server_start_find_param(req, key);
    385   grpc_slice* slice = create_slice(record_protocol, strlen(record_protocol));
    386   add_repeated_field(
    387       reinterpret_cast<repeated_field**>(&param->record_protocols.arg), slice);
    388   param->record_protocols.funcs.encode = &encode_repeated_string_cb;
    389   return true;
    390 }
    391 
    392 bool grpc_gcp_handshaker_req_param_add_local_identity_hostname(
    393     grpc_gcp_handshaker_req* req, grpc_gcp_handshake_protocol key,
    394     const char* hostname) {
    395   if (req == nullptr || hostname == nullptr || !req->has_server_start) {
    396     gpr_log(GPR_ERROR,
    397             "Invalid arguments to "
    398             "grpc_gcp_handshaker_req_param_add_local_identity_hostname().");
    399     return false;
    400   }
    401   grpc_gcp_server_handshake_parameters* param =
    402       server_start_find_param(req, key);
    403   grpc_gcp_identity* local_identity =
    404       static_cast<grpc_gcp_identity*>(gpr_zalloc(sizeof(*local_identity)));
    405   set_identity_hostname(local_identity, hostname);
    406   add_repeated_field(
    407       reinterpret_cast<repeated_field**>(&param->local_identities.arg),
    408       local_identity);
    409   param->local_identities.funcs.encode = &encode_repeated_identity_cb;
    410   return true;
    411 }
    412 
    413 bool grpc_gcp_handshaker_req_param_add_local_identity_service_account(
    414     grpc_gcp_handshaker_req* req, grpc_gcp_handshake_protocol key,
    415     const char* service_account) {
    416   if (req == nullptr || service_account == nullptr || !req->has_server_start) {
    417     gpr_log(
    418         GPR_ERROR,
    419         "Invalid arguments to "
    420         "grpc_gcp_handshaker_req_param_add_local_identity_service_account().");
    421     return false;
    422   }
    423   grpc_gcp_server_handshake_parameters* param =
    424       server_start_find_param(req, key);
    425   grpc_gcp_identity* local_identity =
    426       static_cast<grpc_gcp_identity*>(gpr_zalloc(sizeof(*local_identity)));
    427   set_identity_service_account(local_identity, service_account);
    428   add_repeated_field(
    429       reinterpret_cast<repeated_field**>(&param->local_identities.arg),
    430       local_identity);
    431   param->local_identities.funcs.encode = &encode_repeated_identity_cb;
    432   return true;
    433 }
    434 
    435 bool grpc_gcp_handshaker_req_encode(grpc_gcp_handshaker_req* req,
    436                                     grpc_slice* slice) {
    437   if (req == nullptr || slice == nullptr) {
    438     gpr_log(GPR_ERROR,
    439             "Invalid nullptr arguments to grpc_gcp_handshaker_req_encode().");
    440     return false;
    441   }
    442   pb_ostream_t size_stream;
    443   memset(&size_stream, 0, sizeof(pb_ostream_t));
    444   if (!pb_encode(&size_stream, grpc_gcp_HandshakerReq_fields, req)) {
    445     gpr_log(GPR_ERROR, "nanopb error: %s", PB_GET_ERROR(&size_stream));
    446     return false;
    447   }
    448   size_t encoded_length = size_stream.bytes_written;
    449   *slice = grpc_slice_malloc(encoded_length);
    450   pb_ostream_t output_stream =
    451       pb_ostream_from_buffer(GRPC_SLICE_START_PTR(*slice), encoded_length);
    452   if (!pb_encode(&output_stream, grpc_gcp_HandshakerReq_fields, req) != 0) {
    453     gpr_log(GPR_ERROR, "nanopb error: %s", PB_GET_ERROR(&output_stream));
    454     return false;
    455   }
    456   return true;
    457 }
    458 
    459 /* HandshakerResp. */
    460 grpc_gcp_handshaker_resp* grpc_gcp_handshaker_resp_create(void) {
    461   grpc_gcp_handshaker_resp* resp =
    462       static_cast<grpc_gcp_handshaker_resp*>(gpr_zalloc(sizeof(*resp)));
    463   return resp;
    464 }
    465 
    466 void grpc_gcp_handshaker_resp_destroy(grpc_gcp_handshaker_resp* resp) {
    467   if (resp != nullptr) {
    468     destroy_slice(static_cast<grpc_slice*>(resp->out_frames.arg));
    469     if (resp->has_status) {
    470       destroy_slice(static_cast<grpc_slice*>(resp->status.details.arg));
    471     }
    472     if (resp->has_result) {
    473       destroy_slice(
    474           static_cast<grpc_slice*>(resp->result.application_protocol.arg));
    475       destroy_slice(static_cast<grpc_slice*>(resp->result.record_protocol.arg));
    476       destroy_slice(static_cast<grpc_slice*>(resp->result.key_data.arg));
    477       if (resp->result.has_local_identity) {
    478         destroy_slice(
    479             static_cast<grpc_slice*>(resp->result.local_identity.hostname.arg));
    480         destroy_slice(static_cast<grpc_slice*>(
    481             resp->result.local_identity.service_account.arg));
    482       }
    483       if (resp->result.has_peer_identity) {
    484         destroy_slice(
    485             static_cast<grpc_slice*>(resp->result.peer_identity.hostname.arg));
    486         destroy_slice(static_cast<grpc_slice*>(
    487             resp->result.peer_identity.service_account.arg));
    488       }
    489     }
    490     gpr_free(resp);
    491   }
    492 }
    493 
    494 bool grpc_gcp_handshaker_resp_decode(grpc_slice encoded_handshaker_resp,
    495                                      grpc_gcp_handshaker_resp* resp) {
    496   if (resp == nullptr) {
    497     gpr_log(GPR_ERROR,
    498             "Invalid nullptr argument to grpc_gcp_handshaker_resp_decode().");
    499     return false;
    500   }
    501   pb_istream_t stream =
    502       pb_istream_from_buffer(GRPC_SLICE_START_PTR(encoded_handshaker_resp),
    503                              GRPC_SLICE_LENGTH(encoded_handshaker_resp));
    504   resp->out_frames.funcs.decode = decode_string_or_bytes_cb;
    505   resp->status.details.funcs.decode = decode_string_or_bytes_cb;
    506   resp->result.application_protocol.funcs.decode = decode_string_or_bytes_cb;
    507   resp->result.record_protocol.funcs.decode = decode_string_or_bytes_cb;
    508   resp->result.key_data.funcs.decode = decode_string_or_bytes_cb;
    509   resp->result.peer_identity.hostname.funcs.decode = decode_string_or_bytes_cb;
    510   resp->result.peer_identity.service_account.funcs.decode =
    511       decode_string_or_bytes_cb;
    512   resp->result.local_identity.hostname.funcs.decode = decode_string_or_bytes_cb;
    513   resp->result.local_identity.service_account.funcs.decode =
    514       decode_string_or_bytes_cb;
    515   if (!pb_decode(&stream, grpc_gcp_HandshakerResp_fields, resp)) {
    516     gpr_log(GPR_ERROR, "nanopb error: %s", PB_GET_ERROR(&stream));
    517     return false;
    518   }
    519   return true;
    520 }
    521