1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/inliner.h" 17 18 #include <memory> 19 #include <string> 20 21 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 22 #include "tensorflow/compiler/xla/service/hlo_computation.h" 23 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 24 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 25 #include "tensorflow/compiler/xla/service/hlo_query.h" 26 #include "tensorflow/compiler/xla/status_macros.h" 27 #include "tensorflow/compiler/xla/types.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/lib/gtl/array_slice.h" 31 #include "tensorflow/core/platform/logging.h" 32 33 namespace xla { 34 35 // InlinerVisitor traverses the HLO computation and inlines maps. 36 class InlinerVisitor : public DfsHloVisitorWithDefault { 37 public: 38 explicit InlinerVisitor(HloComputation* computation) 39 : computation_(computation) {} 40 41 // Default visitor action is to do nothing and return OK. 42 Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { 43 return Status::OK(); 44 } 45 46 Status HandleMap(HloInstruction* map) override; 47 48 // Runs the visitor on a computation. 49 StatusOr<bool> Run(HloComputation* computation); 50 51 private: 52 // Current HloComputation instance the InlinerVisitor is traversing. 53 HloComputation* computation_; 54 55 // Whether algebraic simplification has occurred. 56 bool changed_ = false; 57 }; 58 59 StatusOr<bool> InlinerVisitor::Run(HloComputation* computation) { 60 changed_ = false; 61 computation_ = computation; 62 TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this)); 63 return changed_; 64 } 65 66 Status InlinerVisitor::HandleMap(HloInstruction* map) { 67 HloComputation* function = map->to_apply(); 68 HloInstruction& root = *function->root_instruction(); 69 // TODO(b/29249531): Add DCE pass to remove unused HloComputations. 70 // Only inlining functions that are simply a single operation until a better 71 // profitability model for inlining is defined. 72 if (hlo_query::AllOperandsAreParameters(root)) { 73 if (root.opcode() == HloOpcode::kFusion || 74 root.opcode() == HloOpcode::kParameter || 75 root.opcode() == HloOpcode::kTrace) { 76 // Cloning not supported for these instructions. 77 return Status::OK(); 78 } 79 VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function " 80 << root.ToShortString(); 81 // If the input is a constant then the shape of the constant could be 82 // different than the map shape. Hence, a broadcast is needed, else the 83 // cloned operand with new shape and operands work. 84 if (root.opcode() != HloOpcode::kConstant) { 85 std::vector<HloInstruction*> params; 86 for (int64 o = 0; o < root.operands().size(); o++) { 87 params.push_back(map->operands()[root.operand(o)->parameter_number()]); 88 } 89 HloInstruction* placed_instruction = computation_->AddInstruction( 90 root.CloneWithNewOperands(map->shape(), params)); 91 TF_RETURN_IF_ERROR( 92 computation_->ReplaceInstruction(map, placed_instruction)); 93 } else { 94 // The constant is in an embedded computation and needs to be recreated 95 // as part of the computation that the broadcast is inserted into. 96 HloInstruction* constant = computation_->AddInstruction(root.Clone()); 97 HloInstruction* placed_instruction = computation_->AddInstruction( 98 HloInstruction::CreateBroadcast(map->shape(), constant, {})); 99 TF_RETURN_IF_ERROR( 100 computation_->ReplaceInstruction(map, placed_instruction)); 101 } 102 changed_ = true; 103 return Status::OK(); 104 } 105 106 return Status::OK(); 107 } 108 109 StatusOr<bool> Inliner::Run(HloModule* module) { 110 InlinerVisitor visitor(/*computation=*/nullptr); 111 bool changed = false; 112 for (HloComputation* computation : module->computations()) { 113 TF_ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation)); 114 changed |= computation_changed; 115 } 116 return changed; 117 } 118 119 } // namespace xla 120