Home | History | Annotate | Download | only in transport
      1 /*
      2  *
      3  * Copyright 2015 gRPC authors.
      4  *
      5  * Licensed under the Apache License, Version 2.0 (the "License");
      6  * you may not use this file except in compliance with the License.
      7  * You may obtain a copy of the License at
      8  *
      9  *     http://www.apache.org/licenses/LICENSE-2.0
     10  *
     11  * Unless required by applicable law or agreed to in writing, software
     12  * distributed under the License is distributed on an "AS IS" BASIS,
     13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14  * See the License for the specific language governing permissions and
     15  * limitations under the License.
     16  *
     17  */
     18 
     19 #include <grpc/support/port_platform.h>
     20 
     21 #include <string.h>
     22 
     23 #include <grpc/support/alloc.h>
     24 #include <grpc/support/log.h>
     25 
     26 #include "src/core/lib/security/context/security_context.h"
     27 #include "src/core/lib/security/credentials/credentials.h"
     28 #include "src/core/lib/security/transport/auth_filters.h"
     29 #include "src/core/lib/slice/slice_internal.h"
     30 
     31 namespace {
     32 enum async_state {
     33   STATE_INIT = 0,
     34   STATE_DONE,
     35   STATE_CANCELLED,
     36 };
     37 
     38 struct call_data {
     39   grpc_call_combiner* call_combiner;
     40   grpc_call_stack* owning_call;
     41   grpc_transport_stream_op_batch* recv_initial_metadata_batch;
     42   grpc_closure* original_recv_initial_metadata_ready;
     43   grpc_closure recv_initial_metadata_ready;
     44   grpc_error* error;
     45   grpc_closure recv_trailing_metadata_ready;
     46   grpc_closure* original_recv_trailing_metadata_ready;
     47   grpc_metadata_array md;
     48   const grpc_metadata* consumed_md;
     49   size_t num_consumed_md;
     50   grpc_closure cancel_closure;
     51   gpr_atm state;  // async_state
     52 };
     53 
     54 struct channel_data {
     55   grpc_auth_context* auth_context;
     56   grpc_server_credentials* creds;
     57 };
     58 }  // namespace
     59 
     60 static grpc_metadata_array metadata_batch_to_md_array(
     61     const grpc_metadata_batch* batch) {
     62   grpc_linked_mdelem* l;
     63   grpc_metadata_array result;
     64   grpc_metadata_array_init(&result);
     65   for (l = batch->list.head; l != nullptr; l = l->next) {
     66     grpc_metadata* usr_md = nullptr;
     67     grpc_mdelem md = l->md;
     68     grpc_slice key = GRPC_MDKEY(md);
     69     grpc_slice value = GRPC_MDVALUE(md);
     70     if (result.count == result.capacity) {
     71       result.capacity = GPR_MAX(result.capacity + 8, result.capacity * 2);
     72       result.metadata = static_cast<grpc_metadata*>(gpr_realloc(
     73           result.metadata, result.capacity * sizeof(grpc_metadata)));
     74     }
     75     usr_md = &result.metadata[result.count++];
     76     usr_md->key = grpc_slice_ref_internal(key);
     77     usr_md->value = grpc_slice_ref_internal(value);
     78   }
     79   return result;
     80 }
     81 
     82 static grpc_filtered_mdelem remove_consumed_md(void* user_data,
     83                                                grpc_mdelem md) {
     84   grpc_call_element* elem = static_cast<grpc_call_element*>(user_data);
     85   call_data* calld = static_cast<call_data*>(elem->call_data);
     86   size_t i;
     87   for (i = 0; i < calld->num_consumed_md; i++) {
     88     const grpc_metadata* consumed_md = &calld->consumed_md[i];
     89     if (grpc_slice_eq(GRPC_MDKEY(md), consumed_md->key) &&
     90         grpc_slice_eq(GRPC_MDVALUE(md), consumed_md->value))
     91       return GRPC_FILTERED_REMOVE();
     92   }
     93   return GRPC_FILTERED_MDELEM(md);
     94 }
     95 
     96 static void on_md_processing_done_inner(grpc_call_element* elem,
     97                                         const grpc_metadata* consumed_md,
     98                                         size_t num_consumed_md,
     99                                         const grpc_metadata* response_md,
    100                                         size_t num_response_md,
    101                                         grpc_error* error) {
    102   call_data* calld = static_cast<call_data*>(elem->call_data);
    103   grpc_transport_stream_op_batch* batch = calld->recv_initial_metadata_batch;
    104   /* TODO(jboeuf): Implement support for response_md. */
    105   if (response_md != nullptr && num_response_md > 0) {
    106     gpr_log(GPR_INFO,
    107             "response_md in auth metadata processing not supported for now. "
    108             "Ignoring...");
    109   }
    110   if (error == GRPC_ERROR_NONE) {
    111     calld->consumed_md = consumed_md;
    112     calld->num_consumed_md = num_consumed_md;
    113     error = grpc_metadata_batch_filter(
    114         batch->payload->recv_initial_metadata.recv_initial_metadata,
    115         remove_consumed_md, elem, "Response metadata filtering error");
    116   }
    117   calld->error = GRPC_ERROR_REF(error);
    118   GRPC_CLOSURE_SCHED(calld->original_recv_initial_metadata_ready, error);
    119 }
    120 
    121 // Called from application code.
    122 static void on_md_processing_done(
    123     void* user_data, const grpc_metadata* consumed_md, size_t num_consumed_md,
    124     const grpc_metadata* response_md, size_t num_response_md,
    125     grpc_status_code status, const char* error_details) {
    126   grpc_call_element* elem = static_cast<grpc_call_element*>(user_data);
    127   call_data* calld = static_cast<call_data*>(elem->call_data);
    128   grpc_core::ExecCtx exec_ctx;
    129   // If the call was not cancelled while we were in flight, process the result.
    130   if (gpr_atm_full_cas(&calld->state, static_cast<gpr_atm>(STATE_INIT),
    131                        static_cast<gpr_atm>(STATE_DONE))) {
    132     grpc_error* error = GRPC_ERROR_NONE;
    133     if (status != GRPC_STATUS_OK) {
    134       if (error_details == nullptr) {
    135         error_details = "Authentication metadata processing failed.";
    136       }
    137       error = grpc_error_set_int(
    138           GRPC_ERROR_CREATE_FROM_COPIED_STRING(error_details),
    139           GRPC_ERROR_INT_GRPC_STATUS, status);
    140     }
    141     on_md_processing_done_inner(elem, consumed_md, num_consumed_md, response_md,
    142                                 num_response_md, error);
    143   }
    144   // Clean up.
    145   for (size_t i = 0; i < calld->md.count; i++) {
    146     grpc_slice_unref_internal(calld->md.metadata[i].key);
    147     grpc_slice_unref_internal(calld->md.metadata[i].value);
    148   }
    149   grpc_metadata_array_destroy(&calld->md);
    150   GRPC_CALL_STACK_UNREF(calld->owning_call, "server_auth_metadata");
    151 }
    152 
    153 static void cancel_call(void* arg, grpc_error* error) {
    154   grpc_call_element* elem = static_cast<grpc_call_element*>(arg);
    155   call_data* calld = static_cast<call_data*>(elem->call_data);
    156   // If the result was not already processed, invoke the callback now.
    157   if (error != GRPC_ERROR_NONE &&
    158       gpr_atm_full_cas(&calld->state, static_cast<gpr_atm>(STATE_INIT),
    159                        static_cast<gpr_atm>(STATE_CANCELLED))) {
    160     on_md_processing_done_inner(elem, nullptr, 0, nullptr, 0,
    161                                 GRPC_ERROR_REF(error));
    162   }
    163 }
    164 
    165 static void recv_initial_metadata_ready(void* arg, grpc_error* error) {
    166   grpc_call_element* elem = static_cast<grpc_call_element*>(arg);
    167   channel_data* chand = static_cast<channel_data*>(elem->channel_data);
    168   call_data* calld = static_cast<call_data*>(elem->call_data);
    169   grpc_transport_stream_op_batch* batch = calld->recv_initial_metadata_batch;
    170   if (error == GRPC_ERROR_NONE) {
    171     if (chand->creds != nullptr && chand->creds->processor.process != nullptr) {
    172       // We're calling out to the application, so we need to make sure
    173       // to drop the call combiner early if we get cancelled.
    174       GRPC_CLOSURE_INIT(&calld->cancel_closure, cancel_call, elem,
    175                         grpc_schedule_on_exec_ctx);
    176       grpc_call_combiner_set_notify_on_cancel(calld->call_combiner,
    177                                               &calld->cancel_closure);
    178       GRPC_CALL_STACK_REF(calld->owning_call, "server_auth_metadata");
    179       calld->md = metadata_batch_to_md_array(
    180           batch->payload->recv_initial_metadata.recv_initial_metadata);
    181       chand->creds->processor.process(
    182           chand->creds->processor.state, chand->auth_context,
    183           calld->md.metadata, calld->md.count, on_md_processing_done, elem);
    184       return;
    185     }
    186   }
    187   GRPC_CLOSURE_RUN(calld->original_recv_initial_metadata_ready,
    188                    GRPC_ERROR_REF(error));
    189 }
    190 
    191 static void recv_trailing_metadata_ready(void* user_data, grpc_error* err) {
    192   grpc_call_element* elem = static_cast<grpc_call_element*>(user_data);
    193   call_data* calld = static_cast<call_data*>(elem->call_data);
    194   err = grpc_error_add_child(GRPC_ERROR_REF(err), GRPC_ERROR_REF(calld->error));
    195   GRPC_CLOSURE_RUN(calld->original_recv_trailing_metadata_ready, err);
    196 }
    197 
    198 static void auth_start_transport_stream_op_batch(
    199     grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
    200   call_data* calld = static_cast<call_data*>(elem->call_data);
    201   if (batch->recv_initial_metadata) {
    202     // Inject our callback.
    203     calld->recv_initial_metadata_batch = batch;
    204     calld->original_recv_initial_metadata_ready =
    205         batch->payload->recv_initial_metadata.recv_initial_metadata_ready;
    206     batch->payload->recv_initial_metadata.recv_initial_metadata_ready =
    207         &calld->recv_initial_metadata_ready;
    208   }
    209   if (batch->recv_trailing_metadata) {
    210     calld->original_recv_trailing_metadata_ready =
    211         batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready;
    212     batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready =
    213         &calld->recv_trailing_metadata_ready;
    214   }
    215   grpc_call_next_op(elem, batch);
    216 }
    217 
    218 /* Constructor for call_data */
    219 static grpc_error* init_call_elem(grpc_call_element* elem,
    220                                   const grpc_call_element_args* args) {
    221   call_data* calld = static_cast<call_data*>(elem->call_data);
    222   channel_data* chand = static_cast<channel_data*>(elem->channel_data);
    223   calld->call_combiner = args->call_combiner;
    224   calld->owning_call = args->call_stack;
    225   GRPC_CLOSURE_INIT(&calld->recv_initial_metadata_ready,
    226                     recv_initial_metadata_ready, elem,
    227                     grpc_schedule_on_exec_ctx);
    228   GRPC_CLOSURE_INIT(&calld->recv_trailing_metadata_ready,
    229                     recv_trailing_metadata_ready, elem,
    230                     grpc_schedule_on_exec_ctx);
    231   // Create server security context.  Set its auth context from channel
    232   // data and save it in the call context.
    233   grpc_server_security_context* server_ctx =
    234       grpc_server_security_context_create(args->arena);
    235   server_ctx->auth_context =
    236       GRPC_AUTH_CONTEXT_REF(chand->auth_context, "server_auth_filter");
    237   if (args->context[GRPC_CONTEXT_SECURITY].value != nullptr) {
    238     args->context[GRPC_CONTEXT_SECURITY].destroy(
    239         args->context[GRPC_CONTEXT_SECURITY].value);
    240   }
    241   args->context[GRPC_CONTEXT_SECURITY].value = server_ctx;
    242   args->context[GRPC_CONTEXT_SECURITY].destroy =
    243       grpc_server_security_context_destroy;
    244   return GRPC_ERROR_NONE;
    245 }
    246 
    247 /* Destructor for call_data */
    248 static void destroy_call_elem(grpc_call_element* elem,
    249                               const grpc_call_final_info* final_info,
    250                               grpc_closure* ignored) {
    251   call_data* calld = static_cast<call_data*>(elem->call_data);
    252   GRPC_ERROR_UNREF(calld->error);
    253 }
    254 
    255 /* Constructor for channel_data */
    256 static grpc_error* init_channel_elem(grpc_channel_element* elem,
    257                                      grpc_channel_element_args* args) {
    258   GPR_ASSERT(!args->is_last);
    259   channel_data* chand = static_cast<channel_data*>(elem->channel_data);
    260   grpc_auth_context* auth_context =
    261       grpc_find_auth_context_in_args(args->channel_args);
    262   GPR_ASSERT(auth_context != nullptr);
    263   chand->auth_context =
    264       GRPC_AUTH_CONTEXT_REF(auth_context, "server_auth_filter");
    265   grpc_server_credentials* creds =
    266       grpc_find_server_credentials_in_args(args->channel_args);
    267   chand->creds = grpc_server_credentials_ref(creds);
    268   return GRPC_ERROR_NONE;
    269 }
    270 
    271 /* Destructor for channel data */
    272 static void destroy_channel_elem(grpc_channel_element* elem) {
    273   channel_data* chand = static_cast<channel_data*>(elem->channel_data);
    274   GRPC_AUTH_CONTEXT_UNREF(chand->auth_context, "server_auth_filter");
    275   grpc_server_credentials_unref(chand->creds);
    276 }
    277 
    278 const grpc_channel_filter grpc_server_auth_filter = {
    279     auth_start_transport_stream_op_batch,
    280     grpc_channel_next_op,
    281     sizeof(call_data),
    282     init_call_elem,
    283     grpc_call_stack_ignore_set_pollset_or_pollset_set,
    284     destroy_call_elem,
    285     sizeof(channel_data),
    286     init_channel_elem,
    287     destroy_channel_elem,
    288     grpc_channel_next_get_info,
    289     "server-auth"};
    290