Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
     17 
     18 #include <algorithm>
     19 #include <memory>
     20 #include <numeric>
     21 #include <set>
     22 #include <string>
     23 #include <utility>
     24 #include <vector>
     25 
     26 #include "tensorflow/compiler/xla/layout_util.h"
     27 #include "tensorflow/compiler/xla/literal_util.h"
     28 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     32 #include "tensorflow/compiler/xla/service/hlo_query.h"
     33 #include "tensorflow/compiler/xla/service/shape_inference.h"
     34 #include "tensorflow/compiler/xla/shape_util.h"
     35 #include "tensorflow/compiler/xla/status_macros.h"
     36 #include "tensorflow/compiler/xla/types.h"
     37 #include "tensorflow/compiler/xla/util.h"
     38 #include "tensorflow/compiler/xla/window_util.h"
     39 #include "tensorflow/compiler/xla/xla_data.pb.h"
     40 #include "tensorflow/core/lib/core/errors.h"
     41 #include "tensorflow/core/lib/core/status.h"
     42 #include "tensorflow/core/lib/gtl/array_slice.h"
     43 #include "tensorflow/core/platform/logging.h"
     44 #include "tensorflow/core/platform/types.h"
     45 
     46 namespace xla {
     47 
     48 // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
     49 // operations into smaller operations.
     50 class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
     51  public:
     52   // Default visitor action is to do nothing and return OK.
     53   Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
     54     return Status::OK();
     55   }
     56 
     57   Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
     58 
     59   Status HandleBatchNormInference(HloInstruction* batch_norm) override;
     60 
     61   Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
     62 
     63   // Runs the visitor on a computation.
     64   static bool Run(HloComputation* computation, bool rewrite_training_op,
     65                   bool rewrite_inference_op, bool rewrite_grad_op,
     66                   bool use_fusion);
     67 
     68   // Returns whether any batch norm ops were rewritten.
     69   const bool changed() const { return changed_; }
     70 
     71   ~BatchNormExpanderVisitor() override = default;
     72 
     73  private:
     74   explicit BatchNormExpanderVisitor(HloComputation* computation,
     75                                     bool rewrite_training_op,
     76                                     bool rewrite_inference_op,
     77                                     bool rewrite_grad_op, bool use_fusion)
     78       : computation_(computation),
     79         rewrite_training_op_(rewrite_training_op),
     80         rewrite_inference_op_(rewrite_inference_op),
     81         rewrite_grad_op_(rewrite_grad_op),
     82         use_fusion_(use_fusion) {}
     83 
     84   HloComputation* GetScalarBinaryComputation(PrimitiveType primitive_type,
     85                                              HloOpcode opcode) {
     86     HloComputation::Builder b("scalar_computation");
     87     auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter(
     88         0, ShapeUtil::MakeShape(primitive_type, {}), "scalar_lhs"));
     89     auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter(
     90         1, ShapeUtil::MakeShape(primitive_type, {}), "scalar_rhs"));
     91     auto scalar_op = b.AddInstruction(
     92         HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}),
     93                                      opcode, scalar_lhs, scalar_rhs));
     94     return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
     95   }
     96 
     97   // Current HloComputation instance the BatchNormExpander is
     98   // traversing.
     99   HloComputation* computation_;
    100 
    101   bool rewrite_training_op_;
    102   bool rewrite_inference_op_;
    103   bool rewrite_grad_op_;
    104   bool use_fusion_;
    105 
    106   // Whether rewrite has occurred.
    107   bool changed_ = false;
    108 
    109   // Replaces the existing HLO instruction old_instruction, with
    110   // new_instruction, and marks the optimizer status as changed.
    111   // Returns the Status representing the result of the replace operation.
    112   Status ReplaceWithNewInstruction(
    113       HloInstruction* old_instruction,
    114       std::unique_ptr<HloInstruction> new_instruction) {
    115     TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
    116         old_instruction, std::move(new_instruction)));
    117     changed_ = true;
    118     return Status::OK();
    119   }
    120 
    121   // Replaces the existing HLO instruction old_instruction, with
    122   // new_instruction, and marks the optimizer status as changed.
    123   // Returns the Status representing the result of the replace operation.
    124   Status ReplaceInstruction(HloInstruction* old_instruction,
    125                             HloInstruction* new_instruction) {
    126     TF_RETURN_IF_ERROR(
    127         computation_->ReplaceInstruction(old_instruction, new_instruction));
    128     changed_ = true;
    129     return Status::OK();
    130   }
    131 };
    132 
    133 bool BatchNormExpanderVisitor::Run(HloComputation* computation,
    134                                    bool rewrite_training_op,
    135                                    bool rewrite_inference_op,
    136                                    bool rewrite_grad_op, bool use_fusion) {
    137   BatchNormExpanderVisitor visitor(
    138       computation,
    139       /*rewrite_training_op=*/rewrite_training_op,
    140       /*rewrite_inference_op=*/rewrite_inference_op,
    141       /*rewrite_grad_op=*/rewrite_grad_op,
    142       /*use_fusion=*/use_fusion);
    143   TF_CHECK_OK(computation->Accept(&visitor));
    144   return visitor.changed_;
    145 }
    146 
    147 Status BatchNormExpanderVisitor::HandleBatchNormTraining(
    148     HloInstruction* batch_norm) {
    149   if (!rewrite_training_op_) {
    150     return Status::OK();
    151   }
    152 
    153   std::vector<HloInstruction*> added_instructions;
    154   auto add = [&](std::unique_ptr<HloInstruction> inst) {
    155     HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
    156     added_instructions.push_back(added_inst);
    157     return added_inst;
    158   };
    159   int64 instruction_count_before = computation_->instruction_count();
    160 
    161   // Expand batch norm training into smaller HLO ops.
    162   HloInstruction* operand = batch_norm->mutable_operand(0);
    163   const Shape operand_shape = operand->shape();
    164   PrimitiveType ptype = operand_shape.element_type();
    165   int64 feature_index = batch_norm->feature_index();
    166   const int64 feature_count = operand_shape.dimensions(feature_index);
    167   const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape);
    168   auto elements_per_feature_literal =
    169       Literal::CreateR0<float>(size_in_elements / feature_count);
    170   TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
    171                       elements_per_feature_literal->Convert(ptype));
    172   auto elements_per_feature = add(
    173       HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
    174 
    175   HloInstruction* scale = batch_norm->mutable_operand(1);
    176   HloInstruction* offset = batch_norm->mutable_operand(2);
    177   const Shape feature_shape = scale->shape();
    178 
    179   auto zero_literal = Literal::CreateR0(0.0f);
    180   TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
    181   auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
    182 
    183   auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
    184   TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
    185   auto epsilon =
    186       add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
    187   std::vector<int64> dimensions_without_feature;
    188 
    189   for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
    190     if (i != feature_index) {
    191       dimensions_without_feature.push_back(i);
    192     }
    193   }
    194 
    195   auto scale_broadcasted = add(
    196       HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
    197 
    198   auto offset_broadcasted = add(
    199       HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
    200 
    201   HloComputation* add_reduce_computation =
    202       GetScalarBinaryComputation(ptype, HloOpcode::kAdd);
    203 
    204   // X^2.
    205   auto operand_squared = add(HloInstruction::CreateBinary(
    206       operand_shape, HloOpcode::kMultiply, operand, operand));
    207   // Sum[X].
    208   auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero,
    209                                               dimensions_without_feature,
    210                                               add_reduce_computation));
    211 
    212   // Sum[X^2].
    213   auto squared_sum = add(HloInstruction::CreateReduce(
    214       feature_shape, operand_squared, zero, dimensions_without_feature,
    215       add_reduce_computation));
    216 
    217   // Fuse two parallel reduces together to improve performance.
    218   if (use_fusion_ && !batch_norm->has_sharding()) {
    219     auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum}));
    220 
    221     auto fused = computation_->CreateFusionInstruction(
    222         {tuple, sum, squared_sum, operand_squared},
    223         HloInstruction::FusionKind::kInput);
    224 
    225     sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
    226 
    227     squared_sum =
    228         add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
    229   }
    230 
    231   // E[X].
    232   auto mean = add(HloInstruction::CreateBinary(
    233       feature_shape, HloOpcode::kDivide, sum, elements_per_feature));
    234 
    235   auto mean_broadcasted = add(
    236       HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
    237 
    238   // E[X^2].
    239   auto square_mean = add(HloInstruction::CreateBinary(
    240       feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature));
    241 
    242   // E^2[X].
    243   auto mean_square = add(HloInstruction::CreateBinary(
    244       feature_shape, HloOpcode::kMultiply, mean, mean));
    245 
    246   // Var[X].
    247   auto var = add(HloInstruction::CreateBinary(
    248       feature_shape, HloOpcode::kSubtract, square_mean, mean_square));
    249 
    250   auto var_broadcasted =
    251       add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
    252 
    253   // Var[X] + epsilon.
    254   auto var_add_epsilon = add(HloInstruction::CreateBinary(
    255       operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
    256 
    257   auto neg_half_literal = Literal::CreateR0(-0.5f);
    258   TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
    259   auto neg_half =
    260       add(HloInstruction::CreateConstant(std::move(neg_half_literal)));
    261 
    262   // 1 / Sqrt[Var[X] + epsilon].
    263   auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
    264       operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half));
    265 
    266   // X - E[X].
    267   auto operand_minus_mean = add(HloInstruction::CreateBinary(
    268       operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted));
    269 
    270   // (X - E[X]) / Sqrt[Var[X] + epsilon].
    271   auto normalized = add(
    272       HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply,
    273                                    operand_minus_mean, rsqrt_var_add_epsilon));
    274 
    275   // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
    276   auto scaled_normalized = add(HloInstruction::CreateBinary(
    277       operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted));
    278 
    279   // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
    280   auto shifted_normalized = add(HloInstruction::CreateBinary(
    281       operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted));
    282 
    283   auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var});
    284 
    285   if (batch_norm->has_sharding()) {
    286     int64 instruction_count_after = computation_->instruction_count();
    287     CHECK_EQ(instruction_count_after,
    288              instruction_count_before + added_instructions.size());
    289     HloSharding operand_sharding =
    290         batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0});
    291     for (HloInstruction* inst : added_instructions) {
    292       if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
    293         inst->set_sharding(operand_sharding);
    294       } else {
    295         inst->set_sharding(HloSharding::Replicate());
    296       }
    297     }
    298     tuple->set_sharding(batch_norm->sharding());
    299   }
    300   TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
    301   return Status::OK();
    302 }
    303 
    304 Status BatchNormExpanderVisitor::HandleBatchNormInference(
    305     HloInstruction* batch_norm) {
    306   if (!rewrite_inference_op_) {
    307     return Status::OK();
    308   }
    309   // Expand batch norm inference into smaller HLO ops.
    310   HloInstruction* operand = batch_norm->mutable_operand(0);
    311   const Shape operand_shape = operand->shape();
    312   int64 feature_index = batch_norm->feature_index();
    313   PrimitiveType ptype = operand_shape.element_type();
    314 
    315   HloInstruction* scale = batch_norm->mutable_operand(1);
    316   HloInstruction* offset = batch_norm->mutable_operand(2);
    317   HloInstruction* mean = batch_norm->mutable_operand(3);
    318   HloInstruction* var = batch_norm->mutable_operand(4);
    319   const Shape feature_shape = scale->shape();
    320 
    321   auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
    322   TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
    323   auto epsilon = computation_->AddInstruction(
    324       HloInstruction::CreateConstant(std::move(epsilon_literal)));
    325 
    326   std::vector<int64> dimensions_without_feature;
    327 
    328   for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
    329     if (i != feature_index) {
    330       dimensions_without_feature.push_back(i);
    331     }
    332   }
    333 
    334   std::vector<HloInstruction*> added_instructions;
    335   auto add = [&](std::unique_ptr<HloInstruction> inst) {
    336     HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
    337     added_instructions.push_back(added_inst);
    338     return added_inst;
    339   };
    340   int64 instruction_count_before = computation_->instruction_count();
    341 
    342   auto scale_broadcasted = add(
    343       HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
    344 
    345   auto offset_broadcasted = add(
    346       HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
    347 
    348   auto mean_broadcasted = add(
    349       HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
    350 
    351   auto var_broadcasted =
    352       add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
    353 
    354   // Var[X] + epsilon.
    355   auto var_add_epsilon = add(HloInstruction::CreateBinary(
    356       operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
    357 
    358   auto neg_half_literal = Literal::CreateR0(-0.5f);
    359   TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
    360   auto neg_half =
    361       add(HloInstruction::CreateConstant(std::move(neg_half_literal)));
    362 
    363   // 1 / Sqrt[Var[X] + epsilon].
    364   auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
    365       operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half));
    366 
    367   // X - E[X].
    368   auto operand_minus_mean = add(HloInstruction::CreateBinary(
    369       operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted));
    370 
    371   // (X - E[X]) / Sqrt[Var[X] + epsilon].
    372   auto normalized = add(
    373       HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply,
    374                                    operand_minus_mean, rsqrt_var_add_epsilon));
    375 
    376   // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
    377   auto scaled_normalized = add(HloInstruction::CreateBinary(
    378       operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted));
    379 
    380   // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
    381   auto shifted_normalized = HloInstruction::CreateBinary(
    382       operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted);
    383 
    384   int64 instruction_count_after = computation_->instruction_count();
    385   CHECK_EQ(instruction_count_after,
    386            instruction_count_before + added_instructions.size());
    387   if (batch_norm->has_sharding()) {
    388     for (HloInstruction* inst : added_instructions) {
    389       if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
    390         inst->set_sharding(batch_norm->sharding());
    391       } else {
    392         inst->set_sharding(HloSharding::Replicate());
    393       }
    394     }
    395     shifted_normalized->set_sharding(batch_norm->sharding());
    396   }
    397   TF_CHECK_OK(
    398       ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized)));
    399   return Status::OK();
    400 }
    401 
    402 Status BatchNormExpanderVisitor::HandleBatchNormGrad(
    403     HloInstruction* batch_norm) {
    404   // Use the following formulas to calculate gradients:
    405   // scale_grad =
    406   //   sum(output_grad * (activation - mean(activation))) * rsqrt(var + epsilon)
    407   //
    408   // offset_grad =
    409   //   sum(output_grad)
    410   //
    411   // activation_grad =
    412   //   1/N * scale * rsqrt(var + epsilon) *
    413   //   (N * output_grad - sum(output_grad) - (activation - mean(activation)) *
    414   //   sum(output_grad * (activation - mean(activation))) / (variance +
    415   //   epsilon))
    416   if (!rewrite_grad_op_) {
    417     return Status::OK();
    418   }
    419   std::vector<HloInstruction*> added_instructions;
    420   auto add = [&](std::unique_ptr<HloInstruction> inst) {
    421     HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
    422     added_instructions.push_back(added_inst);
    423     return added_inst;
    424   };
    425   int64 instruction_count_before = computation_->instruction_count();
    426 
    427   HloInstruction* activation = batch_norm->mutable_operand(0);
    428   const Shape activation_shape = activation->shape();
    429   PrimitiveType ptype = activation_shape.element_type();
    430   HloInstruction* scale = batch_norm->mutable_operand(1);
    431   const Shape feature_shape = scale->shape();
    432   HloInstruction* mean = batch_norm->mutable_operand(2);
    433   HloInstruction* variance = batch_norm->mutable_operand(3);
    434   HloInstruction* grad_output = batch_norm->mutable_operand(4);
    435 
    436   int64 feature_index = batch_norm->feature_index();
    437 
    438   const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape);
    439   const int64 feature_count = activation_shape.dimensions(feature_index);
    440   auto elements_per_feature_literal =
    441       Literal::CreateR0<float>(size_in_elements / feature_count);
    442   TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
    443                       elements_per_feature_literal->Convert(ptype));
    444   auto elements_per_feature = add(
    445       HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
    446 
    447   auto zero_literal = Literal::CreateR0(0.0f);
    448   TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
    449   auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
    450 
    451   auto neg_half_literal = Literal::CreateR0(-0.5f);
    452   TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
    453   auto neg_half =
    454       add(HloInstruction::CreateConstant(std::move(neg_half_literal)));
    455 
    456   auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
    457   TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
    458   auto epsilon =
    459       add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
    460 
    461   std::vector<int64> dimensions_without_feature;
    462 
    463   for (int64 i = 0; i < ShapeUtil::Rank(activation_shape); ++i) {
    464     if (i != feature_index) {
    465       dimensions_without_feature.push_back(i);
    466     }
    467   }
    468 
    469   auto scale_broadcasted = add(HloInstruction::CreateBroadcast(
    470       activation_shape, scale, {feature_index}));
    471   auto variance_broadcasted = add(HloInstruction::CreateBroadcast(
    472       activation_shape, variance, {feature_index}));
    473 
    474   // E[X].
    475   auto mean_broadcasted = add(
    476       HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index}));
    477 
    478   // rsqrt[Var[X] + epsilon].
    479   auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary(
    480       activation_shape, HloOpcode::kPower,
    481       add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd,
    482                                        variance_broadcasted, epsilon)),
    483       neg_half));
    484 
    485   auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
    486       feature_shape, HloOpcode::kPower,
    487       add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance,
    488                                        epsilon)),
    489       neg_half));
    490 
    491   // X - E[X].
    492   auto activation_minus_mean = add(HloInstruction::CreateBinary(
    493       activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted));
    494 
    495   // Grad[Y] * (X - E[X]).
    496   auto grad_output_times_activiation_minus_mean =
    497       add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
    498                                        grad_output, activation_minus_mean));
    499 
    500   HloComputation* add_reduce_computation =
    501       GetScalarBinaryComputation(ptype, HloOpcode::kAdd);
    502 
    503   // sum(Grad[Y] * (X - E[X])).
    504   auto sum_grad_output_times_activiation_minus_mean =
    505       add(HloInstruction::CreateReduce(
    506           feature_shape, grad_output_times_activiation_minus_mean, zero,
    507           dimensions_without_feature, add_reduce_computation));
    508 
    509   // Grad[beta] = Sum(Grad[Y]).
    510   auto grad_beta = add(HloInstruction::CreateReduce(
    511       feature_shape, grad_output, zero, dimensions_without_feature,
    512       add_reduce_computation));
    513 
    514   if (use_fusion_ && !batch_norm->has_sharding()) {
    515     auto tuple = add(HloInstruction::CreateTuple(
    516         {sum_grad_output_times_activiation_minus_mean, grad_beta}));
    517 
    518     auto fused = computation_->CreateFusionInstruction(
    519         {tuple, sum_grad_output_times_activiation_minus_mean, grad_beta},
    520         HloInstruction::FusionKind::kInput);
    521 
    522     sum_grad_output_times_activiation_minus_mean =
    523         add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
    524 
    525     grad_beta =
    526         add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
    527   }
    528 
    529   // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]).
    530   auto grad_scale = add(HloInstruction::CreateBinary(
    531       feature_shape, HloOpcode::kMultiply,
    532       sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon));
    533 
    534   // I2 = Sum(Grad[Y])
    535   auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta,
    536                                                 {feature_index}));
    537 
    538   // I3 = Sum(Grad[Y] * (X - E[X]))
    539   auto i3 = add(HloInstruction::CreateBroadcast(
    540       activation_shape, sum_grad_output_times_activiation_minus_mean,
    541       {feature_index}));
    542 
    543   // I4 = (X - E[X]) * I3
    544   auto i4 = add(HloInstruction::CreateBinary(
    545       activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean));
    546 
    547   // I5 = I4 / (Var[X] + epsilon)
    548   auto i5 = add(HloInstruction::CreateBinary(
    549       activation_shape, HloOpcode::kDivide, i4,
    550       add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd,
    551                                        variance_broadcasted, epsilon))));
    552 
    553   // scale * rsqrt[Var[X] + epsilon] * 1/N
    554   auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
    555       activation_shape, HloOpcode::kMultiply, scale_broadcasted,
    556       rsqrt_var_add_epsilon_broadcasted));
    557 
    558   scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
    559       activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon,
    560       elements_per_feature));
    561 
    562   auto i1 =
    563       add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
    564                                        grad_output, elements_per_feature));
    565 
    566   // I6 = I1 - I2 - I5
    567   auto i6 = add(HloInstruction::CreateBinary(
    568       activation_shape, HloOpcode::kSubtract,
    569       add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract,
    570                                        i1, i2)),
    571       i5));
    572 
    573   // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6.
    574   auto grad_activation =
    575       add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
    576                                        scale_times_rsqrt_var_add_epsilon, i6));
    577   auto tuple =
    578       HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta});
    579   if (batch_norm->has_sharding()) {
    580     int64 instruction_count_after = computation_->instruction_count();
    581     CHECK_EQ(instruction_count_after,
    582              instruction_count_before + added_instructions.size());
    583     HloSharding activation_sharding =
    584         batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0});
    585     for (HloInstruction* inst : added_instructions) {
    586       if (ShapeUtil::Equal(inst->shape(), activation_shape)) {
    587         inst->set_sharding(activation_sharding);
    588       } else {
    589         inst->set_sharding(HloSharding::Replicate());
    590       }
    591     }
    592     tuple->set_sharding(batch_norm->sharding());
    593   }
    594 
    595   TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
    596 
    597   return Status::OK();
    598 }
    599 
    600 StatusOr<bool> BatchNormExpander::Run(HloModule* module) {
    601   XLA_VLOG_LINES(2, "BatchNormExpander::Run(), before:\n" + module->ToString());
    602   bool changed = false;
    603   for (auto* comp : module->MakeNonfusionComputations()) {
    604     if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_,
    605                                       rewrite_inference_op_, rewrite_grad_op_,
    606                                       use_fusion_)) {
    607       changed = true;
    608     }
    609   }
    610   XLA_VLOG_LINES(2, "BatchNormExpander::Run(), after:\n" + module->ToString());
    611   return changed;
    612 }
    613 
    614 }  // namespace xla
    615