Home | History | Annotate | Download | only in util
      1 /*
      2  * Copyright 2016 The gRPC Authors
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package io.grpc.util;
     18 
     19 import static com.google.common.truth.Truth.assertThat;
     20 import static io.grpc.ConnectivityState.CONNECTING;
     21 import static io.grpc.ConnectivityState.IDLE;
     22 import static io.grpc.ConnectivityState.READY;
     23 import static io.grpc.ConnectivityState.SHUTDOWN;
     24 import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
     25 import static io.grpc.util.RoundRobinLoadBalancerFactory.RoundRobinLoadBalancer.STATE_INFO;
     26 import static org.junit.Assert.assertEquals;
     27 import static org.junit.Assert.assertFalse;
     28 import static org.junit.Assert.assertNotNull;
     29 import static org.junit.Assert.assertNotSame;
     30 import static org.junit.Assert.assertNull;
     31 import static org.junit.Assert.assertSame;
     32 import static org.junit.Assert.assertTrue;
     33 import static org.mockito.Matchers.any;
     34 import static org.mockito.Matchers.eq;
     35 import static org.mockito.Matchers.isA;
     36 import static org.mockito.Mockito.atLeast;
     37 import static org.mockito.Mockito.doAnswer;
     38 import static org.mockito.Mockito.doReturn;
     39 import static org.mockito.Mockito.inOrder;
     40 import static org.mockito.Mockito.mock;
     41 import static org.mockito.Mockito.never;
     42 import static org.mockito.Mockito.times;
     43 import static org.mockito.Mockito.verify;
     44 import static org.mockito.Mockito.verifyNoMoreInteractions;
     45 import static org.mockito.Mockito.when;
     46 
     47 import com.google.common.collect.Lists;
     48 import com.google.common.collect.Maps;
     49 import io.grpc.Attributes;
     50 import io.grpc.ConnectivityState;
     51 import io.grpc.ConnectivityStateInfo;
     52 import io.grpc.EquivalentAddressGroup;
     53 import io.grpc.LoadBalancer;
     54 import io.grpc.LoadBalancer.Helper;
     55 import io.grpc.LoadBalancer.PickResult;
     56 import io.grpc.LoadBalancer.PickSubchannelArgs;
     57 import io.grpc.LoadBalancer.Subchannel;
     58 import io.grpc.LoadBalancer.SubchannelPicker;
     59 import io.grpc.Metadata;
     60 import io.grpc.Metadata.Key;
     61 import io.grpc.Status;
     62 import io.grpc.internal.GrpcAttributes;
     63 import io.grpc.util.RoundRobinLoadBalancerFactory.EmptyPicker;
     64 import io.grpc.util.RoundRobinLoadBalancerFactory.ReadyPicker;
     65 import io.grpc.util.RoundRobinLoadBalancerFactory.Ref;
     66 import io.grpc.util.RoundRobinLoadBalancerFactory.RoundRobinLoadBalancer;
     67 import io.grpc.util.RoundRobinLoadBalancerFactory.RoundRobinLoadBalancer.StickinessState;
     68 import java.net.SocketAddress;
     69 import java.util.ArrayList;
     70 import java.util.Arrays;
     71 import java.util.Collections;
     72 import java.util.HashMap;
     73 import java.util.Iterator;
     74 import java.util.List;
     75 import java.util.Map;
     76 import org.junit.After;
     77 import org.junit.Before;
     78 import org.junit.Test;
     79 import org.junit.runner.RunWith;
     80 import org.junit.runners.JUnit4;
     81 import org.mockito.ArgumentCaptor;
     82 import org.mockito.Captor;
     83 import org.mockito.InOrder;
     84 import org.mockito.Mock;
     85 import org.mockito.MockitoAnnotations;
     86 import org.mockito.invocation.InvocationOnMock;
     87 import org.mockito.stubbing.Answer;
     88 
     89 /** Unit test for {@link RoundRobinLoadBalancerFactory}. */
     90 @RunWith(JUnit4.class)
     91 public class RoundRobinLoadBalancerTest {
     92   private RoundRobinLoadBalancer loadBalancer;
     93   private List<EquivalentAddressGroup> servers = Lists.newArrayList();
     94   private Map<EquivalentAddressGroup, Subchannel> subchannels = Maps.newLinkedHashMap();
     95   private static final Attributes.Key<String> MAJOR_KEY = Attributes.Key.create("major-key");
     96   private Attributes affinity = Attributes.newBuilder().set(MAJOR_KEY, "I got the keys").build();
     97 
     98   @Captor
     99   private ArgumentCaptor<SubchannelPicker> pickerCaptor;
    100   @Captor
    101   private ArgumentCaptor<ConnectivityState> stateCaptor;
    102   @Captor
    103   private ArgumentCaptor<EquivalentAddressGroup> eagCaptor;
    104   @Mock
    105   private Helper mockHelper;
    106 
    107   @Mock // This LoadBalancer doesn't use any of the arg fields, as verified in tearDown().
    108   private PickSubchannelArgs mockArgs;
    109 
    110   @Before
    111   public void setUp() {
    112     MockitoAnnotations.initMocks(this);
    113 
    114     for (int i = 0; i < 3; i++) {
    115       SocketAddress addr = new FakeSocketAddress("server" + i);
    116       EquivalentAddressGroup eag = new EquivalentAddressGroup(addr);
    117       servers.add(eag);
    118       Subchannel sc = mock(Subchannel.class);
    119       when(sc.getAddresses()).thenReturn(eag);
    120       subchannels.put(eag, sc);
    121     }
    122 
    123     when(mockHelper.createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)))
    124         .then(new Answer<Subchannel>() {
    125           @Override
    126           public Subchannel answer(InvocationOnMock invocation) throws Throwable {
    127             Object[] args = invocation.getArguments();
    128             Subchannel subchannel = subchannels.get(args[0]);
    129             when(subchannel.getAttributes()).thenReturn((Attributes) args[1]);
    130             return subchannel;
    131           }
    132         });
    133 
    134     loadBalancer = (RoundRobinLoadBalancer) RoundRobinLoadBalancerFactory.getInstance()
    135         .newLoadBalancer(mockHelper);
    136   }
    137 
    138   @After
    139   public void tearDown() throws Exception {
    140     verifyNoMoreInteractions(mockArgs);
    141   }
    142 
    143   @Test
    144   public void pickAfterResolved() throws Exception {
    145     final Subchannel readySubchannel = subchannels.values().iterator().next();
    146     loadBalancer.handleResolvedAddressGroups(servers, affinity);
    147     loadBalancer.handleSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY));
    148 
    149     verify(mockHelper, times(3)).createSubchannel(eagCaptor.capture(),
    150         any(Attributes.class));
    151 
    152     assertThat(eagCaptor.getAllValues()).containsAllIn(subchannels.keySet());
    153     for (Subchannel subchannel : subchannels.values()) {
    154       verify(subchannel).requestConnection();
    155       verify(subchannel, never()).shutdown();
    156     }
    157 
    158     verify(mockHelper, times(2))
    159         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
    160 
    161     assertEquals(CONNECTING, stateCaptor.getAllValues().get(0));
    162     assertEquals(READY, stateCaptor.getAllValues().get(1));
    163     assertThat(getList(pickerCaptor.getValue())).containsExactly(readySubchannel);
    164 
    165     verifyNoMoreInteractions(mockHelper);
    166   }
    167 
    168   @Test
    169   public void pickAfterResolvedUpdatedHosts() throws Exception {
    170     Subchannel removedSubchannel = mock(Subchannel.class);
    171     Subchannel oldSubchannel = mock(Subchannel.class);
    172     Subchannel newSubchannel = mock(Subchannel.class);
    173 
    174     for (Subchannel subchannel : Lists.newArrayList(removedSubchannel, oldSubchannel,
    175         newSubchannel)) {
    176       when(subchannel.getAttributes()).thenReturn(Attributes.newBuilder().set(STATE_INFO,
    177           new Ref<ConnectivityStateInfo>(
    178               ConnectivityStateInfo.forNonError(READY))).build());
    179     }
    180 
    181     FakeSocketAddress removedAddr = new FakeSocketAddress("removed");
    182     FakeSocketAddress oldAddr = new FakeSocketAddress("old");
    183     FakeSocketAddress newAddr = new FakeSocketAddress("new");
    184 
    185     final Map<EquivalentAddressGroup, Subchannel> subchannels2 = Maps.newHashMap();
    186     subchannels2.put(new EquivalentAddressGroup(removedAddr), removedSubchannel);
    187     subchannels2.put(new EquivalentAddressGroup(oldAddr), oldSubchannel);
    188 
    189     List<EquivalentAddressGroup> currentServers =
    190         Lists.newArrayList(
    191             new EquivalentAddressGroup(removedAddr),
    192             new EquivalentAddressGroup(oldAddr));
    193 
    194     doAnswer(new Answer<Subchannel>() {
    195       @Override
    196       public Subchannel answer(InvocationOnMock invocation) throws Throwable {
    197         Object[] args = invocation.getArguments();
    198         return subchannels2.get(args[0]);
    199       }
    200     }).when(mockHelper).createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class));
    201 
    202     loadBalancer.handleResolvedAddressGroups(currentServers, affinity);
    203 
    204     InOrder inOrder = inOrder(mockHelper);
    205 
    206     inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture());
    207     SubchannelPicker picker = pickerCaptor.getValue();
    208     assertThat(getList(picker)).containsExactly(removedSubchannel, oldSubchannel);
    209 
    210     verify(removedSubchannel, times(1)).requestConnection();
    211     verify(oldSubchannel, times(1)).requestConnection();
    212 
    213     assertThat(loadBalancer.getSubchannels()).containsExactly(removedSubchannel,
    214         oldSubchannel);
    215 
    216     subchannels2.clear();
    217     subchannels2.put(new EquivalentAddressGroup(oldAddr), oldSubchannel);
    218     subchannels2.put(new EquivalentAddressGroup(newAddr), newSubchannel);
    219 
    220     List<EquivalentAddressGroup> latestServers =
    221         Lists.newArrayList(
    222             new EquivalentAddressGroup(oldAddr),
    223             new EquivalentAddressGroup(newAddr));
    224 
    225     loadBalancer.handleResolvedAddressGroups(latestServers, affinity);
    226 
    227     verify(newSubchannel, times(1)).requestConnection();
    228     verify(removedSubchannel, times(1)).shutdown();
    229 
    230     loadBalancer.handleSubchannelState(removedSubchannel,
    231             ConnectivityStateInfo.forNonError(SHUTDOWN));
    232 
    233     assertThat(loadBalancer.getSubchannels()).containsExactly(oldSubchannel,
    234         newSubchannel);
    235 
    236     verify(mockHelper, times(3)).createSubchannel(any(EquivalentAddressGroup.class),
    237         any(Attributes.class));
    238     inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture());
    239 
    240     picker = pickerCaptor.getValue();
    241     assertThat(getList(picker)).containsExactly(oldSubchannel, newSubchannel);
    242 
    243     // test going from non-empty to empty
    244     loadBalancer.handleResolvedAddressGroups(Collections.<EquivalentAddressGroup>emptyList(),
    245             affinity);
    246 
    247     inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
    248     assertEquals(PickResult.withNoResult(), pickerCaptor.getValue().pickSubchannel(mockArgs));
    249 
    250     verifyNoMoreInteractions(mockHelper);
    251   }
    252 
    253   @Test
    254   public void pickAfterStateChange() throws Exception {
    255     InOrder inOrder = inOrder(mockHelper);
    256     loadBalancer.handleResolvedAddressGroups(servers, Attributes.EMPTY);
    257     Subchannel subchannel = loadBalancer.getSubchannels().iterator().next();
    258     Ref<ConnectivityStateInfo> subchannelStateInfo = subchannel.getAttributes().get(
    259         STATE_INFO);
    260 
    261     inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
    262     assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(IDLE));
    263 
    264     loadBalancer.handleSubchannelState(subchannel,
    265         ConnectivityStateInfo.forNonError(READY));
    266     inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture());
    267     assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class);
    268     assertThat(subchannelStateInfo.value).isEqualTo(
    269         ConnectivityStateInfo.forNonError(READY));
    270 
    271     Status error = Status.UNKNOWN.withDescription("\\_()_//");
    272     loadBalancer.handleSubchannelState(subchannel,
    273         ConnectivityStateInfo.forTransientFailure(error));
    274     assertThat(subchannelStateInfo.value).isEqualTo(
    275         ConnectivityStateInfo.forTransientFailure(error));
    276     inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
    277     assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class);
    278 
    279     loadBalancer.handleSubchannelState(subchannel,
    280         ConnectivityStateInfo.forNonError(IDLE));
    281     assertThat(subchannelStateInfo.value).isEqualTo(
    282         ConnectivityStateInfo.forNonError(IDLE));
    283 
    284     verify(subchannel, times(2)).requestConnection();
    285     verify(mockHelper, times(3)).createSubchannel(any(EquivalentAddressGroup.class),
    286         any(Attributes.class));
    287     verifyNoMoreInteractions(mockHelper);
    288   }
    289 
    290   private Subchannel nextSubchannel(Subchannel current, List<Subchannel> allSubChannels) {
    291     return allSubChannels.get((allSubChannels.indexOf(current) + 1) % allSubChannels.size());
    292   }
    293 
    294   @Test
    295   public void pickerRoundRobin() throws Exception {
    296     Subchannel subchannel = mock(Subchannel.class);
    297     Subchannel subchannel1 = mock(Subchannel.class);
    298     Subchannel subchannel2 = mock(Subchannel.class);
    299 
    300     ReadyPicker picker = new ReadyPicker(Collections.unmodifiableList(
    301         Lists.<Subchannel>newArrayList(subchannel, subchannel1, subchannel2)),
    302         0 /* startIndex */, null /* stickinessState */);
    303 
    304     assertThat(picker.getList()).containsExactly(subchannel, subchannel1, subchannel2);
    305 
    306     assertEquals(subchannel, picker.pickSubchannel(mockArgs).getSubchannel());
    307     assertEquals(subchannel1, picker.pickSubchannel(mockArgs).getSubchannel());
    308     assertEquals(subchannel2, picker.pickSubchannel(mockArgs).getSubchannel());
    309     assertEquals(subchannel, picker.pickSubchannel(mockArgs).getSubchannel());
    310   }
    311 
    312   @Test
    313   public void pickerEmptyList() throws Exception {
    314     SubchannelPicker picker = new EmptyPicker(Status.UNKNOWN);
    315 
    316     assertEquals(null, picker.pickSubchannel(mockArgs).getSubchannel());
    317     assertEquals(Status.UNKNOWN,
    318         picker.pickSubchannel(mockArgs).getStatus());
    319   }
    320 
    321   @Test
    322   public void nameResolutionErrorWithNoChannels() throws Exception {
    323     Status error = Status.NOT_FOUND.withDescription("nameResolutionError");
    324     loadBalancer.handleNameResolutionError(error);
    325     verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
    326     LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs);
    327     assertNull(pickResult.getSubchannel());
    328     assertEquals(error, pickResult.getStatus());
    329     verifyNoMoreInteractions(mockHelper);
    330   }
    331 
    332   @Test
    333   public void nameResolutionErrorWithActiveChannels() throws Exception {
    334     final Subchannel readySubchannel = subchannels.values().iterator().next();
    335     loadBalancer.handleResolvedAddressGroups(servers, affinity);
    336     loadBalancer.handleSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY));
    337     loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError"));
    338 
    339     verify(mockHelper, times(3)).createSubchannel(any(EquivalentAddressGroup.class),
    340         any(Attributes.class));
    341     verify(mockHelper, times(3))
    342         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
    343 
    344     Iterator<ConnectivityState> stateIterator = stateCaptor.getAllValues().iterator();
    345     assertEquals(CONNECTING, stateIterator.next());
    346     assertEquals(READY, stateIterator.next());
    347     assertEquals(TRANSIENT_FAILURE, stateIterator.next());
    348 
    349     LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs);
    350     assertEquals(readySubchannel, pickResult.getSubchannel());
    351     assertEquals(Status.OK.getCode(), pickResult.getStatus().getCode());
    352 
    353     LoadBalancer.PickResult pickResult2 = pickerCaptor.getValue().pickSubchannel(mockArgs);
    354     assertEquals(readySubchannel, pickResult2.getSubchannel());
    355     verifyNoMoreInteractions(mockHelper);
    356   }
    357 
    358   @Test
    359   public void subchannelStateIsolation() throws Exception {
    360     Iterator<Subchannel> subchannelIterator = subchannels.values().iterator();
    361     Subchannel sc1 = subchannelIterator.next();
    362     Subchannel sc2 = subchannelIterator.next();
    363     Subchannel sc3 = subchannelIterator.next();
    364 
    365     loadBalancer.handleResolvedAddressGroups(servers, Attributes.EMPTY);
    366     verify(sc1, times(1)).requestConnection();
    367     verify(sc2, times(1)).requestConnection();
    368     verify(sc3, times(1)).requestConnection();
    369 
    370     loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY));
    371     loadBalancer.handleSubchannelState(sc2, ConnectivityStateInfo.forNonError(READY));
    372     loadBalancer.handleSubchannelState(sc3, ConnectivityStateInfo.forNonError(READY));
    373     loadBalancer.handleSubchannelState(sc2, ConnectivityStateInfo.forNonError(IDLE));
    374     loadBalancer
    375         .handleSubchannelState(sc3, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE));
    376 
    377     verify(mockHelper, times(6))
    378         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
    379     Iterator<ConnectivityState> stateIterator = stateCaptor.getAllValues().iterator();
    380     Iterator<SubchannelPicker> pickers = pickerCaptor.getAllValues().iterator();
    381     // The picker is incrementally updated as subchannels become READY
    382     assertEquals(CONNECTING, stateIterator.next());
    383     assertThat(pickers.next()).isInstanceOf(EmptyPicker.class);
    384     assertEquals(READY, stateIterator.next());
    385     assertThat(getList(pickers.next())).containsExactly(sc1);
    386     assertEquals(READY, stateIterator.next());
    387     assertThat(getList(pickers.next())).containsExactly(sc1, sc2);
    388     assertEquals(READY, stateIterator.next());
    389     assertThat(getList(pickers.next())).containsExactly(sc1, sc2, sc3);
    390     // The IDLE subchannel is dropped from the picker, but a reconnection is requested
    391     assertEquals(READY, stateIterator.next());
    392     assertThat(getList(pickers.next())).containsExactly(sc1, sc3);
    393     verify(sc2, times(2)).requestConnection();
    394     // The failing subchannel is dropped from the picker, with no requested reconnect
    395     assertEquals(READY, stateIterator.next());
    396     assertThat(getList(pickers.next())).containsExactly(sc1);
    397     verify(sc3, times(1)).requestConnection();
    398     assertThat(stateIterator.hasNext()).isFalse();
    399     assertThat(pickers.hasNext()).isFalse();
    400   }
    401 
    402   @Test
    403   public void noStickinessEnabled_withStickyHeader() {
    404     loadBalancer.handleResolvedAddressGroups(servers, Attributes.EMPTY);
    405     for (Subchannel subchannel : subchannels.values()) {
    406       loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
    407     }
    408     verify(mockHelper, times(4))
    409         .updateBalancingState(any(ConnectivityState.class), pickerCaptor.capture());
    410     SubchannelPicker picker = pickerCaptor.getValue();
    411 
    412     Key<String> stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER);
    413     Metadata headerWithStickinessValue = new Metadata();
    414     headerWithStickinessValue.put(stickinessKey, "my-sticky-value");
    415     doReturn(headerWithStickinessValue).when(mockArgs).getHeaders();
    416 
    417     List<Subchannel> allSubchannels = getList(picker);
    418     Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();
    419     Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel();
    420     Subchannel sc3 = picker.pickSubchannel(mockArgs).getSubchannel();
    421     Subchannel sc4 = picker.pickSubchannel(mockArgs).getSubchannel();
    422 
    423     assertEquals(nextSubchannel(sc1, allSubchannels), sc2);
    424     assertEquals(nextSubchannel(sc2, allSubchannels), sc3);
    425     assertEquals(nextSubchannel(sc3, allSubchannels), sc1);
    426     assertEquals(sc4, sc1);
    427 
    428     assertNull(loadBalancer.getStickinessMapForTest());
    429   }
    430 
    431   @Test
    432   public void stickinessEnabled_withoutStickyHeader() {
    433     Map<String, Object> serviceConfig = new HashMap<String, Object>();
    434     serviceConfig.put("stickinessMetadataKey", "my-sticky-key");
    435     Attributes attributes = Attributes.newBuilder()
    436         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build();
    437     loadBalancer.handleResolvedAddressGroups(servers, attributes);
    438     for (Subchannel subchannel : subchannels.values()) {
    439       loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
    440     }
    441     verify(mockHelper, times(4))
    442         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
    443     SubchannelPicker picker = pickerCaptor.getValue();
    444 
    445     doReturn(new Metadata()).when(mockArgs).getHeaders();
    446 
    447     List<Subchannel> allSubchannels = getList(picker);
    448 
    449     Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();
    450     Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel();
    451     Subchannel sc3 = picker.pickSubchannel(mockArgs).getSubchannel();
    452     Subchannel sc4 = picker.pickSubchannel(mockArgs).getSubchannel();
    453 
    454     assertEquals(nextSubchannel(sc1, allSubchannels), sc2);
    455     assertEquals(nextSubchannel(sc2, allSubchannels), sc3);
    456     assertEquals(nextSubchannel(sc3, allSubchannels), sc1);
    457     assertEquals(sc4, sc1);
    458     verify(mockArgs, times(4)).getHeaders();
    459     assertNotNull(loadBalancer.getStickinessMapForTest());
    460     assertThat(loadBalancer.getStickinessMapForTest()).isEmpty();
    461   }
    462 
    463   @Test
    464   public void stickinessEnabled_withStickyHeader() {
    465     Map<String, Object> serviceConfig = new HashMap<String, Object>();
    466     serviceConfig.put("stickinessMetadataKey", "my-sticky-key");
    467     Attributes attributes = Attributes.newBuilder()
    468         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build();
    469     loadBalancer.handleResolvedAddressGroups(servers, attributes);
    470     for (Subchannel subchannel : subchannels.values()) {
    471       loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
    472     }
    473     verify(mockHelper, times(4))
    474         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
    475     SubchannelPicker picker = pickerCaptor.getValue();
    476 
    477     Key<String> stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER);
    478     Metadata headerWithStickinessValue = new Metadata();
    479     headerWithStickinessValue.put(stickinessKey, "my-sticky-value");
    480     doReturn(headerWithStickinessValue).when(mockArgs).getHeaders();
    481 
    482     Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();
    483     assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
    484     assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
    485     assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
    486     assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
    487 
    488     verify(mockArgs, atLeast(4)).getHeaders();
    489     assertNotNull(loadBalancer.getStickinessMapForTest());
    490     assertThat(loadBalancer.getStickinessMapForTest()).hasSize(1);
    491   }
    492 
    493   @Test
    494   public void stickinessEnabled_withDifferentStickyHeaders() {
    495     Map<String, Object> serviceConfig = new HashMap<String, Object>();
    496     serviceConfig.put("stickinessMetadataKey", "my-sticky-key");
    497     Attributes attributes = Attributes.newBuilder()
    498         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build();
    499     loadBalancer.handleResolvedAddressGroups(servers, attributes);
    500     for (Subchannel subchannel : subchannels.values()) {
    501       loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
    502     }
    503     verify(mockHelper, times(4))
    504         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
    505     SubchannelPicker picker = pickerCaptor.getValue();
    506 
    507     Key<String> stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER);
    508     Metadata headerWithStickinessValue1 = new Metadata();
    509     headerWithStickinessValue1.put(stickinessKey, "my-sticky-value");
    510 
    511     Metadata headerWithStickinessValue2 = new Metadata();
    512     headerWithStickinessValue2.put(stickinessKey, "my-sticky-value2");
    513 
    514     List<Subchannel> allSubchannels = getList(picker);
    515 
    516     doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders();
    517     Subchannel sc1a = picker.pickSubchannel(mockArgs).getSubchannel();
    518 
    519     doReturn(headerWithStickinessValue2).when(mockArgs).getHeaders();
    520     Subchannel sc2a = picker.pickSubchannel(mockArgs).getSubchannel();
    521 
    522     doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders();
    523     Subchannel sc1b = picker.pickSubchannel(mockArgs).getSubchannel();
    524 
    525     doReturn(headerWithStickinessValue2).when(mockArgs).getHeaders();
    526     Subchannel sc2b = picker.pickSubchannel(mockArgs).getSubchannel();
    527 
    528     assertEquals(sc1a, sc1b);
    529     assertEquals(sc2a, sc2b);
    530     assertEquals(nextSubchannel(sc1a, allSubchannels), sc2a);
    531     assertEquals(nextSubchannel(sc1b, allSubchannels), sc2b);
    532 
    533     verify(mockArgs, atLeast(4)).getHeaders();
    534     assertNotNull(loadBalancer.getStickinessMapForTest());
    535     assertThat(loadBalancer.getStickinessMapForTest()).hasSize(2);
    536   }
    537 
    538   @Test
    539   public void stickiness_goToTransientFailure_pick_backToReady() {
    540     Map<String, Object> serviceConfig = new HashMap<String, Object>();
    541     serviceConfig.put("stickinessMetadataKey", "my-sticky-key");
    542     Attributes attributes = Attributes.newBuilder()
    543         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build();
    544     loadBalancer.handleResolvedAddressGroups(servers, attributes);
    545     for (Subchannel subchannel : subchannels.values()) {
    546       loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
    547     }
    548     verify(mockHelper, times(4))
    549         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
    550     SubchannelPicker picker = pickerCaptor.getValue();
    551 
    552     Key<String> stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER);
    553     Metadata headerWithStickinessValue = new Metadata();
    554     headerWithStickinessValue.put(stickinessKey, "my-sticky-value");
    555     doReturn(headerWithStickinessValue).when(mockArgs).getHeaders();
    556 
    557     // first pick
    558     Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();
    559 
    560     // go to transient failure
    561     loadBalancer
    562         .handleSubchannelState(sc1, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE));
    563 
    564     verify(mockHelper, times(5))
    565         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
    566     picker = pickerCaptor.getValue();
    567 
    568     // second pick
    569     Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel();
    570 
    571     // go back to ready
    572     loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY));
    573 
    574     verify(mockHelper, times(6))
    575         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
    576     picker = pickerCaptor.getValue();
    577 
    578     // third pick
    579     Subchannel sc3 = picker.pickSubchannel(mockArgs).getSubchannel();
    580     assertEquals(sc2, sc3);
    581     verify(mockArgs, atLeast(3)).getHeaders();
    582     assertNotNull(loadBalancer.getStickinessMapForTest());
    583     assertThat(loadBalancer.getStickinessMapForTest()).hasSize(1);
    584   }
    585 
    586   @Test
    587   public void stickiness_goToTransientFailure_backToReady_pick() {
    588     Map<String, Object> serviceConfig = new HashMap<String, Object>();
    589     serviceConfig.put("stickinessMetadataKey", "my-sticky-key");
    590     Attributes attributes = Attributes.newBuilder()
    591         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build();
    592     loadBalancer.handleResolvedAddressGroups(servers, attributes);
    593     for (Subchannel subchannel : subchannels.values()) {
    594       loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
    595     }
    596     verify(mockHelper, times(4))
    597         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
    598     SubchannelPicker picker = pickerCaptor.getValue();
    599 
    600     Key<String> stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER);
    601     Metadata headerWithStickinessValue1 = new Metadata();
    602     headerWithStickinessValue1.put(stickinessKey, "my-sticky-value");
    603     doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders();
    604 
    605     // first pick
    606     Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();
    607 
    608     // go to transient failure
    609     loadBalancer
    610         .handleSubchannelState(sc1, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE));
    611 
    612     Metadata headerWithStickinessValue2 = new Metadata();
    613     headerWithStickinessValue2.put(stickinessKey, "my-sticky-value2");
    614     doReturn(headerWithStickinessValue2).when(mockArgs).getHeaders();
    615     verify(mockHelper, times(5))
    616         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
    617     picker = pickerCaptor.getValue();
    618 
    619     // second pick with a different stickiness value
    620     Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel();
    621 
    622     // go back to ready
    623     loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY));
    624 
    625     doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders();
    626     verify(mockHelper, times(6))
    627         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
    628     picker = pickerCaptor.getValue();
    629 
    630     // third pick with my-sticky-value1
    631     Subchannel sc3 = picker.pickSubchannel(mockArgs).getSubchannel();
    632     assertEquals(sc1, sc3);
    633 
    634     verify(mockArgs, atLeast(3)).getHeaders();
    635     assertNotNull(loadBalancer.getStickinessMapForTest());
    636     assertThat(loadBalancer.getStickinessMapForTest()).hasSize(2);
    637   }
    638 
    639   @Test
    640   public void stickiness_oneSubchannelShutdown() {
    641     Map<String, Object> serviceConfig = new HashMap<String, Object>();
    642     serviceConfig.put("stickinessMetadataKey", "my-sticky-key");
    643     Attributes attributes = Attributes.newBuilder()
    644         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build();
    645     loadBalancer.handleResolvedAddressGroups(servers, attributes);
    646     for (Subchannel subchannel : subchannels.values()) {
    647       loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
    648     }
    649     verify(mockHelper, times(4))
    650         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
    651     SubchannelPicker picker = pickerCaptor.getValue();
    652 
    653     Key<String> stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER);
    654     Metadata headerWithStickinessValue = new Metadata();
    655     headerWithStickinessValue.put(stickinessKey, "my-sticky-value");
    656     doReturn(headerWithStickinessValue).when(mockArgs).getHeaders();
    657 
    658     List<Subchannel> allSubchannels = Lists.newArrayList(getList(picker));
    659 
    660     Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();
    661 
    662     // shutdown channel directly
    663     loadBalancer
    664         .handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(ConnectivityState.SHUTDOWN));
    665 
    666     assertNull(loadBalancer.getStickinessMapForTest().get("my-sticky-value").value);
    667 
    668     assertEquals(nextSubchannel(sc1, allSubchannels),
    669                  picker.pickSubchannel(mockArgs).getSubchannel());
    670     assertThat(loadBalancer.getStickinessMapForTest()).hasSize(1);
    671     verify(mockArgs, atLeast(2)).getHeaders();
    672 
    673     Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel();
    674 
    675     assertEquals(sc2, loadBalancer.getStickinessMapForTest().get("my-sticky-value").value);
    676 
    677     // shutdown channel via name resolver change
    678     List<EquivalentAddressGroup> newServers = new ArrayList<>(servers);
    679     newServers.remove(sc2.getAddresses());
    680 
    681     loadBalancer.handleResolvedAddressGroups(newServers, attributes);
    682 
    683     verify(sc2, times(1)).shutdown();
    684 
    685     loadBalancer.handleSubchannelState(sc2, ConnectivityStateInfo.forNonError(SHUTDOWN));
    686 
    687     assertNull(loadBalancer.getStickinessMapForTest().get("my-sticky-value").value);
    688 
    689     assertEquals(nextSubchannel(sc2, allSubchannels),
    690             picker.pickSubchannel(mockArgs).getSubchannel());
    691     assertThat(loadBalancer.getStickinessMapForTest()).hasSize(1);
    692     verify(mockArgs, atLeast(2)).getHeaders();
    693   }
    694 
    695   @Test
    696   public void stickiness_resolveTwice_metadataKeyChanged() {
    697     Map<String, Object> serviceConfig1 = new HashMap<String, Object>();
    698     serviceConfig1.put("stickinessMetadataKey", "my-sticky-key1");
    699     Attributes attributes1 = Attributes.newBuilder()
    700         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig1).build();
    701     loadBalancer.handleResolvedAddressGroups(servers, attributes1);
    702     Map<String, ?> stickinessMap1 = loadBalancer.getStickinessMapForTest();
    703 
    704     Map<String, Object> serviceConfig2 = new HashMap<String, Object>();
    705     serviceConfig2.put("stickinessMetadataKey", "my-sticky-key2");
    706     Attributes attributes2 = Attributes.newBuilder()
    707         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig2).build();
    708     loadBalancer.handleResolvedAddressGroups(servers, attributes2);
    709     Map<String, ?> stickinessMap2 = loadBalancer.getStickinessMapForTest();
    710 
    711     assertNotSame(stickinessMap1, stickinessMap2);
    712   }
    713 
    714   @Test
    715   public void stickiness_resolveTwice_metadataKeyUnChanged() {
    716     Map<String, Object> serviceConfig1 = new HashMap<String, Object>();
    717     serviceConfig1.put("stickinessMetadataKey", "my-sticky-key1");
    718     Attributes attributes1 = Attributes.newBuilder()
    719         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig1).build();
    720     loadBalancer.handleResolvedAddressGroups(servers, attributes1);
    721     Map<String, ?> stickinessMap1 = loadBalancer.getStickinessMapForTest();
    722 
    723     loadBalancer.handleResolvedAddressGroups(servers, attributes1);
    724     Map<String, ?> stickinessMap2 = loadBalancer.getStickinessMapForTest();
    725 
    726     assertSame(stickinessMap1, stickinessMap2);
    727   }
    728 
    729   @Test(expected = IllegalArgumentException.class)
    730   public void readyPicker_emptyList() {
    731     // ready picker list must be non-empty
    732     new ReadyPicker(Collections.<Subchannel>emptyList(), 0, null);
    733   }
    734 
    735   @Test
    736   public void internalPickerComparisons() {
    737     EmptyPicker emptyOk1 = new EmptyPicker(Status.OK);
    738     EmptyPicker emptyOk2 = new EmptyPicker(Status.OK.withDescription("different OK"));
    739     EmptyPicker emptyErr = new EmptyPicker(Status.UNKNOWN.withDescription("\\_()_//"));
    740 
    741     Iterator<Subchannel> subchannelIterator = subchannels.values().iterator();
    742     Subchannel sc1 = subchannelIterator.next();
    743     Subchannel sc2 = subchannelIterator.next();
    744     StickinessState stickinessState = new StickinessState("stick-key");
    745     ReadyPicker ready1 = new ReadyPicker(Arrays.asList(sc1, sc2), 0, null);
    746     ReadyPicker ready2 = new ReadyPicker(Arrays.asList(sc1), 0, null);
    747     ReadyPicker ready3 = new ReadyPicker(Arrays.asList(sc2, sc1), 1, null);
    748     ReadyPicker ready4 = new ReadyPicker(Arrays.asList(sc1, sc2), 1, stickinessState);
    749     ReadyPicker ready5 = new ReadyPicker(Arrays.asList(sc2, sc1), 0, stickinessState);
    750 
    751     assertTrue(emptyOk1.isEquivalentTo(emptyOk2));
    752     assertFalse(emptyOk1.isEquivalentTo(emptyErr));
    753     assertFalse(ready1.isEquivalentTo(ready2));
    754     assertTrue(ready1.isEquivalentTo(ready3));
    755     assertFalse(ready3.isEquivalentTo(ready4));
    756     assertTrue(ready4.isEquivalentTo(ready5));
    757     assertFalse(emptyOk1.isEquivalentTo(ready1));
    758     assertFalse(ready1.isEquivalentTo(emptyOk1));
    759   }
    760 
    761 
    762   private static List<Subchannel> getList(SubchannelPicker picker) {
    763     return picker instanceof ReadyPicker ? ((ReadyPicker) picker).getList() :
    764         Collections.<Subchannel>emptyList();
    765   }
    766 
    767   private static class FakeSocketAddress extends SocketAddress {
    768     final String name;
    769 
    770     FakeSocketAddress(String name) {
    771       this.name = name;
    772     }
    773 
    774     @Override
    775     public String toString() {
    776       return "FakeSocketAddress-" + name;
    777     }
    778   }
    779 }
    780