Home | History | Annotate | Download | only in shard
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 package com.android.tradefed.invoker.shard;
     17 
     18 import com.android.annotations.VisibleForTesting;
     19 import com.android.tradefed.config.IConfiguration;
     20 import com.android.tradefed.invoker.IInvocationContext;
     21 import com.android.tradefed.invoker.IRescheduler;
     22 import com.android.tradefed.log.LogUtil.CLog;
     23 import com.android.tradefed.testtype.IBuildReceiver;
     24 import com.android.tradefed.testtype.IDeviceTest;
     25 import com.android.tradefed.testtype.IInvocationContextReceiver;
     26 import com.android.tradefed.testtype.IMultiDeviceTest;
     27 import com.android.tradefed.testtype.IRemoteTest;
     28 import com.android.tradefed.testtype.IRuntimeHintProvider;
     29 import com.android.tradefed.testtype.IShardableTest;
     30 import com.android.tradefed.testtype.IStrictShardableTest;
     31 import com.android.tradefed.testtype.suite.ITestSuite;
     32 import com.android.tradefed.testtype.suite.ModuleMerger;
     33 import com.android.tradefed.util.TimeUtil;
     34 
     35 import java.util.ArrayList;
     36 import java.util.Collection;
     37 import java.util.Collections;
     38 import java.util.List;
     39 
     40 /** Sharding strategy to create strict shards that do not report together, */
     41 public class StrictShardHelper extends ShardHelper {
     42 
     43     /** {@inheritDoc} */
     44     @Override
     45     public boolean shardConfig(
     46             IConfiguration config, IInvocationContext context, IRescheduler rescheduler) {
     47         Integer shardCount = config.getCommandOptions().getShardCount();
     48         Integer shardIndex = config.getCommandOptions().getShardIndex();
     49 
     50         if (shardIndex == null) {
     51             return super.shardConfig(config, context, rescheduler);
     52         }
     53         if (shardCount == null) {
     54             throw new RuntimeException("shard-count is null while shard-index is " + shardIndex);
     55         }
     56 
     57         // Split tests in place, without actually sharding.
     58         if (!config.getCommandOptions().shouldUseTfSharding()) {
     59             // TODO: remove when IStrictShardableTest is removed.
     60             updateConfigIfSharded(config, shardCount, shardIndex);
     61         } else {
     62             List<IRemoteTest> listAllTests = getAllTests(config, shardCount, context);
     63             // We cannot shuffle to get better average results
     64             normalizeDistribution(listAllTests, shardCount);
     65             List<IRemoteTest> splitList;
     66             if (shardCount == 1) {
     67                 // not sharded
     68                 splitList = listAllTests;
     69             } else {
     70                 splitList = splitTests(listAllTests, shardCount).get(shardIndex);
     71             }
     72             aggregateSuiteModules(splitList);
     73             config.setTests(splitList);
     74         }
     75         return false;
     76     }
     77 
     78     // TODO: Retire IStrictShardableTest for IShardableTest and have TF balance the list of tests.
     79     private void updateConfigIfSharded(IConfiguration config, int shardCount, int shardIndex) {
     80         List<IRemoteTest> testShards = new ArrayList<>();
     81         for (IRemoteTest test : config.getTests()) {
     82             if (!(test instanceof IStrictShardableTest)) {
     83                 CLog.w(
     84                         "%s is not shardable; the whole test will run in shard 0",
     85                         test.getClass().getName());
     86                 if (shardIndex == 0) {
     87                     testShards.add(test);
     88                 }
     89                 continue;
     90             }
     91             IRemoteTest testShard =
     92                     ((IStrictShardableTest) test).getTestShard(shardCount, shardIndex);
     93             testShards.add(testShard);
     94         }
     95         config.setTests(testShards);
     96     }
     97 
     98     /**
     99      * Helper to return the full list of {@link IRemoteTest} based on {@link IShardableTest} split.
    100      *
    101      * @param config the {@link IConfiguration} describing the invocation.
    102      * @param shardCount the shard count hint to be provided to some tests.
    103      * @param context the {@link IInvocationContext} of the parent invocation.
    104      * @return the list of all {@link IRemoteTest}.
    105      */
    106     private List<IRemoteTest> getAllTests(
    107             IConfiguration config, Integer shardCount, IInvocationContext context) {
    108         List<IRemoteTest> allTests = new ArrayList<>();
    109         for (IRemoteTest test : config.getTests()) {
    110             if (test instanceof IShardableTest) {
    111                 // Inject current information to help with sharding
    112                 if (test instanceof IBuildReceiver) {
    113                     ((IBuildReceiver) test).setBuild(context.getBuildInfos().get(0));
    114                 }
    115                 if (test instanceof IDeviceTest) {
    116                     ((IDeviceTest) test).setDevice(context.getDevices().get(0));
    117                 }
    118                 if (test instanceof IMultiDeviceTest) {
    119                     ((IMultiDeviceTest) test).setDeviceInfos(context.getDeviceBuildMap());
    120                 }
    121                 if (test instanceof IInvocationContextReceiver) {
    122                     ((IInvocationContextReceiver) test).setInvocationContext(context);
    123                 }
    124 
    125                 // Handling of the ITestSuite is a special case, we do not allow pool of tests
    126                 // since each shard needs to be independent.
    127                 if (test instanceof ITestSuite) {
    128                     ((ITestSuite) test).setShouldMakeDynamicModule(false);
    129                 }
    130 
    131                 Collection<IRemoteTest> subTests = ((IShardableTest) test).split(shardCount);
    132                 if (subTests == null) {
    133                     // test did not shard so we add it as is.
    134                     allTests.add(test);
    135                 } else {
    136                     allTests.addAll(subTests);
    137                 }
    138             } else {
    139                 // if test is not shardable we add it as is.
    140                 allTests.add(test);
    141             }
    142         }
    143         return allTests;
    144     }
    145 
    146     /**
    147      * Split the list of tests to run however the implementation see fit. Sharding needs to be
    148      * consistent. It is acceptable to return an empty list if no tests can be run in the shard.
    149      *
    150      * <p>Implement this in order to provide a test suite specific sharding. The default
    151      * implementation attempts to balance the number of IRemoteTest per shards as much as possible
    152      * as a first step, then use a minor criteria or run-hint to adjust the lists a bit more.
    153      *
    154      * @param fullList the initial full list of {@link IRemoteTest} containing all the tests that
    155      *     need to run.
    156      * @param shardCount the total number of shard that need to run.
    157      * @return a list of list {@link IRemoteTest}s that have been assigned to each shard. The list
    158      *     size will be the shardCount.
    159      */
    160     @VisibleForTesting
    161     protected List<List<IRemoteTest>> splitTests(List<IRemoteTest> fullList, int shardCount) {
    162         List<List<IRemoteTest>> shards = new ArrayList<>();
    163         // We are using Match.ceil to avoid the last shard having too much extra.
    164         int numPerShard = (int) Math.ceil(fullList.size() / (float) shardCount);
    165 
    166         boolean needsCorrection = false;
    167         float correctionRatio = 0f;
    168         if (fullList.size() > shardCount) {
    169             // In some cases because of the Math.ceil, some combination might run out of tests
    170             // before the last shard, in that case we populate a correction to rebalance the tests.
    171             needsCorrection = (numPerShard * (shardCount - 1)) > fullList.size();
    172             correctionRatio = numPerShard - ((fullList.size() / (float) shardCount));
    173         }
    174         // Recalculate the number of tests per shard with the correction taken into account.
    175         numPerShard = (int) Math.floor(numPerShard - correctionRatio);
    176         // Based of the parameters, distribute the tests accross shards.
    177         shards = balancedDistrib(fullList, shardCount, numPerShard, needsCorrection);
    178         // Do last minute rebalancing
    179         topBottom(shards, shardCount);
    180         return shards;
    181     }
    182 
    183     private List<List<IRemoteTest>> balancedDistrib(
    184             List<IRemoteTest> fullList, int shardCount, int numPerShard, boolean needsCorrection) {
    185         List<List<IRemoteTest>> shards = new ArrayList<>();
    186         List<IRemoteTest> correctionList = new ArrayList<>();
    187         int correctionSize = 0;
    188 
    189         // Generate all the shards
    190         for (int i = 0; i < shardCount; i++) {
    191             List<IRemoteTest> shardList;
    192             if (i >= fullList.size()) {
    193                 // Return empty list when we don't have enough tests for all the shards.
    194                 shardList = new ArrayList<IRemoteTest>();
    195                 shards.add(shardList);
    196                 continue;
    197             }
    198 
    199             if (i == shardCount - 1) {
    200                 // last shard take everything remaining except the correction:
    201                 if (needsCorrection) {
    202                     // We omit the size of the correction needed.
    203                     correctionSize = fullList.size() - (numPerShard + (i * numPerShard));
    204                     correctionList =
    205                             fullList.subList(fullList.size() - correctionSize, fullList.size());
    206                 }
    207                 shardList = fullList.subList(i * numPerShard, fullList.size() - correctionSize);
    208                 shards.add(new ArrayList<>(shardList));
    209                 continue;
    210             }
    211             shardList = fullList.subList(i * numPerShard, numPerShard + (i * numPerShard));
    212             shards.add(new ArrayList<>(shardList));
    213         }
    214 
    215         // If we have correction omitted tests, disperse them on each shard, at this point the
    216         // number of tests in correction is ensured to be bellow the number of shards.
    217         for (int i = 0; i < shardCount; i++) {
    218             if (i < correctionList.size()) {
    219                 shards.get(i).add(correctionList.get(i));
    220             } else {
    221                 break;
    222             }
    223         }
    224         return shards;
    225     }
    226 
    227     /**
    228      * Move around predictably the tests in order to have a better uniformization of the tests in
    229      * each shard.
    230      */
    231     private void normalizeDistribution(List<IRemoteTest> listAllTests, int shardCount) {
    232         final int numRound = shardCount;
    233         final int distance = shardCount - 1;
    234         for (int i = 0; i < numRound; i++) {
    235             for (int j = 0; j < listAllTests.size(); j = j + distance) {
    236                 // Push the test at the end
    237                 IRemoteTest push = listAllTests.remove(j);
    238                 listAllTests.add(push);
    239             }
    240         }
    241     }
    242 
    243     /**
    244      * Special handling for suite from {@link ITestSuite}. We aggregate the tests in the same shard
    245      * in order to optimize target_preparation step.
    246      *
    247      * @param tests the {@link List} of {@link IRemoteTest} for that shard.
    248      */
    249     private void aggregateSuiteModules(List<IRemoteTest> tests) {
    250         List<IRemoteTest> dupList = new ArrayList<>(tests);
    251         for (int i = 0; i < dupList.size(); i++) {
    252             if (dupList.get(i) instanceof ITestSuite) {
    253                 // We iterate the other tests to see if we can find another from the same module.
    254                 for (int j = i + 1; j < dupList.size(); j++) {
    255                     // If the test was not already merged
    256                     if (tests.contains(dupList.get(j))) {
    257                         if (dupList.get(j) instanceof ITestSuite) {
    258                             if (ModuleMerger.arePartOfSameSuite(
    259                                     (ITestSuite) dupList.get(i), (ITestSuite) dupList.get(j))) {
    260                                 ModuleMerger.mergeSplittedITestSuite(
    261                                         (ITestSuite) dupList.get(i), (ITestSuite) dupList.get(j));
    262                                 tests.remove(dupList.get(j));
    263                             }
    264                         }
    265                     }
    266                 }
    267             }
    268         }
    269     }
    270 
    271     private void topBottom(List<List<IRemoteTest>> allShards, int shardCount) {
    272         // We only attempt this when the number of shard is pretty high
    273         if (shardCount < 4) {
    274             return;
    275         }
    276         // Generate approximate RuntimeHint for each shard
    277         int index = 0;
    278         List<SortShardObj> shardTimes = new ArrayList<>();
    279         CLog.e("============================");
    280         for (List<IRemoteTest> shard : allShards) {
    281             long aggTime = 0l;
    282             CLog.e("++++++++++++++++++ SHARD %s +++++++++++++++", index);
    283             for (IRemoteTest test : shard) {
    284                 if (test instanceof IRuntimeHintProvider) {
    285                     aggTime += ((IRuntimeHintProvider) test).getRuntimeHint();
    286                 }
    287             }
    288             CLog.e("Shard %s approximate time: %s", index, TimeUtil.formatElapsedTime(aggTime));
    289             shardTimes.add(new SortShardObj(index, aggTime));
    290             index++;
    291             CLog.e("+++++++++++++++++++++++++++++++++++++++++++");
    292         }
    293         CLog.e("============================");
    294 
    295         Collections.sort(shardTimes);
    296         if ((shardTimes.get(0).mAggTime - shardTimes.get(shardTimes.size() - 1).mAggTime)
    297                 < 60 * 60 * 1000l) {
    298             return;
    299         }
    300 
    301         // take 30% top shard (10 shard = top 3 shards)
    302         for (int i = 0; i < (shardCount * 0.3); i++) {
    303             CLog.e(
    304                     "Top shard %s is index %s with %s",
    305                     i,
    306                     shardTimes.get(i).mIndex,
    307                     TimeUtil.formatElapsedTime(shardTimes.get(i).mAggTime));
    308             int give = shardTimes.get(i).mIndex;
    309             int receive = shardTimes.get(shardTimes.size() - 1 - i).mIndex;
    310             CLog.e("Giving from shard %s to shard %s", give, receive);
    311             for (int j = 0; j < (allShards.get(give).size() * (0.2f / (i + 1))); j++) {
    312                 IRemoteTest givetest = allShards.get(give).remove(0);
    313                 allShards.get(receive).add(givetest);
    314             }
    315         }
    316     }
    317 
    318     /** Object holder for shard, their index and their aggregated execution time. */
    319     private class SortShardObj implements Comparable<SortShardObj> {
    320         public final int mIndex;
    321         public final Long mAggTime;
    322 
    323         public SortShardObj(int index, long aggTime) {
    324             mIndex = index;
    325             mAggTime = aggTime;
    326         }
    327 
    328         @Override
    329         public int compareTo(SortShardObj obj) {
    330             return obj.mAggTime.compareTo(mAggTime);
    331         }
    332     }
    333 }
    334