1 /* 2 * Copyright 2017 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.services; 18 19 import static com.google.common.base.Charsets.UTF_8; 20 import static com.google.common.truth.Truth.assertThat; 21 import static org.junit.Assert.assertEquals; 22 import static org.junit.Assert.assertSame; 23 import static org.junit.Assert.assertTrue; 24 25 import com.google.common.io.ByteStreams; 26 import io.grpc.CallOptions; 27 import io.grpc.Channel; 28 import io.grpc.ClientCall; 29 import io.grpc.ClientInterceptor; 30 import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; 31 import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; 32 import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; 33 import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener; 34 import io.grpc.Metadata; 35 import io.grpc.MethodDescriptor; 36 import io.grpc.MethodDescriptor.Marshaller; 37 import io.grpc.MethodDescriptor.MethodType; 38 import io.grpc.ServerCall; 39 import io.grpc.ServerCall.Listener; 40 import io.grpc.ServerCallHandler; 41 import io.grpc.ServerInterceptor; 42 import io.grpc.ServerMethodDefinition; 43 import io.grpc.internal.IoUtils; 44 import io.grpc.internal.NoopClientCall; 45 import io.grpc.internal.NoopServerCall; 46 import java.io.ByteArrayInputStream; 47 import java.io.IOException; 48 import java.io.InputStream; 49 import java.util.ArrayList; 50 import java.util.List; 51 import java.util.concurrent.atomic.AtomicReference; 52 import org.junit.Test; 53 import org.junit.runner.RunWith; 54 import org.junit.runners.JUnit4; 55 56 /** Unit tests for {@link BinaryLogProvider}. */ 57 @RunWith(JUnit4.class) 58 public class BinaryLogProviderTest { 59 private final InvocationCountMarshaller<String> reqMarshaller = 60 new InvocationCountMarshaller<String>() { 61 @Override 62 Marshaller<String> delegate() { 63 return StringMarshaller.INSTANCE; 64 } 65 }; 66 private final InvocationCountMarshaller<Integer> respMarshaller = 67 new InvocationCountMarshaller<Integer>() { 68 @Override 69 Marshaller<Integer> delegate() { 70 return IntegerMarshaller.INSTANCE; 71 } 72 }; 73 private final MethodDescriptor<String, Integer> method = 74 MethodDescriptor 75 .newBuilder(reqMarshaller, respMarshaller) 76 .setFullMethodName("myservice/mymethod") 77 .setType(MethodType.UNARY) 78 .setSchemaDescriptor(new Object()) 79 .setIdempotent(true) 80 .setSafe(true) 81 .setSampledToLocalTracing(true) 82 .build(); 83 private final List<byte[]> binlogReq = new ArrayList<byte[]>(); 84 private final List<byte[]> binlogResp = new ArrayList<byte[]>(); 85 private final BinaryLogProvider binlogProvider = new BinaryLogProvider() { 86 @Override 87 public ServerInterceptor getServerInterceptor(String fullMethodName) { 88 return new TestBinaryLogServerInterceptor(); 89 } 90 91 @Override 92 public ClientInterceptor getClientInterceptor( 93 String fullMethodName, CallOptions callOptions) { 94 return new TestBinaryLogClientInterceptor(); 95 } 96 }; 97 98 @Test 99 public void wrapChannel_methodDescriptor() throws Exception { 100 final AtomicReference<MethodDescriptor<?, ?>> methodRef = 101 new AtomicReference<MethodDescriptor<?, ?>>(); 102 Channel channel = new Channel() { 103 @Override 104 public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall( 105 MethodDescriptor<RequestT, ResponseT> method, CallOptions callOptions) { 106 methodRef.set(method); 107 return new NoopClientCall<RequestT, ResponseT>(); 108 } 109 110 @Override 111 public String authority() { 112 throw new UnsupportedOperationException(); 113 } 114 }; 115 Channel wChannel = binlogProvider.wrapChannel(channel); 116 ClientCall<String, Integer> unusedClientCall = wChannel.newCall(method, CallOptions.DEFAULT); 117 validateWrappedMethod(methodRef.get()); 118 } 119 120 @Test 121 public void wrapChannel_handler() throws Exception { 122 final List<byte[]> serializedReq = new ArrayList<byte[]>(); 123 final AtomicReference<ClientCall.Listener<?>> listener = 124 new AtomicReference<ClientCall.Listener<?>>(); 125 Channel channel = new Channel() { 126 @Override 127 public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall( 128 MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) { 129 return new NoopClientCall<RequestT, ResponseT>() { 130 @Override 131 public void start(Listener<ResponseT> responseListener, Metadata headers) { 132 listener.set(responseListener); 133 } 134 135 @Override 136 public void sendMessage(RequestT message) { 137 serializedReq.add((byte[]) message); 138 } 139 }; 140 } 141 142 @Override 143 public String authority() { 144 throw new UnsupportedOperationException(); 145 } 146 }; 147 Channel wChannel = binlogProvider.wrapChannel(channel); 148 ClientCall<String, Integer> clientCall = wChannel.newCall(method, CallOptions.DEFAULT); 149 final List<Integer> observedResponse = new ArrayList<>(); 150 clientCall.start( 151 new NoopClientCall.NoopClientCallListener<Integer>() { 152 @Override 153 public void onMessage(Integer message) { 154 observedResponse.add(message); 155 } 156 }, 157 new Metadata()); 158 159 String expectedRequest = "hello world"; 160 assertThat(binlogReq).isEmpty(); 161 assertThat(serializedReq).isEmpty(); 162 assertEquals(0, reqMarshaller.streamInvocations); 163 clientCall.sendMessage(expectedRequest); 164 // it is unacceptably expensive for the binlog to double parse every logged message 165 assertEquals(1, reqMarshaller.streamInvocations); 166 assertEquals(0, reqMarshaller.parseInvocations); 167 assertThat(binlogReq).hasSize(1); 168 assertThat(serializedReq).hasSize(1); 169 assertEquals( 170 expectedRequest, 171 StringMarshaller.INSTANCE.parse(new ByteArrayInputStream(binlogReq.get(0)))); 172 assertEquals( 173 expectedRequest, 174 StringMarshaller.INSTANCE.parse(new ByteArrayInputStream(serializedReq.get(0)))); 175 176 int expectedResponse = 12345; 177 assertThat(binlogResp).isEmpty(); 178 assertThat(observedResponse).isEmpty(); 179 assertEquals(0, respMarshaller.parseInvocations); 180 onClientMessageHelper(listener.get(), IntegerMarshaller.INSTANCE.stream(expectedResponse)); 181 // it is unacceptably expensive for the binlog to double parse every logged message 182 assertEquals(1, respMarshaller.parseInvocations); 183 assertEquals(0, respMarshaller.streamInvocations); 184 assertThat(binlogResp).hasSize(1); 185 assertThat(observedResponse).hasSize(1); 186 assertEquals( 187 expectedResponse, 188 (int) IntegerMarshaller.INSTANCE.parse(new ByteArrayInputStream(binlogResp.get(0)))); 189 assertEquals(expectedResponse, (int) observedResponse.get(0)); 190 } 191 192 @SuppressWarnings({"rawtypes", "unchecked"}) 193 private static void onClientMessageHelper(ClientCall.Listener listener, Object request) { 194 listener.onMessage(request); 195 } 196 197 private void validateWrappedMethod(MethodDescriptor<?, ?> wMethod) { 198 assertSame(BinaryLogProvider.BYTEARRAY_MARSHALLER, wMethod.getRequestMarshaller()); 199 assertSame(BinaryLogProvider.BYTEARRAY_MARSHALLER, wMethod.getResponseMarshaller()); 200 assertEquals(method.getType(), wMethod.getType()); 201 assertEquals(method.getFullMethodName(), wMethod.getFullMethodName()); 202 assertEquals(method.getSchemaDescriptor(), wMethod.getSchemaDescriptor()); 203 assertEquals(method.isIdempotent(), wMethod.isIdempotent()); 204 assertEquals(method.isSafe(), wMethod.isSafe()); 205 assertEquals(method.isSampledToLocalTracing(), wMethod.isSampledToLocalTracing()); 206 } 207 208 @Test 209 public void wrapMethodDefinition_methodDescriptor() throws Exception { 210 ServerMethodDefinition<String, Integer> methodDef = 211 ServerMethodDefinition.create( 212 method, 213 new ServerCallHandler<String, Integer>() { 214 @Override 215 public Listener<String> startCall( 216 ServerCall<String, Integer> call, Metadata headers) { 217 throw new UnsupportedOperationException(); 218 } 219 }); 220 ServerMethodDefinition<?, ?> wMethodDef = binlogProvider.wrapMethodDefinition(methodDef); 221 validateWrappedMethod(wMethodDef.getMethodDescriptor()); 222 } 223 224 @Test 225 public void wrapMethodDefinition_handler() throws Exception { 226 // The request as seen by the user supplied server code 227 final List<String> observedRequest = new ArrayList<>(); 228 final AtomicReference<ServerCall<String, Integer>> serverCall = 229 new AtomicReference<ServerCall<String, Integer>>(); 230 ServerMethodDefinition<String, Integer> methodDef = 231 ServerMethodDefinition.create( 232 method, 233 new ServerCallHandler<String, Integer>() { 234 @Override 235 public ServerCall.Listener<String> startCall( 236 ServerCall<String, Integer> call, Metadata headers) { 237 serverCall.set(call); 238 return new ServerCall.Listener<String>() { 239 @Override 240 public void onMessage(String message) { 241 observedRequest.add(message); 242 } 243 }; 244 } 245 }); 246 ServerMethodDefinition<?, ?> wDef = binlogProvider.wrapMethodDefinition(methodDef); 247 List<Object> serializedResp = new ArrayList<>(); 248 ServerCall.Listener<?> wListener = startServerCallHelper(wDef, serializedResp); 249 250 String expectedRequest = "hello world"; 251 assertThat(binlogReq).isEmpty(); 252 assertThat(observedRequest).isEmpty(); 253 assertEquals(0, reqMarshaller.parseInvocations); 254 onServerMessageHelper(wListener, StringMarshaller.INSTANCE.stream(expectedRequest)); 255 // it is unacceptably expensive for the binlog to double parse every logged message 256 assertEquals(1, reqMarshaller.parseInvocations); 257 assertEquals(0, reqMarshaller.streamInvocations); 258 assertThat(binlogReq).hasSize(1); 259 assertThat(observedRequest).hasSize(1); 260 assertEquals( 261 expectedRequest, 262 StringMarshaller.INSTANCE.parse(new ByteArrayInputStream(binlogReq.get(0)))); 263 assertEquals(expectedRequest, observedRequest.get(0)); 264 265 int expectedResponse = 12345; 266 assertThat(binlogResp).isEmpty(); 267 assertThat(serializedResp).isEmpty(); 268 assertEquals(0, respMarshaller.streamInvocations); 269 serverCall.get().sendMessage(expectedResponse); 270 // it is unacceptably expensive for the binlog to double parse every logged message 271 assertEquals(0, respMarshaller.parseInvocations); 272 assertEquals(1, respMarshaller.streamInvocations); 273 assertThat(binlogResp).hasSize(1); 274 assertThat(serializedResp).hasSize(1); 275 assertEquals( 276 expectedResponse, 277 (int) IntegerMarshaller.INSTANCE.parse(new ByteArrayInputStream(binlogResp.get(0)))); 278 assertEquals(expectedResponse, 279 (int) method.parseResponse(new ByteArrayInputStream((byte[]) serializedResp.get(0)))); 280 } 281 282 @SuppressWarnings({"rawtypes", "unchecked"}) 283 private static void onServerMessageHelper(ServerCall.Listener listener, Object request) { 284 listener.onMessage(request); 285 } 286 287 private static <ReqT, RespT> ServerCall.Listener<ReqT> startServerCallHelper( 288 final ServerMethodDefinition<ReqT, RespT> methodDef, 289 final List<Object> serializedResp) { 290 ServerCall<ReqT, RespT> serverCall = new NoopServerCall<ReqT, RespT>() { 291 @Override 292 public void sendMessage(RespT message) { 293 serializedResp.add(message); 294 } 295 296 @Override 297 public MethodDescriptor<ReqT, RespT> getMethodDescriptor() { 298 return methodDef.getMethodDescriptor(); 299 } 300 }; 301 return methodDef.getServerCallHandler().startCall(serverCall, new Metadata()); 302 } 303 304 private final class TestBinaryLogClientInterceptor implements ClientInterceptor { 305 @Override 306 public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( 307 final MethodDescriptor<ReqT, RespT> method, 308 CallOptions callOptions, 309 Channel next) { 310 assertSame(BinaryLogProvider.BYTEARRAY_MARSHALLER, method.getRequestMarshaller()); 311 assertSame(BinaryLogProvider.BYTEARRAY_MARSHALLER, method.getResponseMarshaller()); 312 return new SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) { 313 @Override 314 public void start(Listener<RespT> responseListener, Metadata headers) { 315 super.start( 316 new SimpleForwardingClientCallListener<RespT>(responseListener) { 317 @Override 318 public void onMessage(RespT message) { 319 assertTrue(message instanceof InputStream); 320 try { 321 byte[] bytes = IoUtils.toByteArray((InputStream) message); 322 binlogResp.add(bytes); 323 ByteArrayInputStream input = new ByteArrayInputStream(bytes); 324 RespT dup = method.parseResponse(input); 325 super.onMessage(dup); 326 } catch (IOException e) { 327 throw new RuntimeException(e); 328 } 329 } 330 }, 331 headers); 332 } 333 334 @Override 335 public void sendMessage(ReqT message) { 336 byte[] bytes = (byte[]) message; 337 binlogReq.add(bytes); 338 ByteArrayInputStream input = new ByteArrayInputStream(bytes); 339 ReqT dup = method.parseRequest(input); 340 super.sendMessage(dup); 341 } 342 }; 343 } 344 } 345 346 private final class TestBinaryLogServerInterceptor implements ServerInterceptor { 347 @Override 348 public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( 349 final ServerCall<ReqT, RespT> call, 350 Metadata headers, 351 ServerCallHandler<ReqT, RespT> next) { 352 assertSame( 353 BinaryLogProvider.BYTEARRAY_MARSHALLER, 354 call.getMethodDescriptor().getRequestMarshaller()); 355 assertSame( 356 BinaryLogProvider.BYTEARRAY_MARSHALLER, 357 call.getMethodDescriptor().getResponseMarshaller()); 358 ServerCall<ReqT, RespT> wCall = new SimpleForwardingServerCall<ReqT, RespT>(call) { 359 @Override 360 public void sendMessage(RespT message) { 361 byte[] bytes = (byte[]) message; 362 binlogResp.add(bytes); 363 ByteArrayInputStream input = new ByteArrayInputStream(bytes); 364 RespT dup = call.getMethodDescriptor().parseResponse(input); 365 super.sendMessage(dup); 366 } 367 }; 368 final ServerCall.Listener<ReqT> oListener = next.startCall(wCall, headers); 369 return new SimpleForwardingServerCallListener<ReqT>(oListener) { 370 @Override 371 public void onMessage(ReqT message) { 372 assertTrue(message instanceof InputStream); 373 try { 374 byte[] bytes = IoUtils.toByteArray((InputStream) message); 375 binlogReq.add(bytes); 376 ByteArrayInputStream input = new ByteArrayInputStream(bytes); 377 ReqT dup = call.getMethodDescriptor().parseRequest(input); 378 super.onMessage(dup); 379 } catch (IOException e) { 380 throw new RuntimeException(e); 381 } 382 } 383 }; 384 } 385 } 386 387 private abstract static class InvocationCountMarshaller<T> 388 implements MethodDescriptor.Marshaller<T> { 389 private int streamInvocations = 0; 390 private int parseInvocations = 0; 391 392 abstract MethodDescriptor.Marshaller<T> delegate(); 393 394 @Override 395 public InputStream stream(T value) { 396 streamInvocations++; 397 return delegate().stream(value); 398 } 399 400 @Override 401 public T parse(InputStream stream) { 402 parseInvocations++; 403 return delegate().parse(stream); 404 } 405 } 406 407 408 private static class StringMarshaller implements MethodDescriptor.Marshaller<String> { 409 public static final StringMarshaller INSTANCE = new StringMarshaller(); 410 411 @Override 412 public InputStream stream(String value) { 413 return new ByteArrayInputStream(value.getBytes(UTF_8)); 414 } 415 416 @Override 417 public String parse(InputStream stream) { 418 try { 419 return new String(ByteStreams.toByteArray(stream), UTF_8); 420 } catch (IOException ex) { 421 throw new RuntimeException(ex); 422 } 423 } 424 } 425 426 private static class IntegerMarshaller implements MethodDescriptor.Marshaller<Integer> { 427 public static final IntegerMarshaller INSTANCE = new IntegerMarshaller(); 428 429 @Override 430 public InputStream stream(Integer value) { 431 return StringMarshaller.INSTANCE.stream(value.toString()); 432 } 433 434 @Override 435 public Integer parse(InputStream stream) { 436 return Integer.valueOf(StringMarshaller.INSTANCE.parse(stream)); 437 } 438 } 439 } 440