Home | History | Annotate | Download | only in netty
      1 /*
      2  * Copyright 2014 The gRPC Authors
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package io.grpc.netty;
     18 
     19 import static com.google.common.base.Charsets.US_ASCII;
     20 import static io.grpc.netty.NettyTestUtil.messageFrame;
     21 import static org.junit.Assert.assertEquals;
     22 import static org.junit.Assert.assertFalse;
     23 import static org.junit.Assert.assertNull;
     24 import static org.junit.Assert.assertTrue;
     25 import static org.mockito.Matchers.any;
     26 import static org.mockito.Matchers.anyBoolean;
     27 import static org.mockito.Mockito.doAnswer;
     28 import static org.mockito.Mockito.never;
     29 import static org.mockito.Mockito.verify;
     30 import static org.mockito.Mockito.when;
     31 
     32 import io.grpc.internal.Stream;
     33 import io.grpc.internal.StreamListener;
     34 import io.grpc.netty.WriteQueue.QueuedCommand;
     35 import io.netty.buffer.UnpooledByteBufAllocator;
     36 import io.netty.channel.Channel;
     37 import io.netty.channel.ChannelHandlerContext;
     38 import io.netty.channel.ChannelPipeline;
     39 import io.netty.channel.ChannelPromise;
     40 import io.netty.channel.DefaultChannelPromise;
     41 import io.netty.channel.EventLoop;
     42 import io.netty.handler.codec.http2.Http2Stream;
     43 import java.io.ByteArrayInputStream;
     44 import java.io.IOException;
     45 import java.io.InputStream;
     46 import java.util.Queue;
     47 import org.junit.Before;
     48 import org.junit.Test;
     49 import org.mockito.Mock;
     50 import org.mockito.MockitoAnnotations;
     51 import org.mockito.invocation.InvocationOnMock;
     52 import org.mockito.stubbing.Answer;
     53 
     54 /**
     55  * Base class for Netty stream unit tests.
     56  */
     57 public abstract class NettyStreamTestBase<T extends Stream> {
     58   protected static final String MESSAGE = "hello world";
     59   protected static final int STREAM_ID = 1;
     60 
     61   @Mock
     62   protected Channel channel;
     63 
     64   @Mock
     65   private ChannelHandlerContext ctx;
     66 
     67   @Mock
     68   private ChannelPipeline pipeline;
     69 
     70   @Mock
     71   protected EventLoop eventLoop;
     72 
     73   // ChannelPromise has too many methods to implement; we stubbed all necessary methods of Future.
     74   @SuppressWarnings("DoNotMock")
     75   @Mock
     76   protected ChannelPromise promise;
     77 
     78   @Mock
     79   protected Http2Stream http2Stream;
     80 
     81   @Mock
     82   protected WriteQueue writeQueue;
     83 
     84   protected T stream;
     85 
     86   /** Set up for test. */
     87   @Before
     88   public void setUp() {
     89     MockitoAnnotations.initMocks(this);
     90 
     91     when(channel.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
     92     when(channel.pipeline()).thenReturn(pipeline);
     93     when(channel.eventLoop()).thenReturn(eventLoop);
     94     when(channel.newPromise()).thenReturn(new DefaultChannelPromise(channel));
     95     when(channel.voidPromise()).thenReturn(new DefaultChannelPromise(channel));
     96     ChannelPromise completedPromise = new DefaultChannelPromise(channel)
     97         .setSuccess();
     98     when(channel.write(any())).thenReturn(completedPromise);
     99     when(channel.writeAndFlush(any())).thenReturn(completedPromise);
    100     when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(completedPromise);
    101     when(pipeline.firstContext()).thenReturn(ctx);
    102     when(eventLoop.inEventLoop()).thenReturn(true);
    103     when(http2Stream.id()).thenReturn(STREAM_ID);
    104 
    105     doAnswer(new Answer<Void>() {
    106       @Override
    107       public Void answer(InvocationOnMock invocation) throws Throwable {
    108         Runnable runnable = (Runnable) invocation.getArguments()[0];
    109         runnable.run();
    110         return null;
    111       }
    112     }).when(eventLoop).execute(any(Runnable.class));
    113 
    114     stream = createStream();
    115   }
    116 
    117   @Test
    118   public void inboundMessageShouldCallListener() throws Exception {
    119     stream.request(1);
    120 
    121     if (stream instanceof NettyServerStream) {
    122       ((NettyServerStream) stream).transportState()
    123           .inboundDataReceived(messageFrame(MESSAGE), false);
    124     } else {
    125       ((NettyClientStream) stream).transportState()
    126           .transportDataReceived(messageFrame(MESSAGE), false);
    127     }
    128 
    129     InputStream message = listenerMessageQueue().poll();
    130 
    131     // Verify that inbound flow control window update has been disabled for the stream.
    132     assertEquals(MESSAGE, NettyTestUtil.toString(message));
    133     assertNull("no additional message expected", listenerMessageQueue().poll());
    134   }
    135 
    136   @Test
    137   public void shouldBeImmediatelyReadyForData() {
    138     assertTrue(stream.isReady());
    139   }
    140 
    141   @Test
    142   public void closedShouldNotBeReady() throws IOException {
    143     assertTrue(stream.isReady());
    144     closeStream();
    145     assertFalse(stream.isReady());
    146   }
    147 
    148   @Test
    149   public void notifiedOnReadyAfterWriteCompletes() throws IOException {
    150     sendHeadersIfServer();
    151     assertTrue(stream.isReady());
    152     byte[] msg = largeMessage();
    153     // The channel.write future is set up to automatically complete, indicating that the write is
    154     // done.
    155     stream.writeMessage(new ByteArrayInputStream(msg));
    156     stream.flush();
    157     assertTrue(stream.isReady());
    158     verify(listener()).onReady();
    159   }
    160 
    161   @Test
    162   public void shouldBeReadyForDataAfterWritingSmallMessage() throws IOException {
    163     sendHeadersIfServer();
    164     // Make sure the writes don't complete so we "back up"
    165     ChannelPromise uncompletedPromise = new DefaultChannelPromise(channel);
    166     when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(uncompletedPromise);
    167 
    168     assertTrue(stream.isReady());
    169     byte[] msg = smallMessage();
    170     stream.writeMessage(new ByteArrayInputStream(msg));
    171     stream.flush();
    172     assertTrue(stream.isReady());
    173     verify(listener(), never()).onReady();
    174   }
    175 
    176   @Test
    177   public void shouldNotBeReadyForDataAfterWritingLargeMessage() throws IOException {
    178     sendHeadersIfServer();
    179     // Make sure the writes don't complete so we "back up"
    180     ChannelPromise uncompletedPromise = new DefaultChannelPromise(channel);
    181     when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(uncompletedPromise);
    182 
    183     assertTrue(stream.isReady());
    184     byte[] msg = largeMessage();
    185     stream.writeMessage(new ByteArrayInputStream(msg));
    186     stream.flush();
    187     assertFalse(stream.isReady());
    188     verify(listener(), never()).onReady();
    189   }
    190 
    191   protected byte[] smallMessage() {
    192     return MESSAGE.getBytes(US_ASCII);
    193   }
    194 
    195   protected byte[] largeMessage() {
    196     byte[] smallMessage = smallMessage();
    197     int size = smallMessage.length * 10 * 1024;
    198     byte[] largeMessage = new byte[size];
    199     for (int ix = 0; ix < size; ix += smallMessage.length) {
    200       System.arraycopy(smallMessage, 0, largeMessage, ix, smallMessage.length);
    201     }
    202     return largeMessage;
    203   }
    204 
    205   protected abstract T createStream();
    206 
    207   protected abstract void sendHeadersIfServer();
    208 
    209   protected abstract StreamListener listener();
    210 
    211   protected abstract Queue<InputStream> listenerMessageQueue();
    212 
    213   protected abstract void closeStream();
    214 }
    215