Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
     17 #include "tensorflow/compiler/xla/literal_util.h"
     18 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
     19 
     20 namespace xla {
     21 namespace gpu {
     22 namespace {
     23 
     24 class Visitor : public DfsHloVisitorWithDefault {
     25  public:
     26   explicit Visitor(HloComputation* computation) : computation_(computation) {}
     27 
     28   static bool Run(HloComputation* computation) {
     29     Visitor visitor(computation);
     30     TF_CHECK_OK(computation->Accept(&visitor));
     31     return visitor.changed_;
     32   }
     33 
     34   Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
     35     return Status::OK();
     36   }
     37 
     38   Status HandleBatchNormInference(HloInstruction* batch_norm) override;
     39   Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
     40   Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
     41 
     42  private:
     43   bool changed_ = false;
     44   HloComputation* computation_;
     45 };
     46 
     47 // cudnn defines CUDNN_BN_MIN_EPSILON = 1e-5 as the minimum acceptable epsilon
     48 // for calls to its batchnorm ops.
     49 bool EpsilonInRange(HloInstruction* batch_norm) {
     50   return batch_norm->epsilon() >= 1e-5;
     51 }
     52 
     53 Status Visitor::HandleBatchNormInference(HloInstruction* batch_norm) {
     54   if (batch_norm->operand(0)->shape().element_type() != F32) {
     55     VLOG(1) << "Not rewriting op with non-F32 element type: "
     56             << batch_norm->ToString();
     57     return Status::OK();
     58   }
     59 
     60   // cudnn errors out on zero-sized inputs.
     61   if (ShapeUtil::ElementsIn(batch_norm->operand(0)->shape()) == 0) {
     62     return Status::OK();
     63   }
     64 
     65   if (!EpsilonInRange(batch_norm)) {
     66     return Status::OK();
     67   }
     68 
     69   HloInstruction* epsilon = computation_->AddInstruction(
     70       HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
     71   HloInstruction* feature_index =
     72       computation_->AddInstruction(HloInstruction::CreateConstant(
     73           Literal::CreateR0(batch_norm->feature_index())));
     74 
     75   std::vector<HloInstruction*> operands(batch_norm->operands().begin(),
     76                                         batch_norm->operands().end());
     77   operands.push_back(epsilon);
     78   operands.push_back(feature_index);
     79 
     80   std::unique_ptr<HloInstruction> libcall = HloInstruction::CreateCustomCall(
     81       batch_norm->shape(), operands, kCudnnBatchNormForwardInferenceCallTarget);
     82   TF_RETURN_IF_ERROR(
     83       computation_->ReplaceWithNewInstruction(batch_norm, std::move(libcall)));
     84   changed_ = true;
     85   return Status::OK();
     86 }
     87 
     88 Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) {
     89   if (batch_norm->operand(0)->shape().element_type() != F32) {
     90     VLOG(1) << "Not rewriting op with non-F32 element type: "
     91             << batch_norm->ToString();
     92     return Status::OK();
     93   }
     94 
     95   // cudnn errors out on zero-sized inputs.
     96   if (ShapeUtil::ElementsIn(batch_norm->operand(0)->shape()) == 0) {
     97     return Status::OK();
     98   }
     99 
    100   if (!EpsilonInRange(batch_norm)) {
    101     return Status::OK();
    102   }
    103 
    104   HloInstruction* epsilon = computation_->AddInstruction(
    105       HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
    106   HloInstruction* feature_index =
    107       computation_->AddInstruction(HloInstruction::CreateConstant(
    108           Literal::CreateR0(batch_norm->feature_index())));
    109 
    110   std::vector<HloInstruction*> operands(batch_norm->operands().begin(),
    111                                         batch_norm->operands().end());
    112   operands.push_back(epsilon);
    113   operands.push_back(feature_index);
    114 
    115   HloInstruction* libcall =
    116       computation_->AddInstruction(HloInstruction::CreateCustomCall(
    117           batch_norm->shape(), operands,
    118           kCudnnBatchNormForwardTrainingCallTarget));
    119 
    120   // The cudnn libcall returns a tuple
    121   //   {output, mean, rsqrt(variance + epsilon)},
    122   // but the batchnorm HLO returns {output, mean, variance}.  Fix it up.
    123   HloInstruction* inverse_stddev =
    124       computation_->AddInstruction(HloInstruction::CreateGetTupleElement(
    125           libcall->shape().tuple_shapes(2), libcall, 2));
    126   HloInstruction* variance_plus_epsilon =
    127       computation_->AddInstruction(HloInstruction::CreateBinary(
    128           inverse_stddev->shape(), HloOpcode::kPower, inverse_stddev,
    129           computation_->AddInstruction(
    130               HloInstruction::CreateConstant(Literal::CreateR0<float>(-2)))));
    131   HloInstruction* variance =
    132       computation_->AddInstruction(HloInstruction::CreateBinary(
    133           variance_plus_epsilon->shape(), HloOpcode::kSubtract,
    134           variance_plus_epsilon, epsilon));
    135 
    136   // Repackage the results.
    137   std::unique_ptr<HloInstruction> new_tuple = HloInstruction::CreateTuple({
    138       computation_->AddInstruction(HloInstruction::CreateGetTupleElement(
    139           libcall->shape().tuple_shapes(0), libcall, 0)),
    140       computation_->AddInstruction(HloInstruction::CreateGetTupleElement(
    141           libcall->shape().tuple_shapes(1), libcall, 1)),
    142       variance,
    143   });
    144 
    145   TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
    146       batch_norm, std::move(new_tuple)));
    147   changed_ = true;
    148   return Status::OK();
    149 }
    150 
    151 Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) {
    152   if (batch_norm->operand(0)->shape().element_type() != F32) {
    153     VLOG(1) << "Not rewriting op with non-F32 element type: "
    154             << batch_norm->ToString();
    155     return Status::OK();
    156   }
    157 
    158   // cudnn errors out on zero-sized inputs.
    159   if (ShapeUtil::ElementsIn(batch_norm->operand(0)->shape()) == 0) {
    160     return Status::OK();
    161   }
    162 
    163   if (!EpsilonInRange(batch_norm)) {
    164     return Status::OK();
    165   }
    166 
    167   HloInstruction* epsilon = computation_->AddInstruction(
    168       HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
    169   HloInstruction* feature_index =
    170       computation_->AddInstruction(HloInstruction::CreateConstant(
    171           Literal::CreateR0(batch_norm->feature_index())));
    172 
    173   // The cudnn libcall expects its input to be rsqrt(variance + epsilon), but
    174   // the batchnorm HLO takes plain variance as input.  Fix it up.
    175   HloInstruction* var_plus_epsilon =
    176       computation_->AddInstruction(HloInstruction::CreateBinary(
    177           batch_norm->operand(3)->shape(), HloOpcode::kAdd,
    178           batch_norm->mutable_operand(3), epsilon));
    179   HloInstruction* inverse_stddev =
    180       computation_->AddInstruction(HloInstruction::CreateBinary(
    181           var_plus_epsilon->shape(), HloOpcode::kPower, var_plus_epsilon,
    182           computation_->AddInstruction(
    183               HloInstruction::CreateConstant(Literal::CreateR0<float>(-.5)))));
    184 
    185   std::vector<HloInstruction*> operands(batch_norm->operands().begin(),
    186                                         batch_norm->operands().end());
    187   operands[3] = inverse_stddev;
    188   operands.push_back(epsilon);
    189   operands.push_back(feature_index);
    190 
    191   std::unique_ptr<HloInstruction> libcall = HloInstruction::CreateCustomCall(
    192       batch_norm->shape(), operands, kCudnnBatchNormBackwardCallTarget);
    193 
    194   TF_RETURN_IF_ERROR(
    195       computation_->ReplaceWithNewInstruction(batch_norm, std::move(libcall)));
    196   changed_ = true;
    197   return Status::OK();
    198 }
    199 
    200 }  // anonymous namespace
    201 
    202 StatusOr<bool> CudnnBatchNormRewriter::Run(HloModule* module) {
    203   VLOG(2) << "CudnnBatchNormRewriter::Run(), before:";
    204   XLA_VLOG_LINES(2, module->ToString());
    205 
    206   bool changed = false;
    207   for (auto* comp : module->MakeNonfusionComputations()) {
    208     if (Visitor::Run(comp)) {
    209       changed = true;
    210     }
    211   }
    212 
    213   VLOG(2) << "CudnnBatchNormRewriter::Run(), after:";
    214   XLA_VLOG_LINES(2, module->ToString());
    215   return changed;
    216 }
    217 
    218 }  // namespace gpu
    219 }  // namespace xla
    220