Home | History | Annotate | Download | only in internal
      1 /*
      2  * Copyright 2018 The gRPC Authors
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package io.grpc.alts.internal;
     18 
     19 import static com.google.common.truth.Truth.assertThat;
     20 import static java.nio.charset.StandardCharsets.UTF_8;
     21 import static org.junit.Assert.assertEquals;
     22 import static org.junit.Assert.assertFalse;
     23 import static org.junit.Assert.assertNotNull;
     24 import static org.junit.Assert.assertTrue;
     25 
     26 import io.grpc.Attributes;
     27 import io.grpc.CallCredentials;
     28 import io.grpc.Grpc;
     29 import io.grpc.InternalChannelz;
     30 import io.grpc.SecurityLevel;
     31 import io.grpc.alts.internal.Handshaker.HandshakerResult;
     32 import io.grpc.alts.internal.TsiFrameProtector.Consumer;
     33 import io.grpc.alts.internal.TsiPeer.Property;
     34 import io.grpc.netty.GrpcHttp2ConnectionHandler;
     35 import io.netty.buffer.ByteBuf;
     36 import io.netty.buffer.ByteBufAllocator;
     37 import io.netty.buffer.CompositeByteBuf;
     38 import io.netty.buffer.Unpooled;
     39 import io.netty.channel.ChannelDuplexHandler;
     40 import io.netty.channel.ChannelFuture;
     41 import io.netty.channel.ChannelFutureListener;
     42 import io.netty.channel.ChannelHandler;
     43 import io.netty.channel.ChannelHandlerContext;
     44 import io.netty.channel.ChannelPromise;
     45 import io.netty.channel.embedded.EmbeddedChannel;
     46 import io.netty.handler.codec.http2.DefaultHttp2Connection;
     47 import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder;
     48 import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder;
     49 import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
     50 import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
     51 import io.netty.handler.codec.http2.Http2Connection;
     52 import io.netty.handler.codec.http2.Http2ConnectionDecoder;
     53 import io.netty.handler.codec.http2.Http2ConnectionEncoder;
     54 import io.netty.handler.codec.http2.Http2FrameReader;
     55 import io.netty.handler.codec.http2.Http2FrameWriter;
     56 import io.netty.handler.codec.http2.Http2Settings;
     57 import io.netty.util.ReferenceCountUtil;
     58 import io.netty.util.ReferenceCounted;
     59 import java.nio.ByteBuffer;
     60 import java.security.GeneralSecurityException;
     61 import java.util.ArrayList;
     62 import java.util.Arrays;
     63 import java.util.Collections;
     64 import java.util.List;
     65 import java.util.concurrent.Future;
     66 import java.util.concurrent.LinkedBlockingQueue;
     67 import java.util.concurrent.TimeUnit;
     68 import java.util.concurrent.atomic.AtomicInteger;
     69 import java.util.concurrent.atomic.AtomicReference;
     70 import org.junit.After;
     71 import org.junit.Before;
     72 import org.junit.Test;
     73 import org.junit.runner.RunWith;
     74 import org.junit.runners.JUnit4;
     75 
     76 /** Tests for {@link AltsProtocolNegotiator}. */
     77 @RunWith(JUnit4.class)
     78 public class AltsProtocolNegotiatorTest {
     79   private final CapturingGrpcHttp2ConnectionHandler grpcHandler = capturingGrpcHandler();
     80 
     81   private final List<ReferenceCounted> references = new ArrayList<>();
     82   private final LinkedBlockingQueue<InterceptingProtector> protectors = new LinkedBlockingQueue<>();
     83 
     84   private EmbeddedChannel channel;
     85   private Throwable caughtException;
     86 
     87   private volatile TsiHandshakeHandler.TsiHandshakeCompletionEvent tsiEvent;
     88   private ChannelHandler handler;
     89 
     90   private TsiPeer mockedTsiPeer = new TsiPeer(Collections.<Property<?>>emptyList());
     91   private AltsAuthContext mockedAltsContext =
     92       new AltsAuthContext(
     93           HandshakerResult.newBuilder()
     94               .setPeerRpcVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
     95               .build());
     96   private final TsiHandshaker mockHandshaker =
     97       new DelegatingTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerServer()) {
     98         @Override
     99         public TsiPeer extractPeer() throws GeneralSecurityException {
    100           return mockedTsiPeer;
    101         }
    102 
    103         @Override
    104         public Object extractPeerObject() throws GeneralSecurityException {
    105           return mockedAltsContext;
    106         }
    107       };
    108   private final NettyTsiHandshaker serverHandshaker = new NettyTsiHandshaker(mockHandshaker);
    109 
    110   @Before
    111   public void setup() throws Exception {
    112     ChannelHandler userEventHandler =
    113         new ChannelDuplexHandler() {
    114           @Override
    115           public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
    116             if (evt instanceof TsiHandshakeHandler.TsiHandshakeCompletionEvent) {
    117               tsiEvent = (TsiHandshakeHandler.TsiHandshakeCompletionEvent) evt;
    118             } else {
    119               super.userEventTriggered(ctx, evt);
    120             }
    121           }
    122         };
    123 
    124     ChannelHandler uncaughtExceptionHandler =
    125         new ChannelDuplexHandler() {
    126           @Override
    127           public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
    128             caughtException = cause;
    129             super.exceptionCaught(ctx, cause);
    130           }
    131         };
    132 
    133     TsiHandshakerFactory handshakerFactory =
    134         new DelegatingTsiHandshakerFactory(FakeTsiHandshaker.clientHandshakerFactory()) {
    135           @Override
    136           public TsiHandshaker newHandshaker() {
    137             return new DelegatingTsiHandshaker(super.newHandshaker()) {
    138               @Override
    139               public TsiPeer extractPeer() throws GeneralSecurityException {
    140                 return mockedTsiPeer;
    141               }
    142 
    143               @Override
    144               public Object extractPeerObject() throws GeneralSecurityException {
    145                 return mockedAltsContext;
    146               }
    147             };
    148           }
    149         };
    150     handler = AltsProtocolNegotiator.create(handshakerFactory).newHandler(grpcHandler);
    151     channel = new EmbeddedChannel(uncaughtExceptionHandler, handler, userEventHandler);
    152   }
    153 
    154   @After
    155   public void teardown() throws Exception {
    156     if (channel != null) {
    157       @SuppressWarnings("unused") // go/futurereturn-lsc
    158       Future<?> possiblyIgnoredError = channel.close();
    159     }
    160 
    161     for (ReferenceCounted reference : references) {
    162       ReferenceCountUtil.safeRelease(reference);
    163     }
    164   }
    165 
    166   @Test
    167   public void handshakeShouldBeSuccessful() throws Exception {
    168     doHandshake();
    169   }
    170 
    171   @Test
    172   @SuppressWarnings("unchecked") // List cast
    173   public void protectShouldRoundtrip() throws Exception {
    174     // Write the message 1 character at a time. The message should be buffered
    175     // and not interfere with the handshake.
    176     final AtomicInteger writeCount = new AtomicInteger();
    177     String message = "hello";
    178     for (int ix = 0; ix < message.length(); ++ix) {
    179       ByteBuf in = Unpooled.copiedBuffer(message, ix, 1, UTF_8);
    180       @SuppressWarnings("unused") // go/futurereturn-lsc
    181       Future<?> possiblyIgnoredError =
    182           channel
    183               .write(in)
    184               .addListener(
    185                   new ChannelFutureListener() {
    186                     @Override
    187                     public void operationComplete(ChannelFuture future) throws Exception {
    188                       if (future.isSuccess()) {
    189                         writeCount.incrementAndGet();
    190                       }
    191                     }
    192                   });
    193     }
    194     channel.flush();
    195 
    196     // Now do the handshake. The buffered message will automatically be protected
    197     // and sent.
    198     doHandshake();
    199 
    200     // Capture the protected data written to the wire.
    201     assertEquals(1, channel.outboundMessages().size());
    202     ByteBuf protectedData = channel.<ByteBuf>readOutbound();
    203     assertEquals(message.length(), writeCount.get());
    204 
    205     // Read the protected message at the server and verify it matches the original message.
    206     TsiFrameProtector serverProtector = serverHandshaker.createFrameProtector(channel.alloc());
    207     List<ByteBuf> unprotected = new ArrayList<>();
    208     serverProtector.unprotect(protectedData, (List<Object>) (List<?>) unprotected, channel.alloc());
    209     // We try our best to remove the HTTP2 handler as soon as possible, but just by constructing it
    210     // a settings frame is written (and an HTTP2 preface).  This is hard coded into Netty, so we
    211     // have to remove it here.  See {@code Http2ConnectionHandler.PrefaceDecode.sendPreface}.
    212     int settingsFrameLength = 9;
    213 
    214     CompositeByteBuf unprotectedAll =
    215         new CompositeByteBuf(channel.alloc(), false, unprotected.size() + 1, unprotected);
    216     ByteBuf unprotectedData = unprotectedAll.slice(settingsFrameLength, message.length());
    217     assertEquals(message, unprotectedData.toString(UTF_8));
    218 
    219     // Protect the same message at the server.
    220     final AtomicReference<ByteBuf> newlyProtectedData = new AtomicReference<>();
    221     serverProtector.protectFlush(
    222         Collections.singletonList(unprotectedData),
    223         new Consumer<ByteBuf>() {
    224           @Override
    225           public void accept(ByteBuf buf) {
    226             newlyProtectedData.set(buf);
    227           }
    228         },
    229         channel.alloc());
    230 
    231     // Read the protected message at the client and verify that it matches the original message.
    232     channel.writeInbound(newlyProtectedData.get());
    233     assertEquals(1, channel.inboundMessages().size());
    234     assertEquals(message, channel.<ByteBuf>readInbound().toString(UTF_8));
    235   }
    236 
    237   @Test
    238   public void unprotectLargeIncomingFrame() throws Exception {
    239 
    240     // We use a server frameprotector with twice the standard frame size.
    241     int serverFrameSize = 4096 * 2;
    242     // This should fit into one frame.
    243     byte[] unprotectedBytes = new byte[serverFrameSize - 500];
    244     Arrays.fill(unprotectedBytes, (byte) 7);
    245     ByteBuf unprotectedData = Unpooled.wrappedBuffer(unprotectedBytes);
    246     unprotectedData.writerIndex(unprotectedBytes.length);
    247 
    248     // Perform handshake.
    249     doHandshake();
    250 
    251     // Protect the message on the server.
    252     TsiFrameProtector serverProtector =
    253         serverHandshaker.createFrameProtector(serverFrameSize, channel.alloc());
    254     serverProtector.protectFlush(
    255         Collections.singletonList(unprotectedData),
    256         new Consumer<ByteBuf>() {
    257           @Override
    258           public void accept(ByteBuf buf) {
    259             channel.writeInbound(buf);
    260           }
    261         },
    262         channel.alloc());
    263     channel.flushInbound();
    264 
    265     // Read the protected message at the client and verify that it matches the original message.
    266     assertEquals(1, channel.inboundMessages().size());
    267 
    268     ByteBuf receivedData1 = channel.<ByteBuf>readInbound();
    269     int receivedLen1 = receivedData1.readableBytes();
    270     byte[] receivedBytes = new byte[receivedLen1];
    271     receivedData1.readBytes(receivedBytes, 0, receivedLen1);
    272 
    273     assertThat(unprotectedBytes.length).isEqualTo(receivedBytes.length);
    274     assertThat(unprotectedBytes).isEqualTo(receivedBytes);
    275   }
    276 
    277   @Test
    278   public void flushShouldFailAllPromises() throws Exception {
    279     doHandshake();
    280 
    281     channel
    282         .pipeline()
    283         .addFirst(
    284             new ChannelDuplexHandler() {
    285               @Override
    286               public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
    287                   throws Exception {
    288                 throw new Exception("Fake exception");
    289               }
    290             });
    291 
    292     // Write the message 1 character at a time.
    293     String message = "hello";
    294     final AtomicInteger failures = new AtomicInteger();
    295     for (int ix = 0; ix < message.length(); ++ix) {
    296       ByteBuf in = Unpooled.copiedBuffer(message, ix, 1, UTF_8);
    297       @SuppressWarnings("unused") // go/futurereturn-lsc
    298       Future<?> possiblyIgnoredError =
    299           channel
    300               .write(in)
    301               .addListener(
    302                   new ChannelFutureListener() {
    303                     @Override
    304                     public void operationComplete(ChannelFuture future) throws Exception {
    305                       if (!future.isSuccess()) {
    306                         failures.incrementAndGet();
    307                       }
    308                     }
    309                   });
    310     }
    311     channel.flush();
    312 
    313     // Verify that the promises fail.
    314     assertEquals(message.length(), failures.get());
    315   }
    316 
    317   @Test
    318   public void doNotFlushEmptyBuffer() throws Exception {
    319     doHandshake();
    320     assertEquals(1, protectors.size());
    321     InterceptingProtector protector = protectors.poll();
    322 
    323     String message = "hello";
    324     ByteBuf in = Unpooled.copiedBuffer(message, UTF_8);
    325 
    326     assertEquals(0, protector.flushes.get());
    327     Future<?> done = channel.write(in);
    328     channel.flush();
    329     done.get(5, TimeUnit.SECONDS);
    330     assertEquals(1, protector.flushes.get());
    331 
    332     done = channel.write(Unpooled.EMPTY_BUFFER);
    333     channel.flush();
    334     done.get(5, TimeUnit.SECONDS);
    335     assertEquals(1, protector.flushes.get());
    336   }
    337 
    338   @Test
    339   public void peerPropagated() throws Exception {
    340     doHandshake();
    341 
    342     assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.getTsiPeerAttributeKey()))
    343         .isEqualTo(mockedTsiPeer);
    344     assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.getAltsAuthContextAttributeKey()))
    345         .isEqualTo(mockedAltsContext);
    346     assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString())
    347         .isEqualTo("embedded");
    348     assertThat(grpcHandler.attrs.get(CallCredentials.ATTR_SECURITY_LEVEL))
    349         .isEqualTo(SecurityLevel.PRIVACY_AND_INTEGRITY);
    350   }
    351 
    352   private void doHandshake() throws Exception {
    353     // Capture the client frame and add to the server.
    354     assertEquals(1, channel.outboundMessages().size());
    355     ByteBuf clientFrame = channel.<ByteBuf>readOutbound();
    356     assertTrue(serverHandshaker.processBytesFromPeer(clientFrame));
    357 
    358     // Get the server response handshake frames.
    359     ByteBuf serverFrame = channel.alloc().buffer();
    360     serverHandshaker.getBytesToSendToPeer(serverFrame);
    361     channel.writeInbound(serverFrame);
    362 
    363     // Capture the next client frame and add to the server.
    364     assertEquals(1, channel.outboundMessages().size());
    365     clientFrame = channel.<ByteBuf>readOutbound();
    366     assertTrue(serverHandshaker.processBytesFromPeer(clientFrame));
    367 
    368     // Get the server response handshake frames.
    369     serverFrame = channel.alloc().buffer();
    370     serverHandshaker.getBytesToSendToPeer(serverFrame);
    371     channel.writeInbound(serverFrame);
    372 
    373     // Ensure that both sides have confirmed that the handshake has completed.
    374     assertFalse(serverHandshaker.isInProgress());
    375 
    376     if (caughtException != null) {
    377       throw new RuntimeException(caughtException);
    378     }
    379     assertNotNull(tsiEvent);
    380   }
    381 
    382   private CapturingGrpcHttp2ConnectionHandler capturingGrpcHandler() {
    383     // Netty Boilerplate.  We don't really need any of this, but there is a tight coupling
    384     // between a Http2ConnectionHandler and its dependencies.
    385     Http2Connection connection = new DefaultHttp2Connection(true);
    386     Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter();
    387     Http2FrameReader frameReader = new DefaultHttp2FrameReader(false);
    388     DefaultHttp2ConnectionEncoder encoder =
    389         new DefaultHttp2ConnectionEncoder(connection, frameWriter);
    390     DefaultHttp2ConnectionDecoder decoder =
    391         new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader);
    392 
    393     return new CapturingGrpcHttp2ConnectionHandler(decoder, encoder, new Http2Settings());
    394   }
    395 
    396   private final class CapturingGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {
    397     private Attributes attrs;
    398 
    399     private CapturingGrpcHttp2ConnectionHandler(
    400         Http2ConnectionDecoder decoder,
    401         Http2ConnectionEncoder encoder,
    402         Http2Settings initialSettings) {
    403       super(null, decoder, encoder, initialSettings);
    404     }
    405 
    406     @Override
    407     public void handleProtocolNegotiationCompleted(
    408         Attributes attrs, InternalChannelz.Security securityInfo) {
    409       // If we are added to the pipeline, we need to remove ourselves.  The HTTP2 handler
    410       channel.pipeline().remove(this);
    411       this.attrs = attrs;
    412     }
    413   }
    414 
    415   private static class DelegatingTsiHandshakerFactory implements TsiHandshakerFactory {
    416 
    417     private TsiHandshakerFactory delegate;
    418 
    419     DelegatingTsiHandshakerFactory(TsiHandshakerFactory delegate) {
    420       this.delegate = delegate;
    421     }
    422 
    423     @Override
    424     public TsiHandshaker newHandshaker() {
    425       return delegate.newHandshaker();
    426     }
    427   }
    428 
    429   private class DelegatingTsiHandshaker implements TsiHandshaker {
    430 
    431     private final TsiHandshaker delegate;
    432 
    433     DelegatingTsiHandshaker(TsiHandshaker delegate) {
    434       this.delegate = delegate;
    435     }
    436 
    437     @Override
    438     public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException {
    439       delegate.getBytesToSendToPeer(bytes);
    440     }
    441 
    442     @Override
    443     public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException {
    444       return delegate.processBytesFromPeer(bytes);
    445     }
    446 
    447     @Override
    448     public boolean isInProgress() {
    449       return delegate.isInProgress();
    450     }
    451 
    452     @Override
    453     public TsiPeer extractPeer() throws GeneralSecurityException {
    454       return delegate.extractPeer();
    455     }
    456 
    457     @Override
    458     public Object extractPeerObject() throws GeneralSecurityException {
    459       return delegate.extractPeerObject();
    460     }
    461 
    462     @Override
    463     public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) {
    464       InterceptingProtector protector =
    465           new InterceptingProtector(delegate.createFrameProtector(alloc));
    466       protectors.add(protector);
    467       return protector;
    468     }
    469 
    470     @Override
    471     public TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) {
    472       InterceptingProtector protector =
    473           new InterceptingProtector(delegate.createFrameProtector(maxFrameSize, alloc));
    474       protectors.add(protector);
    475       return protector;
    476     }
    477   }
    478 
    479   private static class InterceptingProtector implements TsiFrameProtector {
    480     private final TsiFrameProtector delegate;
    481     final AtomicInteger flushes = new AtomicInteger();
    482 
    483     InterceptingProtector(TsiFrameProtector delegate) {
    484       this.delegate = delegate;
    485     }
    486 
    487     @Override
    488     public void protectFlush(
    489         List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)
    490         throws GeneralSecurityException {
    491       flushes.incrementAndGet();
    492       delegate.protectFlush(unprotectedBufs, ctxWrite, alloc);
    493     }
    494 
    495     @Override
    496     public void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
    497         throws GeneralSecurityException {
    498       delegate.unprotect(in, out, alloc);
    499     }
    500 
    501     @Override
    502     public void destroy() {
    503       delegate.destroy();
    504     }
    505   }
    506 }
    507