1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/instruction_fusion.h" 17 18 #include <algorithm> 19 #include <list> 20 #include <memory> 21 #include <numeric> 22 #include <vector> 23 24 #include "tensorflow/compiler/xla/map_util.h" 25 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 26 #include "tensorflow/core/lib/core/errors.h" 27 #include "tensorflow/core/lib/gtl/flatmap.h" 28 #include "tensorflow/core/platform/logging.h" 29 30 namespace xla { 31 /*static*/ bool InstructionFusion::IsExpensive( 32 const HloInstruction& instruction) { 33 switch (instruction.opcode()) { 34 // Cheap instructions. 35 case HloOpcode::kAdd: 36 case HloOpcode::kAnd: 37 case HloOpcode::kBitcast: 38 case HloOpcode::kBitcastConvert: 39 case HloOpcode::kBroadcast: 40 case HloOpcode::kCeil: 41 case HloOpcode::kClamp: 42 case HloOpcode::kComplex: 43 case HloOpcode::kConcatenate: 44 case HloOpcode::kConstant: 45 case HloOpcode::kConvert: 46 case HloOpcode::kCopy: 47 case HloOpcode::kDynamicSlice: 48 case HloOpcode::kDynamicUpdateSlice: 49 case HloOpcode::kEq: 50 case HloOpcode::kFloor: 51 case HloOpcode::kGe: 52 case HloOpcode::kGetTupleElement: 53 case HloOpcode::kGt: 54 case HloOpcode::kImag: 55 case HloOpcode::kInfeed: 56 case HloOpcode::kIsFinite: 57 case HloOpcode::kLe: 58 case HloOpcode::kLt: 59 case HloOpcode::kMaximum: 60 case HloOpcode::kMinimum: 61 case HloOpcode::kMultiply: 62 case HloOpcode::kNe: 63 case HloOpcode::kNegate: 64 case HloOpcode::kNot: 65 case HloOpcode::kOr: 66 case HloOpcode::kOutfeed: 67 case HloOpcode::kPad: 68 case HloOpcode::kReal: 69 case HloOpcode::kReducePrecision: 70 case HloOpcode::kReshape: 71 case HloOpcode::kReverse: 72 case HloOpcode::kRoundNearestAfz: 73 case HloOpcode::kSelect: 74 case HloOpcode::kShiftLeft: 75 case HloOpcode::kShiftRightArithmetic: 76 case HloOpcode::kShiftRightLogical: 77 case HloOpcode::kSlice: 78 case HloOpcode::kSubtract: 79 case HloOpcode::kTranspose: 80 case HloOpcode::kTuple: 81 return false; 82 83 // Cheap instructions for reals, but expensive for complex. 84 case HloOpcode::kAbs: 85 case HloOpcode::kCos: 86 case HloOpcode::kSign: 87 case HloOpcode::kSin: 88 return ShapeUtil::ElementIsComplex(instruction.shape()); 89 90 // Expensive instructions. 91 case HloOpcode::kAtan2: 92 case HloOpcode::kBatchNormGrad: 93 case HloOpcode::kBatchNormInference: 94 case HloOpcode::kBatchNormTraining: 95 case HloOpcode::kCall: 96 case HloOpcode::kConditional: 97 case HloOpcode::kConvolution: 98 case HloOpcode::kCrossReplicaSum: 99 case HloOpcode::kCustomCall: 100 case HloOpcode::kDivide: 101 case HloOpcode::kDot: 102 case HloOpcode::kExp: 103 case HloOpcode::kFft: 104 case HloOpcode::kFusion: 105 case HloOpcode::kGather: 106 case HloOpcode::kHostCompute: 107 case HloOpcode::kLog: 108 case HloOpcode::kMap: 109 case HloOpcode::kParameter: 110 case HloOpcode::kPower: 111 case HloOpcode::kRecv: 112 case HloOpcode::kRecvDone: 113 case HloOpcode::kReduce: 114 case HloOpcode::kReduceWindow: 115 case HloOpcode::kRemainder: 116 case HloOpcode::kRng: 117 case HloOpcode::kSelectAndScatter: 118 case HloOpcode::kSend: 119 case HloOpcode::kSendDone: 120 case HloOpcode::kSort: 121 case HloOpcode::kTanh: 122 case HloOpcode::kTrace: 123 case HloOpcode::kWhile: 124 return true; 125 } 126 127 return false; 128 } 129 130 // An "effectively unary" operation is one that has one "large" 131 // input with the others being negligible in terms of memory usage. 132 // We use "has a smaller true rank than the output" as a heuristic 133 // for "negligible" memory usage. 134 bool InstructionFusion::EffectivelyUnary(HloInstruction* hlo) { 135 int64 output_rank = 0; 136 ShapeUtil::ForEachSubshape( 137 hlo->shape(), 138 [&output_rank](const Shape& subshape, const ShapeIndex& shape_index) { 139 if (ShapeUtil::IsArray(subshape)) { 140 output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape)); 141 } 142 }); 143 return std::count_if(hlo->operands().begin(), hlo->operands().end(), 144 [output_rank](HloInstruction* operand) { 145 if (operand->opcode() == HloOpcode::kBroadcast) { 146 return false; 147 } 148 if (operand->opcode() == HloOpcode::kConstant && 149 ShapeUtil::IsEffectiveScalar(operand->shape())) { 150 return false; 151 } 152 return ShapeUtil::TrueRank(operand->shape()) >= 153 output_rank; 154 }) <= 1; 155 } 156 157 bool InstructionFusion::CanFuseOnAllPaths( 158 const HloReachabilityMap& reachability_map, HloInstruction* producer, 159 HloInstruction* consumer, DoNotFuseSet* do_not_fuse) { 160 auto could_fuse_on_all_paths = [&] { 161 // First check to see if we have already marked this producer as infeasible 162 // to fuse into consumer. 163 if (do_not_fuse->count(producer) > 0) { 164 return false; 165 } 166 // Make sure it is possible for producer and consumer to exist in a fusion 167 // node. 168 if (!producer->IsFusable() || !consumer->IsFusable()) { 169 return false; 170 } 171 // We do an upward walk of the graph from consumer towards all paths which 172 // lead to producer to find any unfusable paths. 173 for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) { 174 auto* consumer_operand = consumer->mutable_operand(i); 175 if (consumer_operand == producer) { 176 // This is the base case: our upward crawl ends but we need to make sure 177 // that fusion from consumer can happen. 178 if (!ShouldFuse(consumer, i)) { 179 return false; 180 } 181 } else if (reachability_map.IsReachable(producer, consumer_operand)) { 182 // The reachability map told us that consumer_operand is a node on the 183 // path to producer. We need to further investigate from 184 // consumer_operand. 185 186 // First check if we have already ruled out fusing producer into 187 // consumer_operand. 188 if (do_not_fuse->count(consumer_operand) > 0) { 189 return false; 190 } 191 // Make sure it is possible for consumer_operand to exist in a fusion 192 // node. 193 if (!consumer_operand->IsFusable()) { 194 return false; 195 } 196 // The producer is reachable from consumer_operand which means we need 197 // to be able to fuse consumer_operand into consumer in order for 198 // producer to be fusable into consumer on all paths. 199 if (!ShouldFuse(consumer, i)) { 200 return false; 201 } 202 // Perform the recursive step: make sure producer can be fused into 203 // consumer_operand on all paths. 204 if (!CanFuseOnAllPaths(reachability_map, producer, consumer_operand, 205 do_not_fuse)) { 206 return false; 207 } 208 } 209 } 210 return true; 211 }; 212 if (could_fuse_on_all_paths()) { 213 return true; 214 } 215 // We couldn't fuse on all paths, record this result. 216 do_not_fuse->insert(producer); 217 return false; 218 } 219 220 StatusOr<bool> InstructionFusion::Run(HloModule* module) { 221 VLOG(2) << "Before instruction fusion:"; 222 XLA_VLOG_LINES(2, module->ToString()); 223 224 bool changed = false; 225 module_ = module; 226 for (auto* computation : module->MakeNonfusionComputations()) { 227 CHECK(!computation->IsFusionComputation()); 228 computation_ = computation; 229 230 // We want to be able to remove arbitrary instructions from the post order 231 // and also compare positions of instructions in the post order. To make 232 // this possible, create vector of instructions in post order and create a 233 // map from HloInstruction* to the instruction's index in the vector. An 234 // instruction is "removed" from the vector by setting it's element to 235 // nullptr. 236 std::list<HloInstruction*> post_order_list = 237 computation_->MakeInstructionPostOrder(); 238 std::vector<HloInstruction*> post_order(post_order_list.begin(), 239 post_order_list.end()); 240 241 tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index; 242 for (size_t i = 0; i < post_order.size(); ++i) { 243 InsertOrDie(&post_order_index, post_order[i], i); 244 } 245 246 DoNotFuseSet do_not_fuse; 247 auto reachability = computation->ComputeReachability(); 248 249 auto cheap_to_duplicate = [this](HloInstruction* producer) { 250 if (producer->opcode() == HloOpcode::kBroadcast) { 251 return true; 252 } 253 if (producer->opcode() == HloOpcode::kConstant && 254 ShapeUtil::IsEffectiveScalar(producer->shape())) { 255 return true; 256 } 257 if (EffectivelyUnary(producer)) { 258 return true; 259 } 260 return false; 261 }; 262 263 for (HloInstruction* consumer : post_order) { 264 for (HloInstruction* producer : consumer->operands()) { 265 if (cheap_to_duplicate(producer)) { 266 continue; 267 } 268 if (CanFuseOnAllPaths(*reachability, producer, consumer, 269 &do_not_fuse)) { 270 CHECK_EQ(do_not_fuse.count(producer), 0); 271 } else { 272 CHECK_GT(do_not_fuse.count(producer), 0); 273 } 274 } 275 } 276 277 // Instruction fusion effectively fuses edges in the computation graph 278 // (producer instruction -> consumer instruction) so we iterate over all 279 // edges. When we fuse an edge, we create a copy of the producer inside the 280 // fusion instruction. 281 while (!post_order.empty()) { 282 // We want to iterate in reverse post order, so remove from the back of 283 // the vector. 284 HloInstruction* instruction = post_order.back(); 285 post_order.pop_back(); 286 287 // Instructions are "removed" from the post order by nulling out the 288 // element in the vector, so if the pointer is null, continue to the next 289 // instruction in the sort. 290 if (instruction == nullptr) { 291 continue; 292 } 293 294 // Remove instruction from the index map to ensure the vector and map stay 295 // consistent. 296 post_order_index.erase(instruction); 297 298 if (!instruction->IsFusable() && 299 instruction->opcode() != HloOpcode::kFusion) { 300 continue; 301 } 302 303 // Consider each operand of this instruction for fusion into this 304 // instruction. We want to consider the operands in a particular order to 305 // avoid created duplicate instruction clones in the fusion instruction. 306 // For example, consider the following expression: 307 // 308 // A = ... 309 // B = op(A) 310 // C = op(A, B) 311 // 312 // If we are considering the operands of C for fusion into C. We might 313 // fuse A or B first. If we fuse A first, we get: 314 // 315 // A = ... 316 // B = op(A) 317 // C_fusion = { A' = ... 318 // C' = op(A', B) } 319 // 320 // Where A' and C' are clones of A and C, respectively. Now only B is an 321 // operand of the fusion instruction C_fusion, so then we fuse B: 322 // 323 // A = ... 324 // B = op(A) 325 // C_fusion = { A' = ... 326 // B' = op(A) 327 // C' = op(A', B') } 328 // 329 // Now A is an operand of C_fusion again, so we then fuse A (again!): 330 // 331 // A = ... 332 // B = op(A) 333 // C_fusion = { A' = ... 334 // A" = .. 335 // B' = op(A") 336 // C' = op(A', B') } 337 // 338 // We prevent this duplication by considering the operands in the reverse 339 // order they appear in the instruction post order. In the example, this 340 // ensures that B will be considered before A. 341 // 342 // We store the original indices of the operands to pass to ShouldFuse. 343 std::vector<int64> sorted_operand_numbers(instruction->operands().size()); 344 std::iota(std::begin(sorted_operand_numbers), 345 std::end(sorted_operand_numbers), 0); 346 std::sort( 347 sorted_operand_numbers.begin(), sorted_operand_numbers.end(), 348 [&](int64 i, int64 j) { 349 // Instructions with higher indices in the post order come 350 // first. 351 return ( 352 FindOrDie(post_order_index, instruction->mutable_operand(i)) > 353 FindOrDie(post_order_index, instruction->mutable_operand(j))); 354 }); 355 356 for (int64 i : sorted_operand_numbers) { 357 HloInstruction* operand = instruction->mutable_operand(i); 358 359 if (!operand->IsFusable()) { 360 continue; 361 } 362 if (!ShouldFuse(instruction, i)) { 363 continue; 364 } 365 if (do_not_fuse.count(operand) > 0) { 366 continue; 367 } 368 HloInstruction* fusion_instruction = Fuse(operand, instruction); 369 370 // Fusing an instruction into a fusion instruction can change the 371 // operand set of the fusion instruction. For simplicity just push the 372 // instruction to the top of the post_order and reconsider it for 373 // further fusion in the next iteration of the outer loop. 374 post_order.push_back(fusion_instruction); 375 InsertOrDie(&post_order_index, fusion_instruction, 376 post_order.size() - 1); 377 changed = true; 378 379 if (operand->user_count() == 0) { 380 // Operand is now dead. Remove from post order by setting it's 381 // location to nullptr. 382 post_order[FindOrDie(post_order_index, operand)] = nullptr; 383 post_order_index.erase(operand); 384 385 // Remove from computation. 386 TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand)); 387 } 388 break; 389 } 390 } 391 } 392 393 VLOG(2) << "After instruction fusion:"; 394 XLA_VLOG_LINES(2, module->ToString()); 395 396 return changed; 397 } 398 399 HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, 400 HloInstruction* consumer) { 401 HloInstruction* fusion_instruction; 402 403 VLOG(2) << "Fusing " << producer->ToString() << " into " 404 << consumer->ToString(); 405 auto kind = ChooseKind(producer, consumer); 406 if (consumer->opcode() == HloOpcode::kFusion) { 407 fusion_instruction = consumer; 408 if (kind != fusion_instruction->fusion_kind()) { 409 fusion_instruction->set_fusion_kind(kind); 410 } 411 } else { 412 fusion_instruction = computation_->AddInstruction( 413 HloInstruction::CreateFusion(consumer->shape(), kind, consumer)); 414 TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction)); 415 } 416 417 fusion_instruction->FuseInstruction(producer); 418 return fusion_instruction; 419 } 420 421 bool InstructionFusion::ShouldFuse(HloInstruction* consumer, 422 int64 operand_index) { 423 HloInstruction* producer = consumer->mutable_operand(operand_index); 424 // Cost condition: don't duplicate expensive instructions. 425 if (FusionWouldDuplicate(*producer, *consumer) && 426 (is_expensive_(*producer) || !may_duplicate_)) { 427 return false; 428 } 429 430 if (consumer->opcode() == HloOpcode::kFusion && 431 consumer->fusion_kind() != HloInstruction::FusionKind::kLoop && 432 consumer->fusion_kind() != HloInstruction::FusionKind::kInput && 433 consumer->fusion_kind() != HloInstruction::FusionKind::kOutput) { 434 return false; 435 } 436 437 if (producer->CouldBeBitcast() && 438 // We can't fuse parameters anyhow, so we leave the user unfused to become 439 // a bitcast. If the operand is not a parameter, we would break a 440 // potential fusion to make it a bitcast, which is not so clear a win. 441 producer->operand(0)->opcode() == HloOpcode::kParameter) { 442 return false; 443 } 444 445 return true; 446 } 447 448 HloInstruction::FusionKind InstructionFusion::ChooseKind( 449 const HloInstruction* producer, const HloInstruction* consumer) { 450 return HloInstruction::FusionKind::kLoop; 451 } 452 453 } // namespace xla 454