1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h" 17 18 #include "tensorflow/compiler/xla/service/hlo_computation.h" 19 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 20 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 21 #include "tensorflow/compiler/xla/status_macros.h" 22 #include "tensorflow/compiler/xla/xla_data.pb.h" 23 #include "tensorflow/core/lib/gtl/array_slice.h" 24 #include "tensorflow/core/platform/logging.h" 25 #include "tensorflow/core/platform/types.h" 26 27 namespace xla { 28 29 class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { 30 public: 31 explicit BFloat16ConversionFoldingVisitor( 32 HloComputation* computation, const BFloat16Support* bfloat16_support) 33 : computation_(computation), bfloat16_support_(bfloat16_support) {} 34 35 Status DefaultAction(HloInstruction* hlo) override; 36 37 static bool Run(HloComputation* computation, 38 const BFloat16Support* bfloat16_support) { 39 BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support); 40 TF_CHECK_OK(computation->Accept(&visitor)); 41 return visitor.changed_; 42 } 43 44 private: 45 // Checks if the HLO has a BF16 -> F32 conversion as input, or a F32 -> BF16 46 // conversion as output, and folds them to the HLO itself if feasible. 47 Status TryFoldBF16Conversions(HloInstruction* hlo); 48 49 // Folds the F32 -> BF16 conversions from the HLO's output. 50 // 51 // Precondition: all of the HLO's users are F32 -> BF16 conversions. 52 Status FoldOutputConversions(HloInstruction* hlo); 53 54 // Folds the BF16 -> F32 conversion operand to the HLO. 55 // 56 // Precondition: the operand is a F32 -> BF16 conversion. 57 Status FoldOperandConversion(HloInstruction* hlo, int64 operand_index); 58 59 HloComputation* computation_; 60 const BFloat16Support* bfloat16_support_; 61 bool changed_ = false; 62 }; 63 64 Status BFloat16ConversionFoldingVisitor::FoldOutputConversions( 65 HloInstruction* hlo) { 66 std::vector<HloInstruction*> materialized_users = hlo->users(); 67 hlo->mutable_shape()->set_element_type(BF16); 68 for (auto user : materialized_users) { 69 CHECK_EQ(user->opcode(), HloOpcode::kConvert); 70 TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo)); 71 changed_ = true; 72 } 73 return Status::OK(); 74 } 75 76 Status BFloat16ConversionFoldingVisitor::FoldOperandConversion( 77 HloInstruction* hlo, int64 operand_index) { 78 // The operand is a convert from BF16 to F32. 79 auto operand = hlo->mutable_operand(operand_index); 80 CHECK_EQ(operand->opcode(), HloOpcode::kConvert); 81 TF_RETURN_IF_ERROR( 82 hlo->ReplaceOperandWith(operand_index, operand->mutable_operand(0))); 83 changed_ = true; 84 return Status::OK(); 85 } 86 87 Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( 88 HloInstruction* hlo) { 89 std::vector<int64> bf16_to_f32_operands; 90 bool has_other_f32_operands = false; 91 for (int64 i = 0; i < hlo->operands().size(); ++i) { 92 auto operand = hlo->operand(i); 93 if (operand->shape().element_type() == F32) { 94 if (operand->opcode() == HloOpcode::kConvert && 95 operand->operand(0)->shape().element_type() == BF16 && 96 bfloat16_support_->SupportsBF16Operand(*hlo, i)) { 97 // Operand is a convert from BF16 to F32 and we support BF16 input 98 // directly in the current HLO at the operand index. 99 bf16_to_f32_operands.push_back(i); 100 } else { 101 has_other_f32_operands = true; 102 } 103 continue; 104 } 105 } 106 107 bool fold_output_conversion = hlo->user_count() > 0 && 108 hlo->shape().element_type() == F32 && 109 bfloat16_support_->SupportsBF16Output(*hlo) && 110 hlo != computation_->root_instruction(); 111 if (fold_output_conversion) { 112 for (auto user : hlo->users()) { 113 if (user->opcode() == HloOpcode::kConvert && 114 user->shape().element_type() == BF16) { 115 continue; 116 } 117 // We should not change the output type if any user is not a conversion 118 // from F32 to BF16. 119 fold_output_conversion = false; 120 break; 121 } 122 } 123 124 if (!bfloat16_support_->SupportsMixedPrecisions(*hlo)) { 125 if (has_other_f32_operands || 126 (!fold_output_conversion && hlo->shape().element_type() == F32)) { 127 // Some of the operands/output will remain F32, but we cannot use mixed 128 // precisions, so we cannot do anything here. 129 return Status::OK(); 130 } 131 } 132 133 if (fold_output_conversion) { 134 TF_RETURN_IF_ERROR(FoldOutputConversions(hlo)); 135 } 136 137 for (int64 i : bf16_to_f32_operands) { 138 TF_RETURN_IF_ERROR(FoldOperandConversion(hlo, i)); 139 } 140 return Status::OK(); 141 } 142 143 Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { 144 // Do not fold BF16 conversions for instructions related to tuples, entry and 145 // exit of a computation, fusion, convert, and control flow. 146 if (hlo->opcode() == HloOpcode::kTuple || // 147 hlo->opcode() == HloOpcode::kGetTupleElement || // 148 hlo->opcode() == HloOpcode::kInfeed || // 149 hlo->opcode() == HloOpcode::kOutfeed || // 150 hlo->opcode() == HloOpcode::kConstant || // 151 hlo->opcode() == HloOpcode::kParameter || // 152 hlo->opcode() == HloOpcode::kFusion || // 153 hlo->opcode() == HloOpcode::kConvert || // 154 hlo->opcode() == HloOpcode::kCall || // 155 hlo->opcode() == HloOpcode::kCustomCall || // 156 hlo->opcode() == HloOpcode::kWhile || // 157 hlo->opcode() == HloOpcode::kConditional) { 158 return Status::OK(); 159 } 160 if (hlo == computation_->root_instruction() && 161 !bfloat16_support_->SupportsMixedPrecisions(*hlo)) { 162 // If hlo is the root instruction, we cannot change its output, so folding 163 // can only happen when it supports mixed precision so that we can change 164 // its operands. 165 return Status::OK(); 166 } 167 return TryFoldBF16Conversions(hlo); 168 } 169 170 StatusOr<bool> BFloat16ConversionFolding::Run(HloModule* module) { 171 XLA_VLOG_LINES( 172 2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString()); 173 bool changed = false; 174 for (auto* comp : module->MakeNonfusionComputations()) { 175 if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) { 176 changed = true; 177 } 178 } 179 XLA_VLOG_LINES( 180 2, "BFloat16ConversionFolding::Run(), after:\n" + module->ToString()); 181 return changed; 182 } 183 184 } // namespace xla 185