1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" 17 #include "tensorflow/compiler/xla/service/tuple_util.h" 18 #include "tensorflow/compiler/xla/service/while_util.h" 19 #include "tensorflow/compiler/xla/util.h" 20 #include "tensorflow/core/lib/gtl/flatmap.h" 21 #include "tensorflow/core/lib/gtl/flatset.h" 22 #include "tensorflow/core/lib/gtl/inlined_vector.h" 23 24 namespace xla { 25 26 using tensorflow::gtl::FlatMap; 27 using tensorflow::gtl::FlatSet; 28 using tensorflow::gtl::InlinedVector; 29 30 // Copies `to_hoist` to the computation containing `while_instr`, hoisting its 31 // operands as needed. All of its transitive operands are expected to be either 32 // in `hoisted_instructions` or `unhoisted_invariant_instructions`. This 33 // function hoists the operands in `unhoisted_invariant_instructions` and moves 34 // them into `hoisted_instructions`. 35 static void CreateLoopInvariantCopy( 36 FlatMap<HloInstruction*, HloInstruction*>* hoisted_instructions, 37 FlatSet<HloInstruction*>* unhoisted_invariant_instructions, 38 HloInstruction* while_instr, HloInstruction* to_hoist) { 39 HloComputation* parent_of_while = while_instr->parent(); 40 HloComputation* while_body = while_instr->while_body(); 41 42 struct DFSFrame { 43 HloInstruction* instruction; 44 int64 operand_index; 45 }; 46 47 InlinedVector<DFSFrame, 8> dfs_stack; 48 dfs_stack.push_back({to_hoist, 0}); 49 50 HloInstruction* while_body_param = while_body->parameter_instruction(0); 51 HloInstruction* while_operand = while_instr->mutable_operand(0); 52 53 do { 54 DFSFrame* frame = &dfs_stack.back(); 55 if (frame->operand_index == frame->instruction->operand_count()) { 56 HloInstruction* old_instruction = frame->instruction; 57 58 // All of the operands for old_instruction have been cloned, so it is 59 // time to clone old_instruction itself. 60 61 auto get_new_operand = [&](HloInstruction* old_operand) { 62 return old_operand == while_body_param 63 ? while_operand 64 : FindOrDie(*hoisted_instructions, old_operand); 65 }; 66 67 InlinedVector<HloInstruction*, 4> new_operands; 68 c_transform(old_instruction->operands(), std::back_inserter(new_operands), 69 get_new_operand); 70 71 HloInstruction* new_instruction = 72 parent_of_while->AddInstruction(old_instruction->CloneWithNewOperands( 73 old_instruction->shape(), new_operands)); 74 75 InsertOrDie(hoisted_instructions, old_instruction, new_instruction); 76 77 // Approximately half of the instructions that would normally be present 78 // in unhoisted_invariant_instructions are constants. We save a bit of 79 // compile time by not putting these in the hashtable. 80 CHECK_EQ(unhoisted_invariant_instructions->erase(old_instruction), 81 to_hoist != old_instruction && 82 old_instruction->opcode() != HloOpcode::kConstant); 83 dfs_stack.pop_back(); 84 continue; 85 } 86 87 HloInstruction* next_operand = 88 frame->instruction->mutable_operand(frame->operand_index++); 89 if (hoisted_instructions->count(next_operand) || 90 next_operand == while_body_param) { 91 continue; 92 } 93 94 dfs_stack.push_back({next_operand, 0}); 95 } while (!dfs_stack.empty()); 96 } 97 98 // Returns true if `instruction` is worth hoisting only if it lets us hoist some 99 // instruction using it. The rationale is that hoisting these instructions will 100 // prevent simplification and fusion in the while body. 101 static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { 102 switch (instruction.opcode()) { 103 default: 104 return false; 105 106 case HloOpcode::kBitcast: 107 case HloOpcode::kBroadcast: 108 case HloOpcode::kConstant: 109 case HloOpcode::kReverse: 110 case HloOpcode::kSlice: 111 case HloOpcode::kTuple: 112 return true; 113 114 case HloOpcode::kTranspose: 115 return ShapeUtil::TransposeIsBitcast( 116 /*input_shape=*/instruction.operand(0)->shape(), 117 /*output_shape=*/instruction.shape(), instruction.dimensions()); 118 119 case HloOpcode::kReshape: 120 return ShapeUtil::ReshapeIsBitcast( 121 /*input_shape=*/instruction.operand(0)->shape(), 122 /*output_shape=*/instruction.shape()); 123 } 124 } 125 126 // Populates `gte_set` with the GetTupleElement instructions in `while_body` 127 // that access elements in the parameter tuple that don't change across 128 // iterations. Assumes `while_body` is the body computation of the while loop 129 // in question. 130 static void GatherInvariantGTEs(HloComputation* while_body, 131 FlatSet<HloInstruction*>* gte_set) { 132 const HloInstruction::InstructionVector root_operands = 133 while_body->root_instruction()->operands(); 134 for (int i = 0; i < root_operands.size(); i++) { 135 HloInstruction* instr = root_operands[i]; 136 if (instr->opcode() == HloOpcode::kGetTupleElement && 137 instr->tuple_index() == i && 138 instr->operand(0) == while_body->parameter_instruction(0) && 139 ShapeUtil::IsArray(instr->shape())) { 140 InsertOrDie(gte_set, instr); 141 } 142 } 143 } 144 145 static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody( 146 HloInstruction* while_instr) { 147 auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); 148 149 if (!ShapeUtil::IsTuple(while_instr->shape())) { 150 // This restriction leaves one interesting pattern on the table: 151 // 152 // while_body(f32[1024, 1024] %param) { 153 // %value = expensive_op(%param) 154 // outfeed(%value) 155 // ROOT = %param 156 // } 157 // 158 // If we see that pattern in the while, instead of generalizing this 159 // algorithm to work with non-tuples, we should instead add a pass that 160 // canonicalizes while loops like the above to use a tuple state. 161 return false; 162 } 163 164 string while_instr_name = while_instr->ToString(print_no_metadata); 165 VLOG(2) << "Trying to hoist from " << while_instr_name; 166 167 HloComputation* while_body = while_instr->while_body(); 168 169 // Maps instructions in the while body to instructions hoisted outside the 170 // while that compute the same value. 171 FlatMap<HloInstruction*, HloInstruction*> hoisted_instructions; 172 173 // Contains instructions that can be legally hoisted, but were deemed to be 174 // unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we 175 // hoist an instruction in this set, we move it from 176 // unhoisted_invariant_instructions to hoisted_instructions. 177 FlatSet<HloInstruction*> unhoisted_invariant_instructions; 178 179 // Invariant GTE's axiomatically satisfy the constraints for 180 // unhoisted_invariant_instructions -- they can be legally hoisted, but there 181 // is no benefit to hoisting them unless something that uses it is also 182 // hoisted. 183 GatherInvariantGTEs(while_body, &unhoisted_invariant_instructions); 184 185 if (unhoisted_invariant_instructions.empty()) { 186 // There are no obviously loop invariant elements in the state being 187 // threaded through the while loop so give up. In theory this precondition 188 // is too strong -- we could have code that e.g. permutes the elements in 189 // the while state but uses a select to pick the same value on every 190 // iteration. 191 return false; 192 } 193 194 // instructions_to_replace[i] is hoisted into a loop invariant instruction 195 // replacement_instructions[i]. 196 std::vector<HloInstruction*> instructions_to_replace; 197 std::vector<HloInstruction*> replacement_instructions; 198 199 for (auto* instruction : while_body->MakeInstructionPostOrder()) { 200 if (instruction->HasSideEffect() || 201 instruction->opcode() == HloOpcode::kParameter || 202 !instruction->control_predecessors().empty() || 203 !instruction->control_successors().empty()) { 204 continue; 205 } 206 207 auto is_invariant = [&](HloInstruction* op) { 208 return hoisted_instructions.find(op) != hoisted_instructions.end() || 209 unhoisted_invariant_instructions.count(op) || 210 op->opcode() == HloOpcode::kConstant; 211 }; 212 213 if (!c_all_of(instruction->operands(), is_invariant)) { 214 continue; 215 } 216 217 if (NotWorthHoistingIndividually(*instruction)) { 218 VLOG(2) << "Adding " << instruction->ToString(print_no_metadata) 219 << " to unhoisted invariant set."; 220 // Approximately half of the instructions that reach this point are 221 // constants. We save a bit of compile time by not putting these in the 222 // hashtable. 223 if (instruction->opcode() != HloOpcode::kConstant) { 224 InsertOrDie(&unhoisted_invariant_instructions, instruction); 225 } 226 continue; 227 } 228 229 VLOG(2) << "Hoisting " << instruction->ToString(print_no_metadata); 230 231 CreateLoopInvariantCopy(&hoisted_instructions, 232 &unhoisted_invariant_instructions, while_instr, 233 instruction); 234 235 instructions_to_replace.push_back(instruction); 236 replacement_instructions.push_back( 237 FindOrDie(hoisted_instructions, instruction)); 238 } 239 240 if (instructions_to_replace.empty()) { 241 return false; 242 } 243 244 TF_ASSIGN_OR_RETURN( 245 WhileUtil::MakeInstructionsLiveInResult live_in_instructions_result, 246 WhileUtil::MakeInstructionsLiveIn(while_instr, replacement_instructions)); 247 248 HloComputation* new_while_body = 249 live_in_instructions_result.new_while_instr->while_body(); 250 251 for (int i = 0; i < instructions_to_replace.size(); i++) { 252 HloInstruction* instruction_to_replace_in_new_while = 253 FindOrDie(live_in_instructions_result.while_body_instruction_map, 254 instructions_to_replace[i]); 255 TF_RETURN_IF_ERROR(new_while_body->ReplaceInstruction( 256 instruction_to_replace_in_new_while, 257 live_in_instructions_result.while_body_live_in_values[i])); 258 } 259 260 VLOG(1) << "Hoisted " << instructions_to_replace.size() 261 << " instructions from " << while_instr_name; 262 263 return true; 264 } 265 266 StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) { 267 bool changed = false; 268 std::vector<HloInstruction*> while_instrs; 269 for (auto* comp : module->computations()) { 270 c_copy_if(comp->instructions(), std::back_inserter(while_instrs), 271 [](const HloInstruction* instr) { 272 return instr->opcode() == HloOpcode::kWhile; 273 }); 274 } 275 276 for (HloInstruction* while_instr : while_instrs) { 277 // Right now we only hoist computations from the while body, but 278 // TryHoistingInvariantInstructionsFromWhileBody can be generalized to 279 // optimize the condition computation too, if needed. 280 // 281 // The transform we do here is a pessmization for while loops that execute 282 // zero times*, but at this time we expect those to be rare. If this 283 // becomes a problem we can consider using the conditional HLO to avoid 284 // doing extra work for while loops with zero trip count. 285 // 286 // * We delete while loops that have a zero trip count, so this would have 287 // to be a while loop with a somewhat opaque condition expression. 288 289 TF_ASSIGN_OR_RETURN( 290 bool result, 291 TryHoistingInvariantInstructionsFromWhileBody(while_instr)); 292 changed |= result; 293 } 294 return changed; 295 } 296 } // namespace xla 297