1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 package com.android.tradefed.invoker.shard; 17 18 import com.android.annotations.VisibleForTesting; 19 import com.android.tradefed.config.IConfiguration; 20 import com.android.tradefed.invoker.IInvocationContext; 21 import com.android.tradefed.invoker.IRescheduler; 22 import com.android.tradefed.log.LogUtil.CLog; 23 import com.android.tradefed.testtype.IBuildReceiver; 24 import com.android.tradefed.testtype.IDeviceTest; 25 import com.android.tradefed.testtype.IInvocationContextReceiver; 26 import com.android.tradefed.testtype.IMultiDeviceTest; 27 import com.android.tradefed.testtype.IRemoteTest; 28 import com.android.tradefed.testtype.IRuntimeHintProvider; 29 import com.android.tradefed.testtype.IShardableTest; 30 import com.android.tradefed.testtype.IStrictShardableTest; 31 import com.android.tradefed.testtype.suite.ITestSuite; 32 import com.android.tradefed.testtype.suite.ModuleMerger; 33 import com.android.tradefed.util.TimeUtil; 34 35 import java.util.ArrayList; 36 import java.util.Collection; 37 import java.util.Collections; 38 import java.util.List; 39 40 /** Sharding strategy to create strict shards that do not report together, */ 41 public class StrictShardHelper extends ShardHelper { 42 43 /** {@inheritDoc} */ 44 @Override 45 public boolean shardConfig( 46 IConfiguration config, IInvocationContext context, IRescheduler rescheduler) { 47 Integer shardCount = config.getCommandOptions().getShardCount(); 48 Integer shardIndex = config.getCommandOptions().getShardIndex(); 49 50 if (shardIndex == null) { 51 return super.shardConfig(config, context, rescheduler); 52 } 53 if (shardCount == null) { 54 throw new RuntimeException("shard-count is null while shard-index is " + shardIndex); 55 } 56 57 // Split tests in place, without actually sharding. 58 if (!config.getCommandOptions().shouldUseTfSharding()) { 59 // TODO: remove when IStrictShardableTest is removed. 60 updateConfigIfSharded(config, shardCount, shardIndex); 61 } else { 62 List<IRemoteTest> listAllTests = getAllTests(config, shardCount, context); 63 // We cannot shuffle to get better average results 64 normalizeDistribution(listAllTests, shardCount); 65 List<IRemoteTest> splitList; 66 if (shardCount == 1) { 67 // not sharded 68 splitList = listAllTests; 69 } else { 70 splitList = splitTests(listAllTests, shardCount).get(shardIndex); 71 } 72 aggregateSuiteModules(splitList); 73 config.setTests(splitList); 74 } 75 return false; 76 } 77 78 // TODO: Retire IStrictShardableTest for IShardableTest and have TF balance the list of tests. 79 private void updateConfigIfSharded(IConfiguration config, int shardCount, int shardIndex) { 80 List<IRemoteTest> testShards = new ArrayList<>(); 81 for (IRemoteTest test : config.getTests()) { 82 if (!(test instanceof IStrictShardableTest)) { 83 CLog.w( 84 "%s is not shardable; the whole test will run in shard 0", 85 test.getClass().getName()); 86 if (shardIndex == 0) { 87 testShards.add(test); 88 } 89 continue; 90 } 91 IRemoteTest testShard = 92 ((IStrictShardableTest) test).getTestShard(shardCount, shardIndex); 93 testShards.add(testShard); 94 } 95 config.setTests(testShards); 96 } 97 98 /** 99 * Helper to return the full list of {@link IRemoteTest} based on {@link IShardableTest} split. 100 * 101 * @param config the {@link IConfiguration} describing the invocation. 102 * @param shardCount the shard count hint to be provided to some tests. 103 * @param context the {@link IInvocationContext} of the parent invocation. 104 * @return the list of all {@link IRemoteTest}. 105 */ 106 private List<IRemoteTest> getAllTests( 107 IConfiguration config, Integer shardCount, IInvocationContext context) { 108 List<IRemoteTest> allTests = new ArrayList<>(); 109 for (IRemoteTest test : config.getTests()) { 110 if (test instanceof IShardableTest) { 111 // Inject current information to help with sharding 112 if (test instanceof IBuildReceiver) { 113 ((IBuildReceiver) test).setBuild(context.getBuildInfos().get(0)); 114 } 115 if (test instanceof IDeviceTest) { 116 ((IDeviceTest) test).setDevice(context.getDevices().get(0)); 117 } 118 if (test instanceof IMultiDeviceTest) { 119 ((IMultiDeviceTest) test).setDeviceInfos(context.getDeviceBuildMap()); 120 } 121 if (test instanceof IInvocationContextReceiver) { 122 ((IInvocationContextReceiver) test).setInvocationContext(context); 123 } 124 125 // Handling of the ITestSuite is a special case, we do not allow pool of tests 126 // since each shard needs to be independent. 127 if (test instanceof ITestSuite) { 128 ((ITestSuite) test).setShouldMakeDynamicModule(false); 129 } 130 131 Collection<IRemoteTest> subTests = ((IShardableTest) test).split(shardCount); 132 if (subTests == null) { 133 // test did not shard so we add it as is. 134 allTests.add(test); 135 } else { 136 allTests.addAll(subTests); 137 } 138 } else { 139 // if test is not shardable we add it as is. 140 allTests.add(test); 141 } 142 } 143 return allTests; 144 } 145 146 /** 147 * Split the list of tests to run however the implementation see fit. Sharding needs to be 148 * consistent. It is acceptable to return an empty list if no tests can be run in the shard. 149 * 150 * <p>Implement this in order to provide a test suite specific sharding. The default 151 * implementation attempts to balance the number of IRemoteTest per shards as much as possible 152 * as a first step, then use a minor criteria or run-hint to adjust the lists a bit more. 153 * 154 * @param fullList the initial full list of {@link IRemoteTest} containing all the tests that 155 * need to run. 156 * @param shardCount the total number of shard that need to run. 157 * @return a list of list {@link IRemoteTest}s that have been assigned to each shard. The list 158 * size will be the shardCount. 159 */ 160 @VisibleForTesting 161 protected List<List<IRemoteTest>> splitTests(List<IRemoteTest> fullList, int shardCount) { 162 List<List<IRemoteTest>> shards = new ArrayList<>(); 163 // We are using Match.ceil to avoid the last shard having too much extra. 164 int numPerShard = (int) Math.ceil(fullList.size() / (float) shardCount); 165 166 boolean needsCorrection = false; 167 float correctionRatio = 0f; 168 if (fullList.size() > shardCount) { 169 // In some cases because of the Math.ceil, some combination might run out of tests 170 // before the last shard, in that case we populate a correction to rebalance the tests. 171 needsCorrection = (numPerShard * (shardCount - 1)) > fullList.size(); 172 correctionRatio = numPerShard - ((fullList.size() / (float) shardCount)); 173 } 174 // Recalculate the number of tests per shard with the correction taken into account. 175 numPerShard = (int) Math.floor(numPerShard - correctionRatio); 176 // Based of the parameters, distribute the tests accross shards. 177 shards = balancedDistrib(fullList, shardCount, numPerShard, needsCorrection); 178 // Do last minute rebalancing 179 topBottom(shards, shardCount); 180 return shards; 181 } 182 183 private List<List<IRemoteTest>> balancedDistrib( 184 List<IRemoteTest> fullList, int shardCount, int numPerShard, boolean needsCorrection) { 185 List<List<IRemoteTest>> shards = new ArrayList<>(); 186 List<IRemoteTest> correctionList = new ArrayList<>(); 187 int correctionSize = 0; 188 189 // Generate all the shards 190 for (int i = 0; i < shardCount; i++) { 191 List<IRemoteTest> shardList; 192 if (i >= fullList.size()) { 193 // Return empty list when we don't have enough tests for all the shards. 194 shardList = new ArrayList<IRemoteTest>(); 195 shards.add(shardList); 196 continue; 197 } 198 199 if (i == shardCount - 1) { 200 // last shard take everything remaining except the correction: 201 if (needsCorrection) { 202 // We omit the size of the correction needed. 203 correctionSize = fullList.size() - (numPerShard + (i * numPerShard)); 204 correctionList = 205 fullList.subList(fullList.size() - correctionSize, fullList.size()); 206 } 207 shardList = fullList.subList(i * numPerShard, fullList.size() - correctionSize); 208 shards.add(new ArrayList<>(shardList)); 209 continue; 210 } 211 shardList = fullList.subList(i * numPerShard, numPerShard + (i * numPerShard)); 212 shards.add(new ArrayList<>(shardList)); 213 } 214 215 // If we have correction omitted tests, disperse them on each shard, at this point the 216 // number of tests in correction is ensured to be bellow the number of shards. 217 for (int i = 0; i < shardCount; i++) { 218 if (i < correctionList.size()) { 219 shards.get(i).add(correctionList.get(i)); 220 } else { 221 break; 222 } 223 } 224 return shards; 225 } 226 227 /** 228 * Move around predictably the tests in order to have a better uniformization of the tests in 229 * each shard. 230 */ 231 private void normalizeDistribution(List<IRemoteTest> listAllTests, int shardCount) { 232 final int numRound = shardCount; 233 final int distance = shardCount - 1; 234 for (int i = 0; i < numRound; i++) { 235 for (int j = 0; j < listAllTests.size(); j = j + distance) { 236 // Push the test at the end 237 IRemoteTest push = listAllTests.remove(j); 238 listAllTests.add(push); 239 } 240 } 241 } 242 243 /** 244 * Special handling for suite from {@link ITestSuite}. We aggregate the tests in the same shard 245 * in order to optimize target_preparation step. 246 * 247 * @param tests the {@link List} of {@link IRemoteTest} for that shard. 248 */ 249 private void aggregateSuiteModules(List<IRemoteTest> tests) { 250 List<IRemoteTest> dupList = new ArrayList<>(tests); 251 for (int i = 0; i < dupList.size(); i++) { 252 if (dupList.get(i) instanceof ITestSuite) { 253 // We iterate the other tests to see if we can find another from the same module. 254 for (int j = i + 1; j < dupList.size(); j++) { 255 // If the test was not already merged 256 if (tests.contains(dupList.get(j))) { 257 if (dupList.get(j) instanceof ITestSuite) { 258 if (ModuleMerger.arePartOfSameSuite( 259 (ITestSuite) dupList.get(i), (ITestSuite) dupList.get(j))) { 260 ModuleMerger.mergeSplittedITestSuite( 261 (ITestSuite) dupList.get(i), (ITestSuite) dupList.get(j)); 262 tests.remove(dupList.get(j)); 263 } 264 } 265 } 266 } 267 } 268 } 269 } 270 271 private void topBottom(List<List<IRemoteTest>> allShards, int shardCount) { 272 // We only attempt this when the number of shard is pretty high 273 if (shardCount < 4) { 274 return; 275 } 276 // Generate approximate RuntimeHint for each shard 277 int index = 0; 278 List<SortShardObj> shardTimes = new ArrayList<>(); 279 CLog.e("============================"); 280 for (List<IRemoteTest> shard : allShards) { 281 long aggTime = 0l; 282 CLog.e("++++++++++++++++++ SHARD %s +++++++++++++++", index); 283 for (IRemoteTest test : shard) { 284 if (test instanceof IRuntimeHintProvider) { 285 aggTime += ((IRuntimeHintProvider) test).getRuntimeHint(); 286 } 287 } 288 CLog.e("Shard %s approximate time: %s", index, TimeUtil.formatElapsedTime(aggTime)); 289 shardTimes.add(new SortShardObj(index, aggTime)); 290 index++; 291 CLog.e("+++++++++++++++++++++++++++++++++++++++++++"); 292 } 293 CLog.e("============================"); 294 295 Collections.sort(shardTimes); 296 if ((shardTimes.get(0).mAggTime - shardTimes.get(shardTimes.size() - 1).mAggTime) 297 < 60 * 60 * 1000l) { 298 return; 299 } 300 301 // take 30% top shard (10 shard = top 3 shards) 302 for (int i = 0; i < (shardCount * 0.3); i++) { 303 CLog.e( 304 "Top shard %s is index %s with %s", 305 i, 306 shardTimes.get(i).mIndex, 307 TimeUtil.formatElapsedTime(shardTimes.get(i).mAggTime)); 308 int give = shardTimes.get(i).mIndex; 309 int receive = shardTimes.get(shardTimes.size() - 1 - i).mIndex; 310 CLog.e("Giving from shard %s to shard %s", give, receive); 311 for (int j = 0; j < (allShards.get(give).size() * (0.2f / (i + 1))); j++) { 312 IRemoteTest givetest = allShards.get(give).remove(0); 313 allShards.get(receive).add(givetest); 314 } 315 } 316 } 317 318 /** Object holder for shard, their index and their aggregated execution time. */ 319 private class SortShardObj implements Comparable<SortShardObj> { 320 public final int mIndex; 321 public final Long mAggTime; 322 323 public SortShardObj(int index, long aggTime) { 324 mIndex = index; 325 mAggTime = aggTime; 326 } 327 328 @Override 329 public int compareTo(SortShardObj obj) { 330 return obj.mAggTime.compareTo(mAggTime); 331 } 332 } 333 } 334