1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" 17 18 #include <algorithm> 19 #include <memory> 20 #include <numeric> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 26 #include "tensorflow/compiler/xla/service/hlo_computation.h" 27 #include "tensorflow/compiler/xla/service/hlo_dce.h" 28 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 29 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 30 #include "tensorflow/compiler/xla/shape_util.h" 31 #include "tensorflow/compiler/xla/status_macros.h" 32 #include "tensorflow/compiler/xla/types.h" 33 #include "tensorflow/compiler/xla/util.h" 34 #include "tensorflow/core/lib/core/errors.h" 35 #include "tensorflow/core/lib/core/status.h" 36 #include "tensorflow/core/platform/logging.h" 37 #include "tensorflow/core/platform/types.h" 38 39 namespace xla { 40 41 namespace { 42 43 // Visitor for removing implicit broadcasts. 44 class ImplicitBroadcastVisitor : public DfsHloVisitorWithDefault { 45 public: 46 Status DefaultAction(HloInstruction* hlo_instruction) override { 47 return Status::OK(); 48 } 49 50 Status HandleElementwiseBinary(HloInstruction* hlo) override { 51 return ReplaceImplicitBroadcastOperands(hlo); 52 } 53 54 Status HandleClamp(HloInstruction* hlo) override { 55 // Clamp is the only element-wise ternary operation. 56 return ReplaceImplicitBroadcastOperands(hlo); 57 } 58 59 // Returns whether any modification has been made to any visited instruction. 60 bool changed() const { return changed_; } 61 62 private: 63 // Iterates through the operands of 'hlo' and replace any operands which are 64 // implicitly broadcast with the equivalent sequence of broadcast and reshape 65 // instructions. An operand is considered to be implicitly broadcast if the 66 // operand shape does have the same dimensions as the shape of 'hlo'. 67 Status ReplaceImplicitBroadcastOperands(HloInstruction* hlo) { 68 auto fadd = [hlo](std::unique_ptr<HloInstruction> x) { 69 return hlo->parent()->AddInstruction(std::move(x)); 70 }; 71 std::vector<HloInstruction*> operands; 72 bool operands_changed = false; 73 for (int i = 0; i < hlo->operand_count(); ++i) { 74 HloInstruction* operand = hlo->mutable_operand(i); 75 if (!ShapeUtil::SameDimensions(hlo->shape(), operand->shape())) { 76 HloInstruction* new_operand = hlo->parent()->AddInstruction( 77 HloInstruction::CreateBroadcastSequence(hlo->shape(), operand, 78 fadd)); 79 operands.push_back(new_operand); 80 operands_changed = true; 81 } else { 82 operands.push_back(operand); 83 } 84 } 85 if (operands_changed) { 86 // Create a new HLO instruction because the HloInstruction::Replace* 87 // methods check that the shape does not change with the replacement. 88 HloInstruction* new_hlo = hlo->parent()->AddInstruction( 89 hlo->CloneWithNewOperands(hlo->shape(), operands)); 90 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_hlo)); 91 changed_ = true; 92 } 93 return Status::OK(); 94 } 95 96 bool changed_ = false; 97 }; 98 99 } // namespace 100 101 StatusOr<bool> ImplicitBroadcastRemover::Run(HloModule* module) { 102 VLOG(1) << "Removing implicit broadcast from module " << module->name(); 103 XLA_VLOG_LINES(2, 104 "Before removing implicit broadcasts:\n" + module->ToString()); 105 106 ImplicitBroadcastVisitor visitor; 107 for (HloComputation* computation : module->computations()) { 108 TF_RETURN_IF_ERROR(computation->Accept(&visitor)); 109 } 110 111 if (visitor.changed()) { 112 // HLO instructions with implicitly broadcast operands are cloned and left 113 // for dead. Remove them. 114 HloDCE dce; 115 TF_RETURN_IF_ERROR(dce.Run(module).status()); 116 } 117 118 XLA_VLOG_LINES(2, 119 "After removing implicit broadcasts:\n" + module->ToString()); 120 121 return visitor.changed(); 122 } 123 124 } // namespace xla 125