1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/call_inliner.h" 17 18 #include <deque> 19 20 #include "tensorflow/compiler/xla/service/call_graph.h" 21 #include "tensorflow/core/lib/core/errors.h" 22 23 namespace xla { 24 namespace { 25 26 // Traverses the callee computation, inlining cloned nodes into the caller 27 // computation and connecting them to producers/consumers appropriately. 28 // When the traversal has completed, the provided call instruction is entriely 29 // replaced in the caller's graph. 30 class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { 31 public: 32 // call is the call operation -- it will be replaced with the body of the 33 // called computation. 34 explicit SubcomputationInsertionVisitor(HloInstruction* call) 35 : call_(call), outer_(call->parent()) { 36 CHECK_EQ(HloOpcode::kCall, call_->opcode()); 37 } 38 39 // Resolves the operands to the HLO instruction in the inlined (caller) graph, 40 // and clones the HLO instruction into that graph with the new operands. 41 // If the instruction is a call, it is added to the work queue. 42 Status DefaultAction(HloInstruction* hlo) override { 43 TF_RET_CHECK(hlo->opcode() != HloOpcode::kCall); 44 std::vector<HloInstruction*> new_operands; 45 for (HloInstruction* operand : hlo->operands()) { 46 TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand)); 47 new_operands.push_back(new_operand); 48 } 49 VLOG(1) << "Cloning HLO and adding to caller: " << hlo->ToString(); 50 auto new_hlo = hlo->CloneWithNewOperands(hlo->shape(), new_operands); 51 HloInstruction* new_hlo_pointer = 52 outer_->AddInstruction(std::move(new_hlo)); 53 TF_RETURN_IF_ERROR(NoteMapping(hlo, new_hlo_pointer)); 54 55 // Account for control edges. 56 for (HloInstruction* control_predecessor : hlo->control_predecessors()) { 57 TF_ASSIGN_OR_RETURN(HloInstruction * new_control_predecessor, 58 Resolve(control_predecessor)); 59 TF_RETURN_IF_ERROR( 60 new_control_predecessor->AddControlDependencyTo(new_hlo_pointer)); 61 } 62 63 return Status::OK(); 64 } 65 66 // Does not create new nodes for the parameter; rather, notes the mapping from 67 // the subcomputation parameter node to the call operands in the caller 68 // computation. 69 Status HandleParameter(HloInstruction* parameter) override { 70 TF_RETURN_IF_ERROR(NoteMapping( 71 parameter, call_->mutable_operand(parameter->parameter_number()))); 72 return Status::OK(); 73 } 74 75 // Wires the consumers of the call to instead point at the newly created root, 76 // replacing the call operation in the caller computation. 77 Status FinishVisit(HloInstruction* root) override { 78 TF_ASSIGN_OR_RETURN(HloInstruction * new_root, Resolve(root)); 79 VLOG(1) << "Replacing all uses of " << call_->ToString() 80 << " with new root " << new_root->ToString(); 81 call_->ClearCalledComputations(); 82 return outer_->ReplaceInstruction(call_, new_root); 83 } 84 85 CallInliner::InlinedInstructionMap ConsumeInstructionMap() { 86 return std::move(subcomputation_hlo_to_new_hlo_); 87 } 88 89 private: 90 // Resolves the callee subcomputation_hlo to the new (inline) HLO in the 91 // caller computation, or returns a NotFound error if that subcomputation HLO 92 // has not been mapped. 93 StatusOr<HloInstruction*> Resolve(HloInstruction* subcomputation_hlo) { 94 auto it = subcomputation_hlo_to_new_hlo_.find(subcomputation_hlo); 95 if (it == subcomputation_hlo_to_new_hlo_.end()) { 96 return NotFound( 97 "Could not find mapping from subcomputation HLO %s to a cloned HLO.", 98 subcomputation_hlo->ToString().c_str()); 99 } 100 return it->second; 101 } 102 103 // Notes that the given subcomputation_hlo in the callee has been mapped to 104 // the (inline) new_hlo in the caller computation. 105 // 106 // Returns an error status if the subcomputation_hlo is mapped more than 107 // once. 108 Status NoteMapping(HloInstruction* subcomputation_hlo, 109 HloInstruction* new_hlo) { 110 auto result = subcomputation_hlo_to_new_hlo_.insert( 111 std::make_pair(subcomputation_hlo, new_hlo)); 112 TF_RET_CHECK(result.second) 113 << "A mapping for the subcomputation HLO is already present."; 114 return Status::OK(); 115 } 116 117 HloInstruction* call_; 118 HloComputation* outer_; 119 CallInliner::InlinedInstructionMap subcomputation_hlo_to_new_hlo_; 120 }; 121 122 } // namespace 123 124 /* static */ StatusOr<CallInliner::InlinedInstructionMap> CallInliner::Inline( 125 HloInstruction* call) { 126 TF_RET_CHECK(call->opcode() == HloOpcode::kCall) 127 << "Instruction was not a call op: " << call->opcode(); 128 const auto& callees = call->called_computations(); 129 TF_RET_CHECK(callees.size() == 1); 130 HloComputation* callee = callees[0]; 131 // We visit the callee, cloning its body into its caller. 132 SubcomputationInsertionVisitor visitor(call); 133 TF_RETURN_IF_ERROR(callee->Accept(&visitor)); 134 return visitor.ConsumeInstructionMap(); 135 } 136 137 StatusOr<bool> CallInliner::Run(HloModule* module) { 138 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); 139 // Because call graph nodes are visited in post-order (callees before callers) 140 // we'll always inline kCalls into their callers in the appropriate order. 141 bool did_mutate = false; 142 TF_RETURN_IF_ERROR( 143 call_graph->VisitNodes([&](const CallGraphNode& node) -> Status { 144 for (const CallSite& callsite : node.caller_callsites()) { 145 VLOG(1) << "Visiting callsite: " << callsite.ToString(); 146 if (callsite.instruction()->opcode() == HloOpcode::kCall) { 147 HloInstruction* call = callsite.instruction(); 148 TF_RETURN_IF_ERROR(Inline(call).status()); 149 did_mutate = true; 150 } 151 } 152 return Status::OK(); 153 })); 154 return did_mutate; 155 } 156 157 } // namespace xla 158