1 /* 2 * Copyright 2018 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.alts.internal; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static java.nio.charset.StandardCharsets.UTF_8; 21 import static org.junit.Assert.assertEquals; 22 import static org.junit.Assert.assertFalse; 23 import static org.junit.Assert.assertNotNull; 24 import static org.junit.Assert.assertTrue; 25 26 import io.grpc.Attributes; 27 import io.grpc.CallCredentials; 28 import io.grpc.Grpc; 29 import io.grpc.InternalChannelz; 30 import io.grpc.SecurityLevel; 31 import io.grpc.alts.internal.Handshaker.HandshakerResult; 32 import io.grpc.alts.internal.TsiFrameProtector.Consumer; 33 import io.grpc.alts.internal.TsiPeer.Property; 34 import io.grpc.netty.GrpcHttp2ConnectionHandler; 35 import io.netty.buffer.ByteBuf; 36 import io.netty.buffer.ByteBufAllocator; 37 import io.netty.buffer.CompositeByteBuf; 38 import io.netty.buffer.Unpooled; 39 import io.netty.channel.ChannelDuplexHandler; 40 import io.netty.channel.ChannelFuture; 41 import io.netty.channel.ChannelFutureListener; 42 import io.netty.channel.ChannelHandler; 43 import io.netty.channel.ChannelHandlerContext; 44 import io.netty.channel.ChannelPromise; 45 import io.netty.channel.embedded.EmbeddedChannel; 46 import io.netty.handler.codec.http2.DefaultHttp2Connection; 47 import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder; 48 import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder; 49 import io.netty.handler.codec.http2.DefaultHttp2FrameReader; 50 import io.netty.handler.codec.http2.DefaultHttp2FrameWriter; 51 import io.netty.handler.codec.http2.Http2Connection; 52 import io.netty.handler.codec.http2.Http2ConnectionDecoder; 53 import io.netty.handler.codec.http2.Http2ConnectionEncoder; 54 import io.netty.handler.codec.http2.Http2FrameReader; 55 import io.netty.handler.codec.http2.Http2FrameWriter; 56 import io.netty.handler.codec.http2.Http2Settings; 57 import io.netty.util.ReferenceCountUtil; 58 import io.netty.util.ReferenceCounted; 59 import java.nio.ByteBuffer; 60 import java.security.GeneralSecurityException; 61 import java.util.ArrayList; 62 import java.util.Arrays; 63 import java.util.Collections; 64 import java.util.List; 65 import java.util.concurrent.Future; 66 import java.util.concurrent.LinkedBlockingQueue; 67 import java.util.concurrent.TimeUnit; 68 import java.util.concurrent.atomic.AtomicInteger; 69 import java.util.concurrent.atomic.AtomicReference; 70 import org.junit.After; 71 import org.junit.Before; 72 import org.junit.Test; 73 import org.junit.runner.RunWith; 74 import org.junit.runners.JUnit4; 75 76 /** Tests for {@link AltsProtocolNegotiator}. */ 77 @RunWith(JUnit4.class) 78 public class AltsProtocolNegotiatorTest { 79 private final CapturingGrpcHttp2ConnectionHandler grpcHandler = capturingGrpcHandler(); 80 81 private final List<ReferenceCounted> references = new ArrayList<>(); 82 private final LinkedBlockingQueue<InterceptingProtector> protectors = new LinkedBlockingQueue<>(); 83 84 private EmbeddedChannel channel; 85 private Throwable caughtException; 86 87 private volatile TsiHandshakeHandler.TsiHandshakeCompletionEvent tsiEvent; 88 private ChannelHandler handler; 89 90 private TsiPeer mockedTsiPeer = new TsiPeer(Collections.<Property<?>>emptyList()); 91 private AltsAuthContext mockedAltsContext = 92 new AltsAuthContext( 93 HandshakerResult.newBuilder() 94 .setPeerRpcVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions()) 95 .build()); 96 private final TsiHandshaker mockHandshaker = 97 new DelegatingTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerServer()) { 98 @Override 99 public TsiPeer extractPeer() throws GeneralSecurityException { 100 return mockedTsiPeer; 101 } 102 103 @Override 104 public Object extractPeerObject() throws GeneralSecurityException { 105 return mockedAltsContext; 106 } 107 }; 108 private final NettyTsiHandshaker serverHandshaker = new NettyTsiHandshaker(mockHandshaker); 109 110 @Before 111 public void setup() throws Exception { 112 ChannelHandler userEventHandler = 113 new ChannelDuplexHandler() { 114 @Override 115 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { 116 if (evt instanceof TsiHandshakeHandler.TsiHandshakeCompletionEvent) { 117 tsiEvent = (TsiHandshakeHandler.TsiHandshakeCompletionEvent) evt; 118 } else { 119 super.userEventTriggered(ctx, evt); 120 } 121 } 122 }; 123 124 ChannelHandler uncaughtExceptionHandler = 125 new ChannelDuplexHandler() { 126 @Override 127 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { 128 caughtException = cause; 129 super.exceptionCaught(ctx, cause); 130 } 131 }; 132 133 TsiHandshakerFactory handshakerFactory = 134 new DelegatingTsiHandshakerFactory(FakeTsiHandshaker.clientHandshakerFactory()) { 135 @Override 136 public TsiHandshaker newHandshaker() { 137 return new DelegatingTsiHandshaker(super.newHandshaker()) { 138 @Override 139 public TsiPeer extractPeer() throws GeneralSecurityException { 140 return mockedTsiPeer; 141 } 142 143 @Override 144 public Object extractPeerObject() throws GeneralSecurityException { 145 return mockedAltsContext; 146 } 147 }; 148 } 149 }; 150 handler = AltsProtocolNegotiator.create(handshakerFactory).newHandler(grpcHandler); 151 channel = new EmbeddedChannel(uncaughtExceptionHandler, handler, userEventHandler); 152 } 153 154 @After 155 public void teardown() throws Exception { 156 if (channel != null) { 157 @SuppressWarnings("unused") // go/futurereturn-lsc 158 Future<?> possiblyIgnoredError = channel.close(); 159 } 160 161 for (ReferenceCounted reference : references) { 162 ReferenceCountUtil.safeRelease(reference); 163 } 164 } 165 166 @Test 167 public void handshakeShouldBeSuccessful() throws Exception { 168 doHandshake(); 169 } 170 171 @Test 172 @SuppressWarnings("unchecked") // List cast 173 public void protectShouldRoundtrip() throws Exception { 174 // Write the message 1 character at a time. The message should be buffered 175 // and not interfere with the handshake. 176 final AtomicInteger writeCount = new AtomicInteger(); 177 String message = "hello"; 178 for (int ix = 0; ix < message.length(); ++ix) { 179 ByteBuf in = Unpooled.copiedBuffer(message, ix, 1, UTF_8); 180 @SuppressWarnings("unused") // go/futurereturn-lsc 181 Future<?> possiblyIgnoredError = 182 channel 183 .write(in) 184 .addListener( 185 new ChannelFutureListener() { 186 @Override 187 public void operationComplete(ChannelFuture future) throws Exception { 188 if (future.isSuccess()) { 189 writeCount.incrementAndGet(); 190 } 191 } 192 }); 193 } 194 channel.flush(); 195 196 // Now do the handshake. The buffered message will automatically be protected 197 // and sent. 198 doHandshake(); 199 200 // Capture the protected data written to the wire. 201 assertEquals(1, channel.outboundMessages().size()); 202 ByteBuf protectedData = channel.<ByteBuf>readOutbound(); 203 assertEquals(message.length(), writeCount.get()); 204 205 // Read the protected message at the server and verify it matches the original message. 206 TsiFrameProtector serverProtector = serverHandshaker.createFrameProtector(channel.alloc()); 207 List<ByteBuf> unprotected = new ArrayList<>(); 208 serverProtector.unprotect(protectedData, (List<Object>) (List<?>) unprotected, channel.alloc()); 209 // We try our best to remove the HTTP2 handler as soon as possible, but just by constructing it 210 // a settings frame is written (and an HTTP2 preface). This is hard coded into Netty, so we 211 // have to remove it here. See {@code Http2ConnectionHandler.PrefaceDecode.sendPreface}. 212 int settingsFrameLength = 9; 213 214 CompositeByteBuf unprotectedAll = 215 new CompositeByteBuf(channel.alloc(), false, unprotected.size() + 1, unprotected); 216 ByteBuf unprotectedData = unprotectedAll.slice(settingsFrameLength, message.length()); 217 assertEquals(message, unprotectedData.toString(UTF_8)); 218 219 // Protect the same message at the server. 220 final AtomicReference<ByteBuf> newlyProtectedData = new AtomicReference<>(); 221 serverProtector.protectFlush( 222 Collections.singletonList(unprotectedData), 223 new Consumer<ByteBuf>() { 224 @Override 225 public void accept(ByteBuf buf) { 226 newlyProtectedData.set(buf); 227 } 228 }, 229 channel.alloc()); 230 231 // Read the protected message at the client and verify that it matches the original message. 232 channel.writeInbound(newlyProtectedData.get()); 233 assertEquals(1, channel.inboundMessages().size()); 234 assertEquals(message, channel.<ByteBuf>readInbound().toString(UTF_8)); 235 } 236 237 @Test 238 public void unprotectLargeIncomingFrame() throws Exception { 239 240 // We use a server frameprotector with twice the standard frame size. 241 int serverFrameSize = 4096 * 2; 242 // This should fit into one frame. 243 byte[] unprotectedBytes = new byte[serverFrameSize - 500]; 244 Arrays.fill(unprotectedBytes, (byte) 7); 245 ByteBuf unprotectedData = Unpooled.wrappedBuffer(unprotectedBytes); 246 unprotectedData.writerIndex(unprotectedBytes.length); 247 248 // Perform handshake. 249 doHandshake(); 250 251 // Protect the message on the server. 252 TsiFrameProtector serverProtector = 253 serverHandshaker.createFrameProtector(serverFrameSize, channel.alloc()); 254 serverProtector.protectFlush( 255 Collections.singletonList(unprotectedData), 256 new Consumer<ByteBuf>() { 257 @Override 258 public void accept(ByteBuf buf) { 259 channel.writeInbound(buf); 260 } 261 }, 262 channel.alloc()); 263 channel.flushInbound(); 264 265 // Read the protected message at the client and verify that it matches the original message. 266 assertEquals(1, channel.inboundMessages().size()); 267 268 ByteBuf receivedData1 = channel.<ByteBuf>readInbound(); 269 int receivedLen1 = receivedData1.readableBytes(); 270 byte[] receivedBytes = new byte[receivedLen1]; 271 receivedData1.readBytes(receivedBytes, 0, receivedLen1); 272 273 assertThat(unprotectedBytes.length).isEqualTo(receivedBytes.length); 274 assertThat(unprotectedBytes).isEqualTo(receivedBytes); 275 } 276 277 @Test 278 public void flushShouldFailAllPromises() throws Exception { 279 doHandshake(); 280 281 channel 282 .pipeline() 283 .addFirst( 284 new ChannelDuplexHandler() { 285 @Override 286 public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) 287 throws Exception { 288 throw new Exception("Fake exception"); 289 } 290 }); 291 292 // Write the message 1 character at a time. 293 String message = "hello"; 294 final AtomicInteger failures = new AtomicInteger(); 295 for (int ix = 0; ix < message.length(); ++ix) { 296 ByteBuf in = Unpooled.copiedBuffer(message, ix, 1, UTF_8); 297 @SuppressWarnings("unused") // go/futurereturn-lsc 298 Future<?> possiblyIgnoredError = 299 channel 300 .write(in) 301 .addListener( 302 new ChannelFutureListener() { 303 @Override 304 public void operationComplete(ChannelFuture future) throws Exception { 305 if (!future.isSuccess()) { 306 failures.incrementAndGet(); 307 } 308 } 309 }); 310 } 311 channel.flush(); 312 313 // Verify that the promises fail. 314 assertEquals(message.length(), failures.get()); 315 } 316 317 @Test 318 public void doNotFlushEmptyBuffer() throws Exception { 319 doHandshake(); 320 assertEquals(1, protectors.size()); 321 InterceptingProtector protector = protectors.poll(); 322 323 String message = "hello"; 324 ByteBuf in = Unpooled.copiedBuffer(message, UTF_8); 325 326 assertEquals(0, protector.flushes.get()); 327 Future<?> done = channel.write(in); 328 channel.flush(); 329 done.get(5, TimeUnit.SECONDS); 330 assertEquals(1, protector.flushes.get()); 331 332 done = channel.write(Unpooled.EMPTY_BUFFER); 333 channel.flush(); 334 done.get(5, TimeUnit.SECONDS); 335 assertEquals(1, protector.flushes.get()); 336 } 337 338 @Test 339 public void peerPropagated() throws Exception { 340 doHandshake(); 341 342 assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.getTsiPeerAttributeKey())) 343 .isEqualTo(mockedTsiPeer); 344 assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.getAltsAuthContextAttributeKey())) 345 .isEqualTo(mockedAltsContext); 346 assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString()) 347 .isEqualTo("embedded"); 348 assertThat(grpcHandler.attrs.get(CallCredentials.ATTR_SECURITY_LEVEL)) 349 .isEqualTo(SecurityLevel.PRIVACY_AND_INTEGRITY); 350 } 351 352 private void doHandshake() throws Exception { 353 // Capture the client frame and add to the server. 354 assertEquals(1, channel.outboundMessages().size()); 355 ByteBuf clientFrame = channel.<ByteBuf>readOutbound(); 356 assertTrue(serverHandshaker.processBytesFromPeer(clientFrame)); 357 358 // Get the server response handshake frames. 359 ByteBuf serverFrame = channel.alloc().buffer(); 360 serverHandshaker.getBytesToSendToPeer(serverFrame); 361 channel.writeInbound(serverFrame); 362 363 // Capture the next client frame and add to the server. 364 assertEquals(1, channel.outboundMessages().size()); 365 clientFrame = channel.<ByteBuf>readOutbound(); 366 assertTrue(serverHandshaker.processBytesFromPeer(clientFrame)); 367 368 // Get the server response handshake frames. 369 serverFrame = channel.alloc().buffer(); 370 serverHandshaker.getBytesToSendToPeer(serverFrame); 371 channel.writeInbound(serverFrame); 372 373 // Ensure that both sides have confirmed that the handshake has completed. 374 assertFalse(serverHandshaker.isInProgress()); 375 376 if (caughtException != null) { 377 throw new RuntimeException(caughtException); 378 } 379 assertNotNull(tsiEvent); 380 } 381 382 private CapturingGrpcHttp2ConnectionHandler capturingGrpcHandler() { 383 // Netty Boilerplate. We don't really need any of this, but there is a tight coupling 384 // between a Http2ConnectionHandler and its dependencies. 385 Http2Connection connection = new DefaultHttp2Connection(true); 386 Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter(); 387 Http2FrameReader frameReader = new DefaultHttp2FrameReader(false); 388 DefaultHttp2ConnectionEncoder encoder = 389 new DefaultHttp2ConnectionEncoder(connection, frameWriter); 390 DefaultHttp2ConnectionDecoder decoder = 391 new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader); 392 393 return new CapturingGrpcHttp2ConnectionHandler(decoder, encoder, new Http2Settings()); 394 } 395 396 private final class CapturingGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler { 397 private Attributes attrs; 398 399 private CapturingGrpcHttp2ConnectionHandler( 400 Http2ConnectionDecoder decoder, 401 Http2ConnectionEncoder encoder, 402 Http2Settings initialSettings) { 403 super(null, decoder, encoder, initialSettings); 404 } 405 406 @Override 407 public void handleProtocolNegotiationCompleted( 408 Attributes attrs, InternalChannelz.Security securityInfo) { 409 // If we are added to the pipeline, we need to remove ourselves. The HTTP2 handler 410 channel.pipeline().remove(this); 411 this.attrs = attrs; 412 } 413 } 414 415 private static class DelegatingTsiHandshakerFactory implements TsiHandshakerFactory { 416 417 private TsiHandshakerFactory delegate; 418 419 DelegatingTsiHandshakerFactory(TsiHandshakerFactory delegate) { 420 this.delegate = delegate; 421 } 422 423 @Override 424 public TsiHandshaker newHandshaker() { 425 return delegate.newHandshaker(); 426 } 427 } 428 429 private class DelegatingTsiHandshaker implements TsiHandshaker { 430 431 private final TsiHandshaker delegate; 432 433 DelegatingTsiHandshaker(TsiHandshaker delegate) { 434 this.delegate = delegate; 435 } 436 437 @Override 438 public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException { 439 delegate.getBytesToSendToPeer(bytes); 440 } 441 442 @Override 443 public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException { 444 return delegate.processBytesFromPeer(bytes); 445 } 446 447 @Override 448 public boolean isInProgress() { 449 return delegate.isInProgress(); 450 } 451 452 @Override 453 public TsiPeer extractPeer() throws GeneralSecurityException { 454 return delegate.extractPeer(); 455 } 456 457 @Override 458 public Object extractPeerObject() throws GeneralSecurityException { 459 return delegate.extractPeerObject(); 460 } 461 462 @Override 463 public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) { 464 InterceptingProtector protector = 465 new InterceptingProtector(delegate.createFrameProtector(alloc)); 466 protectors.add(protector); 467 return protector; 468 } 469 470 @Override 471 public TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) { 472 InterceptingProtector protector = 473 new InterceptingProtector(delegate.createFrameProtector(maxFrameSize, alloc)); 474 protectors.add(protector); 475 return protector; 476 } 477 } 478 479 private static class InterceptingProtector implements TsiFrameProtector { 480 private final TsiFrameProtector delegate; 481 final AtomicInteger flushes = new AtomicInteger(); 482 483 InterceptingProtector(TsiFrameProtector delegate) { 484 this.delegate = delegate; 485 } 486 487 @Override 488 public void protectFlush( 489 List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc) 490 throws GeneralSecurityException { 491 flushes.incrementAndGet(); 492 delegate.protectFlush(unprotectedBufs, ctxWrite, alloc); 493 } 494 495 @Override 496 public void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc) 497 throws GeneralSecurityException { 498 delegate.unprotect(in, out, alloc); 499 } 500 501 @Override 502 public void destroy() { 503 delegate.destroy(); 504 } 505 } 506 } 507