1 /* 2 * Copyright 2014 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.netty; 18 19 import static com.google.common.base.Charsets.US_ASCII; 20 import static io.grpc.netty.NettyTestUtil.messageFrame; 21 import static org.junit.Assert.assertEquals; 22 import static org.junit.Assert.assertFalse; 23 import static org.junit.Assert.assertNull; 24 import static org.junit.Assert.assertTrue; 25 import static org.mockito.Matchers.any; 26 import static org.mockito.Matchers.anyBoolean; 27 import static org.mockito.Mockito.doAnswer; 28 import static org.mockito.Mockito.never; 29 import static org.mockito.Mockito.verify; 30 import static org.mockito.Mockito.when; 31 32 import io.grpc.internal.Stream; 33 import io.grpc.internal.StreamListener; 34 import io.grpc.netty.WriteQueue.QueuedCommand; 35 import io.netty.buffer.UnpooledByteBufAllocator; 36 import io.netty.channel.Channel; 37 import io.netty.channel.ChannelHandlerContext; 38 import io.netty.channel.ChannelPipeline; 39 import io.netty.channel.ChannelPromise; 40 import io.netty.channel.DefaultChannelPromise; 41 import io.netty.channel.EventLoop; 42 import io.netty.handler.codec.http2.Http2Stream; 43 import java.io.ByteArrayInputStream; 44 import java.io.IOException; 45 import java.io.InputStream; 46 import java.util.Queue; 47 import org.junit.Before; 48 import org.junit.Test; 49 import org.mockito.Mock; 50 import org.mockito.MockitoAnnotations; 51 import org.mockito.invocation.InvocationOnMock; 52 import org.mockito.stubbing.Answer; 53 54 /** 55 * Base class for Netty stream unit tests. 56 */ 57 public abstract class NettyStreamTestBase<T extends Stream> { 58 protected static final String MESSAGE = "hello world"; 59 protected static final int STREAM_ID = 1; 60 61 @Mock 62 protected Channel channel; 63 64 @Mock 65 private ChannelHandlerContext ctx; 66 67 @Mock 68 private ChannelPipeline pipeline; 69 70 @Mock 71 protected EventLoop eventLoop; 72 73 // ChannelPromise has too many methods to implement; we stubbed all necessary methods of Future. 74 @SuppressWarnings("DoNotMock") 75 @Mock 76 protected ChannelPromise promise; 77 78 @Mock 79 protected Http2Stream http2Stream; 80 81 @Mock 82 protected WriteQueue writeQueue; 83 84 protected T stream; 85 86 /** Set up for test. */ 87 @Before 88 public void setUp() { 89 MockitoAnnotations.initMocks(this); 90 91 when(channel.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); 92 when(channel.pipeline()).thenReturn(pipeline); 93 when(channel.eventLoop()).thenReturn(eventLoop); 94 when(channel.newPromise()).thenReturn(new DefaultChannelPromise(channel)); 95 when(channel.voidPromise()).thenReturn(new DefaultChannelPromise(channel)); 96 ChannelPromise completedPromise = new DefaultChannelPromise(channel) 97 .setSuccess(); 98 when(channel.write(any())).thenReturn(completedPromise); 99 when(channel.writeAndFlush(any())).thenReturn(completedPromise); 100 when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(completedPromise); 101 when(pipeline.firstContext()).thenReturn(ctx); 102 when(eventLoop.inEventLoop()).thenReturn(true); 103 when(http2Stream.id()).thenReturn(STREAM_ID); 104 105 doAnswer(new Answer<Void>() { 106 @Override 107 public Void answer(InvocationOnMock invocation) throws Throwable { 108 Runnable runnable = (Runnable) invocation.getArguments()[0]; 109 runnable.run(); 110 return null; 111 } 112 }).when(eventLoop).execute(any(Runnable.class)); 113 114 stream = createStream(); 115 } 116 117 @Test 118 public void inboundMessageShouldCallListener() throws Exception { 119 stream.request(1); 120 121 if (stream instanceof NettyServerStream) { 122 ((NettyServerStream) stream).transportState() 123 .inboundDataReceived(messageFrame(MESSAGE), false); 124 } else { 125 ((NettyClientStream) stream).transportState() 126 .transportDataReceived(messageFrame(MESSAGE), false); 127 } 128 129 InputStream message = listenerMessageQueue().poll(); 130 131 // Verify that inbound flow control window update has been disabled for the stream. 132 assertEquals(MESSAGE, NettyTestUtil.toString(message)); 133 assertNull("no additional message expected", listenerMessageQueue().poll()); 134 } 135 136 @Test 137 public void shouldBeImmediatelyReadyForData() { 138 assertTrue(stream.isReady()); 139 } 140 141 @Test 142 public void closedShouldNotBeReady() throws IOException { 143 assertTrue(stream.isReady()); 144 closeStream(); 145 assertFalse(stream.isReady()); 146 } 147 148 @Test 149 public void notifiedOnReadyAfterWriteCompletes() throws IOException { 150 sendHeadersIfServer(); 151 assertTrue(stream.isReady()); 152 byte[] msg = largeMessage(); 153 // The channel.write future is set up to automatically complete, indicating that the write is 154 // done. 155 stream.writeMessage(new ByteArrayInputStream(msg)); 156 stream.flush(); 157 assertTrue(stream.isReady()); 158 verify(listener()).onReady(); 159 } 160 161 @Test 162 public void shouldBeReadyForDataAfterWritingSmallMessage() throws IOException { 163 sendHeadersIfServer(); 164 // Make sure the writes don't complete so we "back up" 165 ChannelPromise uncompletedPromise = new DefaultChannelPromise(channel); 166 when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(uncompletedPromise); 167 168 assertTrue(stream.isReady()); 169 byte[] msg = smallMessage(); 170 stream.writeMessage(new ByteArrayInputStream(msg)); 171 stream.flush(); 172 assertTrue(stream.isReady()); 173 verify(listener(), never()).onReady(); 174 } 175 176 @Test 177 public void shouldNotBeReadyForDataAfterWritingLargeMessage() throws IOException { 178 sendHeadersIfServer(); 179 // Make sure the writes don't complete so we "back up" 180 ChannelPromise uncompletedPromise = new DefaultChannelPromise(channel); 181 when(writeQueue.enqueue(any(QueuedCommand.class), anyBoolean())).thenReturn(uncompletedPromise); 182 183 assertTrue(stream.isReady()); 184 byte[] msg = largeMessage(); 185 stream.writeMessage(new ByteArrayInputStream(msg)); 186 stream.flush(); 187 assertFalse(stream.isReady()); 188 verify(listener(), never()).onReady(); 189 } 190 191 protected byte[] smallMessage() { 192 return MESSAGE.getBytes(US_ASCII); 193 } 194 195 protected byte[] largeMessage() { 196 byte[] smallMessage = smallMessage(); 197 int size = smallMessage.length * 10 * 1024; 198 byte[] largeMessage = new byte[size]; 199 for (int ix = 0; ix < size; ix += smallMessage.length) { 200 System.arraycopy(smallMessage, 0, largeMessage, ix, smallMessage.length); 201 } 202 return largeMessage; 203 } 204 205 protected abstract T createStream(); 206 207 protected abstract void sendHeadersIfServer(); 208 209 protected abstract StreamListener listener(); 210 211 protected abstract Queue<InputStream> listenerMessageQueue(); 212 213 protected abstract void closeStream(); 214 } 215