Home | History | Annotate | Download | only in services
      1 /*
      2  * Copyright 2017 The gRPC Authors
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package io.grpc.services;
     18 
     19 import static com.google.common.base.Charsets.UTF_8;
     20 import static com.google.common.truth.Truth.assertThat;
     21 import static org.junit.Assert.assertEquals;
     22 import static org.junit.Assert.assertSame;
     23 import static org.junit.Assert.assertTrue;
     24 
     25 import com.google.common.io.ByteStreams;
     26 import io.grpc.CallOptions;
     27 import io.grpc.Channel;
     28 import io.grpc.ClientCall;
     29 import io.grpc.ClientInterceptor;
     30 import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
     31 import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
     32 import io.grpc.ForwardingServerCall.SimpleForwardingServerCall;
     33 import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener;
     34 import io.grpc.Metadata;
     35 import io.grpc.MethodDescriptor;
     36 import io.grpc.MethodDescriptor.Marshaller;
     37 import io.grpc.MethodDescriptor.MethodType;
     38 import io.grpc.ServerCall;
     39 import io.grpc.ServerCall.Listener;
     40 import io.grpc.ServerCallHandler;
     41 import io.grpc.ServerInterceptor;
     42 import io.grpc.ServerMethodDefinition;
     43 import io.grpc.internal.IoUtils;
     44 import io.grpc.internal.NoopClientCall;
     45 import io.grpc.internal.NoopServerCall;
     46 import java.io.ByteArrayInputStream;
     47 import java.io.IOException;
     48 import java.io.InputStream;
     49 import java.util.ArrayList;
     50 import java.util.List;
     51 import java.util.concurrent.atomic.AtomicReference;
     52 import org.junit.Test;
     53 import org.junit.runner.RunWith;
     54 import org.junit.runners.JUnit4;
     55 
     56 /** Unit tests for {@link BinaryLogProvider}. */
     57 @RunWith(JUnit4.class)
     58 public class BinaryLogProviderTest {
     59   private final InvocationCountMarshaller<String> reqMarshaller =
     60       new InvocationCountMarshaller<String>() {
     61         @Override
     62         Marshaller<String> delegate() {
     63           return StringMarshaller.INSTANCE;
     64         }
     65       };
     66   private final InvocationCountMarshaller<Integer> respMarshaller =
     67       new InvocationCountMarshaller<Integer>() {
     68         @Override
     69         Marshaller<Integer> delegate() {
     70           return IntegerMarshaller.INSTANCE;
     71         }
     72       };
     73   private final MethodDescriptor<String, Integer> method =
     74       MethodDescriptor
     75           .newBuilder(reqMarshaller, respMarshaller)
     76           .setFullMethodName("myservice/mymethod")
     77           .setType(MethodType.UNARY)
     78           .setSchemaDescriptor(new Object())
     79           .setIdempotent(true)
     80           .setSafe(true)
     81           .setSampledToLocalTracing(true)
     82           .build();
     83   private final List<byte[]> binlogReq = new ArrayList<byte[]>();
     84   private final List<byte[]> binlogResp = new ArrayList<byte[]>();
     85   private final BinaryLogProvider binlogProvider = new BinaryLogProvider() {
     86     @Override
     87     public ServerInterceptor getServerInterceptor(String fullMethodName) {
     88       return new TestBinaryLogServerInterceptor();
     89     }
     90 
     91     @Override
     92     public ClientInterceptor getClientInterceptor(
     93         String fullMethodName, CallOptions callOptions) {
     94       return new TestBinaryLogClientInterceptor();
     95     }
     96   };
     97 
     98   @Test
     99   public void wrapChannel_methodDescriptor() throws Exception {
    100     final AtomicReference<MethodDescriptor<?, ?>> methodRef =
    101         new AtomicReference<MethodDescriptor<?, ?>>();
    102     Channel channel = new Channel() {
    103       @Override
    104       public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
    105           MethodDescriptor<RequestT, ResponseT> method, CallOptions callOptions) {
    106         methodRef.set(method);
    107         return new NoopClientCall<RequestT, ResponseT>();
    108       }
    109 
    110       @Override
    111       public String authority() {
    112         throw new UnsupportedOperationException();
    113       }
    114     };
    115     Channel wChannel = binlogProvider.wrapChannel(channel);
    116     ClientCall<String, Integer> unusedClientCall = wChannel.newCall(method, CallOptions.DEFAULT);
    117     validateWrappedMethod(methodRef.get());
    118   }
    119 
    120   @Test
    121   public void wrapChannel_handler() throws Exception {
    122     final List<byte[]> serializedReq = new ArrayList<byte[]>();
    123     final AtomicReference<ClientCall.Listener<?>> listener =
    124         new AtomicReference<ClientCall.Listener<?>>();
    125     Channel channel = new Channel() {
    126       @Override
    127       public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
    128           MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
    129         return new NoopClientCall<RequestT, ResponseT>() {
    130           @Override
    131           public void start(Listener<ResponseT> responseListener, Metadata headers) {
    132             listener.set(responseListener);
    133           }
    134 
    135           @Override
    136           public void sendMessage(RequestT message) {
    137             serializedReq.add((byte[]) message);
    138           }
    139         };
    140       }
    141 
    142       @Override
    143       public String authority() {
    144         throw new UnsupportedOperationException();
    145       }
    146     };
    147     Channel wChannel = binlogProvider.wrapChannel(channel);
    148     ClientCall<String, Integer> clientCall = wChannel.newCall(method, CallOptions.DEFAULT);
    149     final List<Integer> observedResponse = new ArrayList<>();
    150     clientCall.start(
    151         new NoopClientCall.NoopClientCallListener<Integer>() {
    152           @Override
    153           public void onMessage(Integer message) {
    154             observedResponse.add(message);
    155           }
    156         },
    157         new Metadata());
    158 
    159     String expectedRequest = "hello world";
    160     assertThat(binlogReq).isEmpty();
    161     assertThat(serializedReq).isEmpty();
    162     assertEquals(0, reqMarshaller.streamInvocations);
    163     clientCall.sendMessage(expectedRequest);
    164     // it is unacceptably expensive for the binlog to double parse every logged message
    165     assertEquals(1, reqMarshaller.streamInvocations);
    166     assertEquals(0, reqMarshaller.parseInvocations);
    167     assertThat(binlogReq).hasSize(1);
    168     assertThat(serializedReq).hasSize(1);
    169     assertEquals(
    170         expectedRequest,
    171         StringMarshaller.INSTANCE.parse(new ByteArrayInputStream(binlogReq.get(0))));
    172     assertEquals(
    173         expectedRequest,
    174         StringMarshaller.INSTANCE.parse(new ByteArrayInputStream(serializedReq.get(0))));
    175 
    176     int expectedResponse = 12345;
    177     assertThat(binlogResp).isEmpty();
    178     assertThat(observedResponse).isEmpty();
    179     assertEquals(0, respMarshaller.parseInvocations);
    180     onClientMessageHelper(listener.get(), IntegerMarshaller.INSTANCE.stream(expectedResponse));
    181     // it is unacceptably expensive for the binlog to double parse every logged message
    182     assertEquals(1, respMarshaller.parseInvocations);
    183     assertEquals(0, respMarshaller.streamInvocations);
    184     assertThat(binlogResp).hasSize(1);
    185     assertThat(observedResponse).hasSize(1);
    186     assertEquals(
    187         expectedResponse,
    188         (int) IntegerMarshaller.INSTANCE.parse(new ByteArrayInputStream(binlogResp.get(0))));
    189     assertEquals(expectedResponse, (int) observedResponse.get(0));
    190   }
    191 
    192   @SuppressWarnings({"rawtypes", "unchecked"})
    193   private static void onClientMessageHelper(ClientCall.Listener listener, Object request) {
    194     listener.onMessage(request);
    195   }
    196 
    197   private void validateWrappedMethod(MethodDescriptor<?, ?> wMethod) {
    198     assertSame(BinaryLogProvider.BYTEARRAY_MARSHALLER, wMethod.getRequestMarshaller());
    199     assertSame(BinaryLogProvider.BYTEARRAY_MARSHALLER, wMethod.getResponseMarshaller());
    200     assertEquals(method.getType(), wMethod.getType());
    201     assertEquals(method.getFullMethodName(), wMethod.getFullMethodName());
    202     assertEquals(method.getSchemaDescriptor(), wMethod.getSchemaDescriptor());
    203     assertEquals(method.isIdempotent(), wMethod.isIdempotent());
    204     assertEquals(method.isSafe(), wMethod.isSafe());
    205     assertEquals(method.isSampledToLocalTracing(), wMethod.isSampledToLocalTracing());
    206   }
    207 
    208   @Test
    209   public void wrapMethodDefinition_methodDescriptor() throws Exception {
    210     ServerMethodDefinition<String, Integer> methodDef =
    211         ServerMethodDefinition.create(
    212             method,
    213             new ServerCallHandler<String, Integer>() {
    214               @Override
    215               public Listener<String> startCall(
    216                   ServerCall<String, Integer> call, Metadata headers) {
    217                 throw new UnsupportedOperationException();
    218               }
    219             });
    220     ServerMethodDefinition<?, ?> wMethodDef = binlogProvider.wrapMethodDefinition(methodDef);
    221     validateWrappedMethod(wMethodDef.getMethodDescriptor());
    222   }
    223 
    224   @Test
    225   public void wrapMethodDefinition_handler() throws Exception {
    226     // The request as seen by the user supplied server code
    227     final List<String> observedRequest = new ArrayList<>();
    228     final AtomicReference<ServerCall<String, Integer>> serverCall =
    229         new AtomicReference<ServerCall<String, Integer>>();
    230     ServerMethodDefinition<String, Integer> methodDef =
    231         ServerMethodDefinition.create(
    232             method,
    233             new ServerCallHandler<String, Integer>() {
    234               @Override
    235               public ServerCall.Listener<String> startCall(
    236                   ServerCall<String, Integer> call, Metadata headers) {
    237                 serverCall.set(call);
    238                 return new ServerCall.Listener<String>() {
    239                   @Override
    240                   public void onMessage(String message) {
    241                     observedRequest.add(message);
    242                   }
    243                 };
    244               }
    245             });
    246     ServerMethodDefinition<?, ?> wDef = binlogProvider.wrapMethodDefinition(methodDef);
    247     List<Object> serializedResp = new ArrayList<>();
    248     ServerCall.Listener<?> wListener = startServerCallHelper(wDef, serializedResp);
    249 
    250     String expectedRequest = "hello world";
    251     assertThat(binlogReq).isEmpty();
    252     assertThat(observedRequest).isEmpty();
    253     assertEquals(0, reqMarshaller.parseInvocations);
    254     onServerMessageHelper(wListener, StringMarshaller.INSTANCE.stream(expectedRequest));
    255     // it is unacceptably expensive for the binlog to double parse every logged message
    256     assertEquals(1, reqMarshaller.parseInvocations);
    257     assertEquals(0, reqMarshaller.streamInvocations);
    258     assertThat(binlogReq).hasSize(1);
    259     assertThat(observedRequest).hasSize(1);
    260     assertEquals(
    261         expectedRequest,
    262         StringMarshaller.INSTANCE.parse(new ByteArrayInputStream(binlogReq.get(0))));
    263     assertEquals(expectedRequest, observedRequest.get(0));
    264 
    265     int expectedResponse = 12345;
    266     assertThat(binlogResp).isEmpty();
    267     assertThat(serializedResp).isEmpty();
    268     assertEquals(0, respMarshaller.streamInvocations);
    269     serverCall.get().sendMessage(expectedResponse);
    270     // it is unacceptably expensive for the binlog to double parse every logged message
    271     assertEquals(0, respMarshaller.parseInvocations);
    272     assertEquals(1, respMarshaller.streamInvocations);
    273     assertThat(binlogResp).hasSize(1);
    274     assertThat(serializedResp).hasSize(1);
    275     assertEquals(
    276         expectedResponse,
    277         (int) IntegerMarshaller.INSTANCE.parse(new ByteArrayInputStream(binlogResp.get(0))));
    278     assertEquals(expectedResponse,
    279         (int) method.parseResponse(new ByteArrayInputStream((byte[]) serializedResp.get(0))));
    280   }
    281 
    282   @SuppressWarnings({"rawtypes", "unchecked"})
    283   private static void onServerMessageHelper(ServerCall.Listener listener, Object request) {
    284     listener.onMessage(request);
    285   }
    286 
    287   private static <ReqT, RespT> ServerCall.Listener<ReqT> startServerCallHelper(
    288       final ServerMethodDefinition<ReqT, RespT> methodDef,
    289       final List<Object> serializedResp) {
    290     ServerCall<ReqT, RespT> serverCall = new NoopServerCall<ReqT, RespT>() {
    291       @Override
    292       public void sendMessage(RespT message) {
    293         serializedResp.add(message);
    294       }
    295 
    296       @Override
    297       public MethodDescriptor<ReqT, RespT> getMethodDescriptor() {
    298         return methodDef.getMethodDescriptor();
    299       }
    300     };
    301     return methodDef.getServerCallHandler().startCall(serverCall, new Metadata());
    302   }
    303 
    304   private final class TestBinaryLogClientInterceptor implements ClientInterceptor {
    305     @Override
    306     public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
    307         final MethodDescriptor<ReqT, RespT> method,
    308         CallOptions callOptions,
    309         Channel next) {
    310       assertSame(BinaryLogProvider.BYTEARRAY_MARSHALLER, method.getRequestMarshaller());
    311       assertSame(BinaryLogProvider.BYTEARRAY_MARSHALLER, method.getResponseMarshaller());
    312       return new SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {
    313         @Override
    314         public void start(Listener<RespT> responseListener, Metadata headers) {
    315           super.start(
    316               new SimpleForwardingClientCallListener<RespT>(responseListener) {
    317                 @Override
    318                 public void onMessage(RespT message) {
    319                   assertTrue(message instanceof InputStream);
    320                   try {
    321                     byte[] bytes = IoUtils.toByteArray((InputStream) message);
    322                     binlogResp.add(bytes);
    323                     ByteArrayInputStream input = new ByteArrayInputStream(bytes);
    324                     RespT dup = method.parseResponse(input);
    325                     super.onMessage(dup);
    326                   } catch (IOException e) {
    327                     throw new RuntimeException(e);
    328                   }
    329                 }
    330               },
    331               headers);
    332         }
    333 
    334         @Override
    335         public void sendMessage(ReqT message) {
    336             byte[] bytes = (byte[]) message;
    337             binlogReq.add(bytes);
    338             ByteArrayInputStream input = new ByteArrayInputStream(bytes);
    339             ReqT dup = method.parseRequest(input);
    340             super.sendMessage(dup);
    341         }
    342       };
    343     }
    344   }
    345 
    346   private final class TestBinaryLogServerInterceptor implements ServerInterceptor {
    347     @Override
    348     public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
    349         final ServerCall<ReqT, RespT> call,
    350         Metadata headers,
    351         ServerCallHandler<ReqT, RespT> next) {
    352       assertSame(
    353           BinaryLogProvider.BYTEARRAY_MARSHALLER,
    354           call.getMethodDescriptor().getRequestMarshaller());
    355       assertSame(
    356           BinaryLogProvider.BYTEARRAY_MARSHALLER,
    357           call.getMethodDescriptor().getResponseMarshaller());
    358       ServerCall<ReqT, RespT> wCall = new SimpleForwardingServerCall<ReqT, RespT>(call) {
    359         @Override
    360         public void sendMessage(RespT message) {
    361             byte[] bytes = (byte[]) message;
    362             binlogResp.add(bytes);
    363             ByteArrayInputStream input = new ByteArrayInputStream(bytes);
    364             RespT dup = call.getMethodDescriptor().parseResponse(input);
    365             super.sendMessage(dup);
    366         }
    367       };
    368       final ServerCall.Listener<ReqT> oListener = next.startCall(wCall, headers);
    369       return new SimpleForwardingServerCallListener<ReqT>(oListener) {
    370         @Override
    371         public void onMessage(ReqT message) {
    372           assertTrue(message instanceof InputStream);
    373           try {
    374             byte[] bytes = IoUtils.toByteArray((InputStream) message);
    375             binlogReq.add(bytes);
    376             ByteArrayInputStream input = new ByteArrayInputStream(bytes);
    377             ReqT dup = call.getMethodDescriptor().parseRequest(input);
    378             super.onMessage(dup);
    379           } catch (IOException e) {
    380             throw new RuntimeException(e);
    381           }
    382         }
    383       };
    384     }
    385   }
    386 
    387   private abstract static class InvocationCountMarshaller<T>
    388       implements MethodDescriptor.Marshaller<T> {
    389     private int streamInvocations = 0;
    390     private int parseInvocations = 0;
    391 
    392     abstract MethodDescriptor.Marshaller<T> delegate();
    393 
    394     @Override
    395     public InputStream stream(T value) {
    396       streamInvocations++;
    397       return delegate().stream(value);
    398     }
    399 
    400     @Override
    401     public T parse(InputStream stream) {
    402       parseInvocations++;
    403       return delegate().parse(stream);
    404     }
    405   }
    406 
    407 
    408   private static class StringMarshaller implements MethodDescriptor.Marshaller<String> {
    409     public static final StringMarshaller INSTANCE = new StringMarshaller();
    410 
    411     @Override
    412     public InputStream stream(String value) {
    413       return new ByteArrayInputStream(value.getBytes(UTF_8));
    414     }
    415 
    416     @Override
    417     public String parse(InputStream stream) {
    418       try {
    419         return new String(ByteStreams.toByteArray(stream), UTF_8);
    420       } catch (IOException ex) {
    421         throw new RuntimeException(ex);
    422       }
    423     }
    424   }
    425 
    426   private static class IntegerMarshaller implements MethodDescriptor.Marshaller<Integer> {
    427     public static final IntegerMarshaller INSTANCE = new IntegerMarshaller();
    428 
    429     @Override
    430     public InputStream stream(Integer value) {
    431       return StringMarshaller.INSTANCE.stream(value.toString());
    432     }
    433 
    434     @Override
    435     public Integer parse(InputStream stream) {
    436       return Integer.valueOf(StringMarshaller.INSTANCE.parse(stream));
    437     }
    438   }
    439 }
    440