Home | History | Annotate | Download | only in service
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
     17 
     18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     20 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     21 #include "tensorflow/compiler/xla/status_macros.h"
     22 #include "tensorflow/compiler/xla/xla_data.pb.h"
     23 #include "tensorflow/core/lib/gtl/array_slice.h"
     24 #include "tensorflow/core/platform/logging.h"
     25 #include "tensorflow/core/platform/types.h"
     26 
     27 namespace xla {
     28 
     29 class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
     30  public:
     31   explicit BFloat16ConversionFoldingVisitor(
     32       HloComputation* computation, const BFloat16Support* bfloat16_support)
     33       : computation_(computation), bfloat16_support_(bfloat16_support) {}
     34 
     35   Status DefaultAction(HloInstruction* hlo) override;
     36 
     37   static bool Run(HloComputation* computation,
     38                   const BFloat16Support* bfloat16_support) {
     39     BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support);
     40     TF_CHECK_OK(computation->Accept(&visitor));
     41     return visitor.changed_;
     42   }
     43 
     44  private:
     45   // Checks if the HLO has a BF16 -> F32 conversion as input, or a F32 -> BF16
     46   // conversion as output, and folds them to the HLO itself if feasible.
     47   Status TryFoldBF16Conversions(HloInstruction* hlo);
     48 
     49   // Folds the F32 -> BF16 conversions from the HLO's output.
     50   //
     51   // Precondition: all of the HLO's users are F32 -> BF16 conversions.
     52   Status FoldOutputConversions(HloInstruction* hlo);
     53 
     54   // Folds the BF16 -> F32 conversion operand to the HLO.
     55   //
     56   // Precondition: the operand is a F32 -> BF16 conversion.
     57   Status FoldOperandConversion(HloInstruction* hlo, int64 operand_index);
     58 
     59   HloComputation* computation_;
     60   const BFloat16Support* bfloat16_support_;
     61   bool changed_ = false;
     62 };
     63 
     64 Status BFloat16ConversionFoldingVisitor::FoldOutputConversions(
     65     HloInstruction* hlo) {
     66   std::vector<HloInstruction*> materialized_users = hlo->users();
     67   hlo->mutable_shape()->set_element_type(BF16);
     68   for (auto user : materialized_users) {
     69     CHECK_EQ(user->opcode(), HloOpcode::kConvert);
     70     TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo));
     71     changed_ = true;
     72   }
     73   return Status::OK();
     74 }
     75 
     76 Status BFloat16ConversionFoldingVisitor::FoldOperandConversion(
     77     HloInstruction* hlo, int64 operand_index) {
     78   // The operand is a convert from BF16 to F32.
     79   auto operand = hlo->mutable_operand(operand_index);
     80   CHECK_EQ(operand->opcode(), HloOpcode::kConvert);
     81   TF_RETURN_IF_ERROR(
     82       hlo->ReplaceOperandWith(operand_index, operand->mutable_operand(0)));
     83   changed_ = true;
     84   return Status::OK();
     85 }
     86 
     87 Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions(
     88     HloInstruction* hlo) {
     89   std::vector<int64> bf16_to_f32_operands;
     90   bool has_other_f32_operands = false;
     91   for (int64 i = 0; i < hlo->operands().size(); ++i) {
     92     auto operand = hlo->operand(i);
     93     if (operand->shape().element_type() == F32) {
     94       if (operand->opcode() == HloOpcode::kConvert &&
     95           operand->operand(0)->shape().element_type() == BF16 &&
     96           bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
     97         // Operand is a convert from BF16 to F32 and we support BF16 input
     98         // directly in the current HLO at the operand index.
     99         bf16_to_f32_operands.push_back(i);
    100       } else {
    101         has_other_f32_operands = true;
    102       }
    103       continue;
    104     }
    105   }
    106 
    107   bool fold_output_conversion = hlo->user_count() > 0 &&
    108                                 hlo->shape().element_type() == F32 &&
    109                                 bfloat16_support_->SupportsBF16Output(*hlo) &&
    110                                 hlo != computation_->root_instruction();
    111   if (fold_output_conversion) {
    112     for (auto user : hlo->users()) {
    113       if (user->opcode() == HloOpcode::kConvert &&
    114           user->shape().element_type() == BF16) {
    115         continue;
    116       }
    117       // We should not change the output type if any user is not a conversion
    118       // from F32 to BF16.
    119       fold_output_conversion = false;
    120       break;
    121     }
    122   }
    123 
    124   if (!bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
    125     if (has_other_f32_operands ||
    126         (!fold_output_conversion && hlo->shape().element_type() == F32)) {
    127       // Some of the operands/output will remain F32, but we cannot use mixed
    128       // precisions, so we cannot do anything here.
    129       return Status::OK();
    130     }
    131   }
    132 
    133   if (fold_output_conversion) {
    134     TF_RETURN_IF_ERROR(FoldOutputConversions(hlo));
    135   }
    136 
    137   for (int64 i : bf16_to_f32_operands) {
    138     TF_RETURN_IF_ERROR(FoldOperandConversion(hlo, i));
    139   }
    140   return Status::OK();
    141 }
    142 
    143 Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
    144   // Do not fold BF16 conversions for instructions related to tuples, entry and
    145   // exit of a computation, fusion, convert, and control flow.
    146   if (hlo->opcode() == HloOpcode::kTuple ||            //
    147       hlo->opcode() == HloOpcode::kGetTupleElement ||  //
    148       hlo->opcode() == HloOpcode::kInfeed ||           //
    149       hlo->opcode() == HloOpcode::kOutfeed ||          //
    150       hlo->opcode() == HloOpcode::kConstant ||         //
    151       hlo->opcode() == HloOpcode::kParameter ||        //
    152       hlo->opcode() == HloOpcode::kFusion ||           //
    153       hlo->opcode() == HloOpcode::kConvert ||          //
    154       hlo->opcode() == HloOpcode::kCall ||             //
    155       hlo->opcode() == HloOpcode::kCustomCall ||       //
    156       hlo->opcode() == HloOpcode::kWhile ||            //
    157       hlo->opcode() == HloOpcode::kConditional) {
    158     return Status::OK();
    159   }
    160   if (hlo == computation_->root_instruction() &&
    161       !bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
    162     // If hlo is the root instruction, we cannot change its output, so folding
    163     // can only happen when it supports mixed precision so that we can change
    164     // its operands.
    165     return Status::OK();
    166   }
    167   return TryFoldBF16Conversions(hlo);
    168 }
    169 
    170 StatusOr<bool> BFloat16ConversionFolding::Run(HloModule* module) {
    171   XLA_VLOG_LINES(
    172       2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString());
    173   bool changed = false;
    174   for (auto* comp : module->MakeNonfusionComputations()) {
    175     if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) {
    176       changed = true;
    177     }
    178   }
    179   XLA_VLOG_LINES(
    180       2, "BFloat16ConversionFolding::Run(), after:\n" + module->ToString());
    181   return changed;
    182 }
    183 
    184 }  // namespace xla
    185