Home | History | Annotate | Download | only in grpc
      1 /*
      2  * Copyright 2014 The gRPC Authors
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package io.grpc;
     18 
     19 import static com.google.common.collect.Iterables.getOnlyElement;
     20 import static org.junit.Assert.assertEquals;
     21 import static org.junit.Assert.assertSame;
     22 import static org.junit.Assert.assertTrue;
     23 import static org.mockito.AdditionalAnswers.delegatesTo;
     24 import static org.mockito.Matchers.same;
     25 import static org.mockito.Mockito.mock;
     26 import static org.mockito.Mockito.times;
     27 import static org.mockito.Mockito.verify;
     28 import static org.mockito.Mockito.verifyNoMoreInteractions;
     29 import static org.mockito.Mockito.verifyZeroInteractions;
     30 
     31 import io.grpc.MethodDescriptor.Marshaller;
     32 import io.grpc.MethodDescriptor.MethodType;
     33 import io.grpc.ServerCall.Listener;
     34 import io.grpc.internal.NoopServerCall;
     35 import java.io.ByteArrayInputStream;
     36 import java.io.InputStream;
     37 import java.util.ArrayList;
     38 import java.util.Arrays;
     39 import java.util.List;
     40 import org.junit.After;
     41 import org.junit.Before;
     42 import org.junit.Test;
     43 import org.junit.runner.RunWith;
     44 import org.junit.runners.JUnit4;
     45 import org.mockito.Mock;
     46 import org.mockito.Mockito;
     47 import org.mockito.MockitoAnnotations;
     48 
     49 /** Unit tests for {@link ServerInterceptors}. */
     50 @RunWith(JUnit4.class)
     51 public class ServerInterceptorsTest {
     52   @Mock
     53   private Marshaller<String> requestMarshaller;
     54 
     55   @Mock
     56   private Marshaller<Integer> responseMarshaller;
     57 
     58   @Mock
     59   private ServerCallHandler<String, Integer> handler;
     60 
     61   @Mock
     62   private ServerCall.Listener<String> listener;
     63 
     64   private MethodDescriptor<String, Integer> flowMethod;
     65 
     66   private ServerCall<String, Integer> call = new NoopServerCall<String, Integer>();
     67 
     68   private ServerServiceDefinition serviceDefinition;
     69 
     70   private final Metadata headers = new Metadata();
     71 
     72   /** Set up for test. */
     73   @Before
     74   public void setUp() {
     75     MockitoAnnotations.initMocks(this);
     76     flowMethod = MethodDescriptor.<String, Integer>newBuilder()
     77         .setType(MethodType.UNKNOWN)
     78         .setFullMethodName("basic/flow")
     79         .setRequestMarshaller(requestMarshaller)
     80         .setResponseMarshaller(responseMarshaller)
     81         .build();
     82 
     83     Mockito.when(handler.startCall(
     84         Mockito.<ServerCall<String, Integer>>any(), Mockito.<Metadata>any()))
     85             .thenReturn(listener);
     86 
     87     serviceDefinition = ServerServiceDefinition.builder(new ServiceDescriptor("basic", flowMethod))
     88         .addMethod(flowMethod, handler).build();
     89   }
     90 
     91   /** Final checks for all tests. */
     92   @After
     93   public void makeSureExpectedMocksUnused() {
     94     verifyZeroInteractions(requestMarshaller);
     95     verifyZeroInteractions(responseMarshaller);
     96     verifyZeroInteractions(listener);
     97   }
     98 
     99   @Test(expected = NullPointerException.class)
    100   public void npeForNullServiceDefinition() {
    101     ServerServiceDefinition serviceDef = null;
    102     ServerInterceptors.intercept(serviceDef, Arrays.<ServerInterceptor>asList());
    103   }
    104 
    105   @Test(expected = NullPointerException.class)
    106   public void npeForNullInterceptorList() {
    107     ServerInterceptors.intercept(serviceDefinition, (List<ServerInterceptor>) null);
    108   }
    109 
    110   @Test(expected = NullPointerException.class)
    111   public void npeForNullInterceptor() {
    112     ServerInterceptors.intercept(serviceDefinition, Arrays.asList((ServerInterceptor) null));
    113   }
    114 
    115   @Test
    116   public void noop() {
    117     assertSame(serviceDefinition,
    118         ServerInterceptors.intercept(serviceDefinition, Arrays.<ServerInterceptor>asList()));
    119   }
    120 
    121   @Test
    122   public void multipleInvocationsOfHandler() {
    123     ServerInterceptor interceptor =
    124         mock(ServerInterceptor.class, delegatesTo(new NoopInterceptor()));
    125     ServerServiceDefinition intercepted
    126         = ServerInterceptors.intercept(serviceDefinition, Arrays.asList(interceptor));
    127     assertSame(listener,
    128         getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
    129     verify(interceptor).interceptCall(same(call), same(headers), anyCallHandler());
    130     verify(handler).startCall(call, headers);
    131     verifyNoMoreInteractions(interceptor, handler);
    132 
    133     assertSame(listener,
    134         getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
    135     verify(interceptor, times(2))
    136         .interceptCall(same(call), same(headers), anyCallHandler());
    137     verify(handler, times(2)).startCall(call, headers);
    138     verifyNoMoreInteractions(interceptor, handler);
    139   }
    140 
    141   @Test
    142   public void correctHandlerCalled() {
    143     @SuppressWarnings("unchecked")
    144     ServerCallHandler<String, Integer> handler2 = mock(ServerCallHandler.class);
    145     MethodDescriptor<String, Integer> flowMethod2 =
    146         flowMethod.toBuilder().setFullMethodName("basic/flow2").build();
    147     serviceDefinition = ServerServiceDefinition.builder(
    148         new ServiceDescriptor("basic", flowMethod, flowMethod2))
    149         .addMethod(flowMethod, handler)
    150         .addMethod(flowMethod2, handler2).build();
    151     ServerServiceDefinition intercepted = ServerInterceptors.intercept(
    152         serviceDefinition, Arrays.<ServerInterceptor>asList(new NoopInterceptor()));
    153     getMethod(intercepted, "basic/flow").getServerCallHandler().startCall(call, headers);
    154     verify(handler).startCall(call, headers);
    155     verifyNoMoreInteractions(handler);
    156     verifyNoMoreInteractions(handler2);
    157 
    158     getMethod(intercepted, "basic/flow2").getServerCallHandler().startCall(call, headers);
    159     verify(handler2).startCall(call, headers);
    160     verifyNoMoreInteractions(handler);
    161     verifyNoMoreInteractions(handler2);
    162   }
    163 
    164   @Test
    165   public void callNextTwice() {
    166     ServerInterceptor interceptor = new ServerInterceptor() {
    167       @Override
    168       public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
    169           ServerCall<ReqT, RespT> call,
    170           Metadata headers,
    171           ServerCallHandler<ReqT, RespT> next) {
    172         // Calling next twice is permitted, although should only rarely be useful.
    173         assertSame(listener, next.startCall(call, headers));
    174         return next.startCall(call, headers);
    175       }
    176     };
    177     ServerServiceDefinition intercepted = ServerInterceptors.intercept(serviceDefinition,
    178         interceptor);
    179     assertSame(listener,
    180         getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
    181     verify(handler, times(2)).startCall(same(call), same(headers));
    182     verifyNoMoreInteractions(handler);
    183   }
    184 
    185   @Test
    186   public void ordered() {
    187     final List<String> order = new ArrayList<>();
    188     handler = new ServerCallHandler<String, Integer>() {
    189           @Override
    190           public ServerCall.Listener<String> startCall(
    191               ServerCall<String, Integer> call,
    192               Metadata headers) {
    193             order.add("handler");
    194             return listener;
    195           }
    196         };
    197     ServerInterceptor interceptor1 = new ServerInterceptor() {
    198           @Override
    199           public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
    200               ServerCall<ReqT, RespT> call,
    201               Metadata headers,
    202               ServerCallHandler<ReqT, RespT> next) {
    203             order.add("i1");
    204             return next.startCall(call, headers);
    205           }
    206         };
    207     ServerInterceptor interceptor2 = new ServerInterceptor() {
    208           @Override
    209           public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
    210               ServerCall<ReqT, RespT> call,
    211               Metadata headers,
    212               ServerCallHandler<ReqT, RespT> next) {
    213             order.add("i2");
    214             return next.startCall(call, headers);
    215           }
    216         };
    217     ServerServiceDefinition serviceDefinition = ServerServiceDefinition.builder(
    218         new ServiceDescriptor("basic", flowMethod))
    219         .addMethod(flowMethod, handler).build();
    220     ServerServiceDefinition intercepted = ServerInterceptors.intercept(
    221         serviceDefinition, Arrays.asList(interceptor1, interceptor2));
    222     assertSame(listener,
    223         getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
    224     assertEquals(Arrays.asList("i2", "i1", "handler"), order);
    225   }
    226 
    227   @Test
    228   public void orderedForward() {
    229     final List<String> order = new ArrayList<>();
    230     handler = new ServerCallHandler<String, Integer>() {
    231       @Override
    232       public ServerCall.Listener<String> startCall(
    233           ServerCall<String, Integer> call,
    234           Metadata headers) {
    235         order.add("handler");
    236         return listener;
    237       }
    238     };
    239     ServerInterceptor interceptor1 = new ServerInterceptor() {
    240       @Override
    241       public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
    242           ServerCall<ReqT, RespT> call,
    243           Metadata headers,
    244           ServerCallHandler<ReqT, RespT> next) {
    245         order.add("i1");
    246         return next.startCall(call, headers);
    247       }
    248     };
    249     ServerInterceptor interceptor2 = new ServerInterceptor() {
    250       @Override
    251       public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
    252           ServerCall<ReqT, RespT> call,
    253           Metadata headers,
    254           ServerCallHandler<ReqT, RespT> next) {
    255         order.add("i2");
    256         return next.startCall(call, headers);
    257       }
    258     };
    259     ServerServiceDefinition serviceDefinition = ServerServiceDefinition.builder(
    260         new ServiceDescriptor("basic", flowMethod))
    261         .addMethod(flowMethod, handler).build();
    262     ServerServiceDefinition intercepted = ServerInterceptors.interceptForward(
    263         serviceDefinition, interceptor1, interceptor2);
    264     assertSame(listener,
    265         getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
    266     assertEquals(Arrays.asList("i1", "i2", "handler"), order);
    267   }
    268 
    269   @Test
    270   public void argumentsPassed() {
    271     @SuppressWarnings("unchecked")
    272     final ServerCall<String, Integer> call2 = new NoopServerCall<String, Integer>();
    273     @SuppressWarnings("unchecked")
    274     final ServerCall.Listener<String> listener2 = mock(ServerCall.Listener.class);
    275 
    276     ServerInterceptor interceptor = new ServerInterceptor() {
    277         @SuppressWarnings("unchecked") // Lot's of casting for no benefit.  Not intended use.
    278         @Override
    279         public <R1, R2> ServerCall.Listener<R1> interceptCall(
    280             ServerCall<R1, R2> call,
    281             Metadata headers,
    282             ServerCallHandler<R1, R2> next) {
    283           assertSame(call, ServerInterceptorsTest.this.call);
    284           assertSame(listener,
    285               next.startCall((ServerCall<R1, R2>)call2, headers));
    286           return (ServerCall.Listener<R1>) listener2;
    287         }
    288       };
    289     ServerServiceDefinition intercepted = ServerInterceptors.intercept(
    290         serviceDefinition, Arrays.asList(interceptor));
    291     assertSame(listener2,
    292         getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
    293     verify(handler).startCall(call2, headers);
    294   }
    295 
    296   @Test
    297   @SuppressWarnings("unchecked")
    298   public void typedMarshalledMessages() {
    299     final List<String> order = new ArrayList<>();
    300     Marshaller<Holder> marshaller = new Marshaller<Holder>() {
    301       @Override
    302       public InputStream stream(Holder value) {
    303         return value.get();
    304       }
    305 
    306       @Override
    307       public Holder parse(InputStream stream) {
    308         return new Holder(stream);
    309       }
    310     };
    311 
    312     ServerCallHandler<Holder, Holder> handler2 = new ServerCallHandler<Holder, Holder>() {
    313       @Override
    314       public Listener<Holder> startCall(final ServerCall<Holder, Holder> call,
    315                                         final Metadata headers) {
    316         return new Listener<Holder>() {
    317           @Override
    318           public void onMessage(Holder message) {
    319             order.add("handler");
    320             call.sendMessage(message);
    321           }
    322         };
    323       }
    324     };
    325 
    326     MethodDescriptor<Holder, Holder> wrappedMethod = MethodDescriptor.<Holder, Holder>newBuilder()
    327         .setType(MethodType.UNKNOWN)
    328         .setFullMethodName("basic/wrapped")
    329         .setRequestMarshaller(marshaller)
    330         .setResponseMarshaller(marshaller)
    331         .build();
    332     ServerServiceDefinition serviceDef = ServerServiceDefinition.builder(
    333         new ServiceDescriptor("basic", wrappedMethod))
    334         .addMethod(wrappedMethod, handler2).build();
    335 
    336     ServerInterceptor interceptor1 = new ServerInterceptor() {
    337       @Override
    338       public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
    339                                                         Metadata headers,
    340                                                         ServerCallHandler<ReqT, RespT> next) {
    341         ServerCall<ReqT, RespT> interceptedCall = new ForwardingServerCall
    342             .SimpleForwardingServerCall<ReqT, RespT>(call) {
    343           @Override
    344           public void sendMessage(RespT message) {
    345             order.add("i1sendMessage");
    346             assertTrue(message instanceof Holder);
    347             super.sendMessage(message);
    348           }
    349         };
    350 
    351         ServerCall.Listener<ReqT> originalListener = next
    352             .startCall(interceptedCall, headers);
    353         return new ForwardingServerCallListener
    354             .SimpleForwardingServerCallListener<ReqT>(originalListener) {
    355           @Override
    356           public void onMessage(ReqT message) {
    357             order.add("i1onMessage");
    358             assertTrue(message instanceof Holder);
    359             super.onMessage(message);
    360           }
    361         };
    362       }
    363     };
    364 
    365     ServerInterceptor interceptor2 = new ServerInterceptor() {
    366       @Override
    367       public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
    368                                                         Metadata headers,
    369                                                         ServerCallHandler<ReqT, RespT> next) {
    370         ServerCall<ReqT, RespT> interceptedCall = new ForwardingServerCall
    371             .SimpleForwardingServerCall<ReqT, RespT>(call) {
    372           @Override
    373           public void sendMessage(RespT message) {
    374             order.add("i2sendMessage");
    375             assertTrue(message instanceof InputStream);
    376             super.sendMessage(message);
    377           }
    378         };
    379 
    380         ServerCall.Listener<ReqT> originalListener = next
    381             .startCall(interceptedCall, headers);
    382         return new ForwardingServerCallListener
    383             .SimpleForwardingServerCallListener<ReqT>(originalListener) {
    384           @Override
    385           public void onMessage(ReqT message) {
    386             order.add("i2onMessage");
    387             assertTrue(message instanceof InputStream);
    388             super.onMessage(message);
    389           }
    390         };
    391       }
    392     };
    393 
    394     ServerServiceDefinition intercepted = ServerInterceptors.intercept(serviceDef, interceptor1);
    395     ServerServiceDefinition inputStreamMessageService = ServerInterceptors
    396         .useInputStreamMessages(intercepted);
    397     ServerServiceDefinition intercepted2 = ServerInterceptors
    398         .intercept(inputStreamMessageService, interceptor2);
    399     ServerMethodDefinition<InputStream, InputStream> serverMethod =
    400         (ServerMethodDefinition<InputStream, InputStream>) intercepted2.getMethod("basic/wrapped");
    401     ServerCall<InputStream, InputStream> call2 = new NoopServerCall<InputStream, InputStream>();
    402     byte[] bytes = {};
    403     serverMethod
    404         .getServerCallHandler()
    405         .startCall(call2, headers)
    406         .onMessage(new ByteArrayInputStream(bytes));
    407     assertEquals(
    408         Arrays.asList("i2onMessage", "i1onMessage", "handler", "i1sendMessage", "i2sendMessage"),
    409         order);
    410   }
    411 
    412   @SuppressWarnings("unchecked")
    413   private static ServerMethodDefinition<String, Integer> getSoleMethod(
    414       ServerServiceDefinition serviceDef) {
    415     if (serviceDef.getMethods().size() != 1) {
    416       throw new AssertionError("Not exactly one method present");
    417     }
    418     return (ServerMethodDefinition<String, Integer>) getOnlyElement(serviceDef.getMethods());
    419   }
    420 
    421   @SuppressWarnings("unchecked")
    422   private static ServerMethodDefinition<String, Integer> getMethod(
    423       ServerServiceDefinition serviceDef, String name) {
    424     return (ServerMethodDefinition<String, Integer>) serviceDef.getMethod(name);
    425   }
    426 
    427   private ServerCallHandler<String, Integer> anyCallHandler() {
    428     return Mockito.<ServerCallHandler<String, Integer>>any();
    429   }
    430 
    431   private static class NoopInterceptor implements ServerInterceptor {
    432     @Override
    433     public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
    434         ServerCall<ReqT, RespT> call,
    435         Metadata headers,
    436         ServerCallHandler<ReqT, RespT> next) {
    437       return next.startCall(call, headers);
    438     }
    439   }
    440 
    441   private static class Holder {
    442     private final InputStream inputStream;
    443 
    444     Holder(InputStream inputStream) {
    445       this.inputStream = inputStream;
    446     }
    447 
    448     public InputStream get() {
    449       return inputStream;
    450     }
    451   }
    452 }
    453