Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
     17 
     18 #include <algorithm>
     19 #include <memory>
     20 #include <numeric>
     21 #include <string>
     22 #include <utility>
     23 #include <vector>
     24 
     25 #include "tensorflow/compiler/xla/layout_util.h"
     26 #include "tensorflow/compiler/xla/literal_util.h"
     27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     31 #include "tensorflow/compiler/xla/service/hlo_query.h"
     32 #include "tensorflow/compiler/xla/service/shape_inference.h"
     33 #include "tensorflow/compiler/xla/shape_util.h"
     34 #include "tensorflow/compiler/xla/status_macros.h"
     35 #include "tensorflow/compiler/xla/types.h"
     36 #include "tensorflow/compiler/xla/util.h"
     37 #include "tensorflow/compiler/xla/window_util.h"
     38 #include "tensorflow/compiler/xla/xla_data.pb.h"
     39 #include "tensorflow/core/lib/core/errors.h"
     40 #include "tensorflow/core/lib/core/status.h"
     41 #include "tensorflow/core/lib/gtl/array_slice.h"
     42 #include "tensorflow/core/lib/gtl/optional.h"
     43 #include "tensorflow/core/platform/logging.h"
     44 #include "tensorflow/core/platform/types.h"
     45 
     46 namespace xla {
     47 namespace {
     48 
     49 // Returns whether operand is a literal with the given value.
     50 bool IsLiteralWithValue(const HloInstruction* operand, int8 value) {
     51   return operand->opcode() == HloOpcode::kConstant &&
     52          operand->literal().IsAll(value);
     53 }
     54 
     55 bool IsAll(const HloInstruction* op, int8 value) {
     56   if (IsLiteralWithValue(op, value)) {
     57     return true;
     58   }
     59   if (op->opcode() == HloOpcode::kBroadcast && IsAll(op->operand(0), value)) {
     60     return true;
     61   }
     62   return false;
     63 }
     64 
     65 // Returns whether the given transpose produces a result which is bit-wise
     66 // identical to its operand and thus may be replaced with a bitcast.
     67 bool TransposeIsBitcast(const HloInstruction* transpose) {
     68   CHECK_EQ(HloOpcode::kTranspose, transpose->opcode());
     69   const HloInstruction* operand = transpose->operand(0);
     70   return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(),
     71                                        transpose->dimensions());
     72 }
     73 
     74 // Returns true if the given reshape produces a result which is bit-wise
     75 // identical to its operand and thus may be replaced with a bitcast.
     76 //
     77 // This function is conservative -- even if this function returns false, the
     78 // reshape may still be a bitcast. For example, a reshape from [28x28] to [784].
     79 bool ReshapeIsBitcast(
     80     const HloInstruction* reshape,
     81     const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) {
     82   CHECK_EQ(HloOpcode::kReshape, reshape->opcode());
     83 
     84   const HloInstruction* operand = reshape->operand(0);
     85   // Can't insert bitcasts if the compiler used a memory layout which isn't
     86   // compatible.
     87   return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()) &&
     88          valid_bitcast_callback(operand->shape(), reshape->shape());
     89 }
     90 
     91 // Adds a scalar computation to the module to enable optimizations with dot
     92 // converting into reduction.
     93 HloComputation* CreateScalarBinaryComputation(HloModule* module,
     94                                               PrimitiveType primitive_type,
     95                                               HloOpcode opcode) {
     96   HloComputation::Builder b("scalar_computation");
     97   auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter(
     98       0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs"));
     99   auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter(
    100       1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs"));
    101   auto scalar_op = b.AddInstruction(
    102       HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}),
    103                                    opcode, scalar_lhs, scalar_rhs));
    104   HloComputation* scalar_computation =
    105       module->AddEmbeddedComputation(b.Build(scalar_op));
    106   return scalar_computation;
    107 }
    108 }  // namespace
    109 
    110 // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
    111 // algebraic expressions to simplified forms. Note: This only supports
    112 // simplifications that simply look at the operands of an instruction. For the
    113 // more general case a worklist based approach would be needed.
    114 class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
    115  public:
    116   // Default visitor action is to do nothing and return OK.
    117   Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
    118     return Status::OK();
    119   }
    120 
    121   Status HandleAdd(HloInstruction* add) override;
    122 
    123   Status HandleBitcast(HloInstruction* bitcast) override;
    124 
    125   Status HandleBroadcast(HloInstruction* broadcast) override;
    126 
    127   Status HandleConcatenate(HloInstruction* concatenate) override;
    128 
    129   Status HandleConstant(HloInstruction* constant) override;
    130 
    131   Status HandleCopy(HloInstruction* copy) override;
    132 
    133   Status HandleConvert(HloInstruction* convert) override;
    134 
    135   Status HandleComplex(HloInstruction* complex) override;
    136 
    137   Status HandleReal(HloInstruction* real) override;
    138 
    139   Status HandleImag(HloInstruction* imag) override;
    140 
    141   Status HandleConvolution(HloInstruction* convolution) override;
    142 
    143   Status HandleDivide(HloInstruction* divide) override;
    144 
    145   Status HandleDot(HloInstruction* dot) override;
    146 
    147   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
    148 
    149   Status HandleLog(HloInstruction* log) override;
    150 
    151   Status HandleMultiply(HloInstruction* multiply) override;
    152 
    153   Status HandlePad(HloInstruction* pad) override;
    154 
    155   Status HandlePower(HloInstruction* power) override;
    156 
    157   Status HandleReshape(HloInstruction* reshape) override;
    158 
    159   Status HandleReduce(HloInstruction* reduce) override;
    160 
    161   Status HandleReduceWindow(HloInstruction* reduce_window) override;
    162 
    163   Status HandleReverse(HloInstruction* reverse) override;
    164   Status HandleSlice(HloInstruction* slice) override;
    165   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
    166   Status HandleDynamicUpdateSlice(
    167       HloInstruction* dynamic_update_slice) override;
    168 
    169   Status HandleTranspose(HloInstruction* transpose) override;
    170 
    171   Status HandleSubtract(HloInstruction* sub) override;
    172 
    173   Status HandleMaximum(HloInstruction* maximum) override;
    174   Status HandleMinimum(HloInstruction* minimum) override;
    175 
    176   // Returns whether algebraic simplification has occurred.
    177   const bool changed() const { return changed_; }
    178 
    179   // Runs the visitor on a computation.
    180   static bool Run(
    181       HloComputation* computation, bool is_layout_sensitive,
    182       AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback,
    183       bool enable_dot_strength_reduction, bool enable_conv_simplification);
    184 
    185  private:
    186   explicit AlgebraicSimplifierVisitor(
    187       HloComputation* computation, bool is_layout_sensitive,
    188       AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback,
    189       bool enable_dot_strength_reduction, bool enable_conv_simplification)
    190       : computation_(computation),
    191         is_layout_sensitive_(is_layout_sensitive),
    192         valid_bitcast_callback_(std::move(valid_bitcast_callback)),
    193         enable_dot_strength_reduction_(enable_dot_strength_reduction),
    194         enable_conv_simplification_(enable_conv_simplification) {}
    195 
    196   // Transforms Dots where at least one input is a vector or has a degenerate
    197   // dimension and converts it into a multiply and reduce. This should enable
    198   // more fusion than leaving the nodes as Dot operations.
    199   StatusOr<bool> HandleDotStrengthReduction(HloInstruction* dot);
    200 
    201   // Reshapes an instruction to rank 1 if it is not already rank 1.
    202   HloInstruction* Flatten(HloInstruction* hlo) {
    203     if (ShapeUtil::Rank(hlo->shape()) == 1) {
    204       return hlo;
    205     }
    206     return computation_->AddInstruction(HloInstruction::CreateReshape(
    207         ShapeUtil::MakeShape(hlo->shape().element_type(),
    208                              {ShapeUtil::ElementsIn(hlo->shape())}),
    209         hlo));
    210   }
    211 
    212   // Helper method to perform and add reduction in a single dimension.
    213   HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
    214     HloInstruction* zero = computation_->AddInstruction(
    215         HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
    216     HloComputation* AddReduce_computation = CreateScalarBinaryComputation(
    217         computation_->parent(), F32, HloOpcode::kAdd);
    218     Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
    219     return computation_->AddInstruction(HloInstruction::CreateReduce(
    220         shape, hlo, zero, {dim}, AddReduce_computation));
    221   }
    222 
    223   // Convenience method for replacing an instruction with a bitcast.
    224   void ReplaceWithBitcast(HloInstruction* instruction);
    225 
    226   // Replace old instruction with new instruction if old and new instructions
    227   // have the same shape. Updates uses and root instruction. Returns whether a
    228   // replacement was made.
    229   bool ReplaceInstructionIfSameShape(HloInstruction* old_instruction,
    230                                      HloInstruction* new_instruction);
    231 
    232   // Returns whether the shape of the output of the given instructions are the
    233   // same for the purposes of simplification. If is_layout_sensitive_ is true,
    234   // then this tests shape equality including layout (ShapeUtil::Equal). If
    235   // is_layout_sensitive_ is false, then the tests shape compatibility
    236   // (ShapeUtil::Compatible).
    237   bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const;
    238 
    239   // Returns whether it was possible to transform `root` to a clamp instruction.
    240   // With min a minimum instruction, max a maximum instruction, min_operand a
    241   // operand of min and max_operand a operand of max.
    242   // Precondition: root is either a minimum or a maximum.
    243   bool TransformToClampIfSameShape(HloInstruction* root, HloInstruction* min,
    244                                    HloInstruction* min_operand,
    245                                    HloInstruction* operand, HloInstruction* max,
    246                                    HloInstruction* max_operand);
    247 
    248   // A Reshape or Broadcast that feeds an element-wise operation with a unique
    249   // non-scalar operand can sink to after the operation.
    250   StatusOr<bool> TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(
    251       HloInstruction* reshape_or_broadcast);
    252 
    253   // Replaces the existing HLO instruction old_instruction, with
    254   // new_instruction, and marks the optimizer status as changed.
    255   // Returns the Status representing the result of the replace operation.
    256   Status ReplaceWithNewInstruction(
    257       HloInstruction* old_instruction,
    258       std::unique_ptr<HloInstruction> new_instruction) {
    259     VLOG(3) << "Replacing instruction:";
    260     VLOG(3) << "  old: " << old_instruction->ToString();
    261     VLOG(3) << "  new: " << new_instruction->ToString();
    262     TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
    263         old_instruction, std::move(new_instruction)));
    264     changed_ = true;
    265     return Status::OK();
    266   }
    267 
    268   // Replaces the existing HLO instruction old_instruction, with
    269   // new_instruction, and marks the optimizer status as changed.
    270   // Returns the Status representing the result of the replace operation.
    271   Status ReplaceInstruction(HloInstruction* old_instruction,
    272                             HloInstruction* new_instruction) {
    273     VLOG(3) << "Replacing instruction:";
    274     VLOG(3) << "  old: " << old_instruction->ToString();
    275     VLOG(3) << "  new: " << new_instruction->ToString();
    276     TF_RETURN_IF_ERROR(
    277         computation_->ReplaceInstruction(old_instruction, new_instruction));
    278     changed_ = true;
    279     return Status::OK();
    280   }
    281 
    282   StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot);
    283   StatusOr<HloInstruction*> OptimizeDotOfConcatHelper(
    284       const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim,
    285       HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped);
    286 
    287   // Current HloComputation instance the AlgebraicSimplifierVisitor is
    288   // traversing.
    289   HloComputation* computation_;
    290 
    291   // Whether algebraic simplification has occurred.
    292   bool changed_ = false;
    293 
    294   // Whether layout is considered during transformation.
    295   bool is_layout_sensitive_;
    296 
    297   // Callback used to determine if a bitcast is possible.
    298   AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_;
    299 
    300   // Disable dot strength reduction on platforms where it causes a slowdown.
    301   bool enable_dot_strength_reduction_;
    302 
    303   // Disable convolution simplication on platforms where it causes a slowdown.
    304   bool enable_conv_simplification_;
    305 };
    306 
    307 bool AlgebraicSimplifierVisitor::Run(
    308     HloComputation* computation, bool is_layout_sensitive,
    309     AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback,
    310     bool enable_dot_strength_reduction, bool enable_conv_simplification) {
    311   AlgebraicSimplifierVisitor visitor(
    312       computation, is_layout_sensitive, std::move(valid_bitcast_callback),
    313       enable_dot_strength_reduction, enable_conv_simplification);
    314   TF_CHECK_OK(computation->Accept(&visitor));
    315   return visitor.changed_;
    316 }
    317 
    318 bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
    319                                            const HloInstruction* rhs) const {
    320   if (is_layout_sensitive_) {
    321     return ShapeUtil::Equal(lhs->shape(), rhs->shape());
    322   } else {
    323     return ShapeUtil::Compatible(lhs->shape(), rhs->shape());
    324   }
    325 }
    326 
    327 void AlgebraicSimplifierVisitor::ReplaceWithBitcast(
    328     HloInstruction* instruction) {
    329   CHECK_EQ(1, instruction->operand_count());
    330   CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()),
    331            ShapeUtil::ElementsIn(instruction->operand(0)->shape()));
    332   CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()),
    333            ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()));
    334 
    335   auto bitcast = computation_->AddInstruction(
    336       HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kBitcast,
    337                                   instruction->mutable_operand(0)));
    338   TF_CHECK_OK(ReplaceInstruction(instruction, bitcast));
    339 }
    340 
    341 bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape(
    342     HloInstruction* old_instruction, HloInstruction* new_instruction) {
    343   if (!SameShape(old_instruction, new_instruction)) {
    344     return false;
    345   }
    346   TF_CHECK_OK(ReplaceInstruction(old_instruction, new_instruction));
    347   return true;
    348 }
    349 
    350 Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
    351   auto lhs = add->mutable_operand(0);
    352   auto rhs = add->mutable_operand(1);
    353   // A + 0 => A
    354   VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString();
    355   if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) {
    356     return Status::OK();
    357   }
    358   // 0 + A => A
    359   VLOG(10) << "trying transform [0 + A => A]: " << add->ToString();
    360   if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) {
    361     return Status::OK();
    362   }
    363 
    364   // Canonicalization: Put constants on the right.  This makes the reassociation
    365   // rules below simpler.
    366   VLOG(10) << "trying transform [Const + A => A + Const]";
    367   if (lhs->IsConstant() && !rhs->IsConstant()) {
    368     return ReplaceWithNewInstruction(
    369         add,
    370         HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs));
    371   }
    372 
    373   // Reassociate to allow constant folding.
    374   //
    375   // Note: This is not general.  For example, we won't reassociate
    376   //
    377   //   (A + C1) + (B + C2) =>  A + B + (C1 + C2).
    378   //
    379   VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]";
    380   if (rhs->IsConstant() && lhs->opcode() == HloOpcode::kAdd &&
    381       !lhs->operand(0)->IsConstant() && lhs->operand(1)->IsConstant()) {
    382     auto* c1 = lhs->mutable_operand(1);
    383     auto* c2 = rhs;
    384     TF_ASSIGN_OR_RETURN(
    385         Shape sum_of_constants_shape,
    386         ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, c1, c2));
    387 
    388     auto* sum_of_constants =
    389         computation_->AddInstruction(HloInstruction::CreateBinary(
    390             sum_of_constants_shape, HloOpcode::kAdd, c1, c2));
    391     return ReplaceWithNewInstruction(
    392         add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd,
    393                                           lhs->mutable_operand(0),
    394                                           sum_of_constants));
    395   }
    396 
    397   return Status::OK();
    398 }
    399 
    400 Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) {
    401   // If a bitcast feeds a bitcast, make it a single bitcast.
    402   if (bitcast->operand(0)->opcode() == HloOpcode::kBitcast) {
    403     return ReplaceWithNewInstruction(
    404         bitcast, HloInstruction::CreateUnary(
    405                      bitcast->shape(), HloOpcode::kBitcast,
    406                      bitcast->mutable_operand(0)->mutable_operand(0)));
    407   }
    408   // All bitcasts can be eliminated (assuming layout constraints are
    409   // satisified).
    410   ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0));
    411   return Status::OK();
    412 }
    413 
    414 Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
    415   // If a copy feeds a copy, make it a single copy.
    416   if (copy->operand(0)->opcode() == HloOpcode::kCopy) {
    417     return ReplaceWithNewInstruction(
    418         copy, HloInstruction::CreateUnary(
    419                   copy->shape(), HloOpcode::kCopy,
    420                   copy->mutable_operand(0)->mutable_operand(0)));
    421   }
    422   // All copies can be eliminated (assuming layout constraints are satisified).
    423   ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0));
    424   return Status::OK();
    425 }
    426 
    427 Status AlgebraicSimplifierVisitor::HandleConcatenate(
    428     HloInstruction* concatenate) {
    429   tensorflow::gtl::ArraySlice<HloInstruction*> operands(
    430       concatenate->operands());
    431   if (operands.size() == 1) {
    432     // Unary concatenates are useless.
    433     ReplaceInstructionIfSameShape(concatenate, operands[0]);
    434     return Status::OK();
    435   }
    436   // Filter out and remove empty operands.
    437   std::vector<HloInstruction*> nonempty_operands;
    438   for (HloInstruction* operand : operands) {
    439     if (!ShapeUtil::HasZeroElements(operand->shape())) {
    440       nonempty_operands.push_back(operand);
    441     }
    442   }
    443   if (nonempty_operands.size() < operands.size()) {
    444     HloInstruction* replacement;
    445     if (nonempty_operands.empty()) {
    446       replacement = operands[0];
    447     } else if (nonempty_operands.size() == 1) {
    448       replacement = nonempty_operands[0];
    449     } else {
    450       replacement =
    451           computation_->AddInstruction(concatenate->CloneWithNewOperands(
    452               concatenate->shape(), nonempty_operands));
    453     }
    454     VLOG(10) << "trying to replace " << concatenate->ToString() << " with "
    455              << replacement->ToString();
    456     ReplaceInstructionIfSameShape(concatenate, replacement);
    457   } else if (operands.size() == 2) {
    458     // A binary concat with a broadcasted scalar as an operand can be converted
    459     // into a pad which is simpler to fold into other operations.
    460     bool is_effective_low_pad =
    461         operands[0]->opcode() == HloOpcode::kBroadcast &&
    462         ShapeUtil::IsScalar(operands[0]->operand(0)->shape());
    463     bool is_effective_high_pad =
    464         operands[1]->opcode() == HloOpcode::kBroadcast &&
    465         ShapeUtil::IsScalar(operands[1]->operand(0)->shape());
    466     if (!is_effective_low_pad && !is_effective_high_pad) {
    467       return Status::OK();
    468     }
    469     PaddingConfig padding_config;
    470     for (int64 dim = 0; dim < ShapeUtil::Rank(operands[0]->shape()); ++dim) {
    471       auto padding_config_dim = padding_config.add_dimensions();
    472       padding_config_dim->set_edge_padding_high(0);
    473       padding_config_dim->set_edge_padding_low(0);
    474       padding_config_dim->set_interior_padding(0);
    475       if (dim == concatenate->concatenate_dimension()) {
    476         if (is_effective_low_pad) {
    477           padding_config_dim->set_edge_padding_low(
    478               operands[0]->shape().dimensions(dim));
    479         } else {
    480           padding_config_dim->set_edge_padding_high(
    481               operands[1]->shape().dimensions(dim));
    482         }
    483       }
    484     }
    485     int64 operand_to_pad = is_effective_low_pad ? 1 : 0;
    486     int64 pad_value_operand = is_effective_low_pad ? 0 : 1;
    487     HloInstruction* pad =
    488         computation_->AddInstruction(HloInstruction::CreatePad(
    489             concatenate->shape(), operands[operand_to_pad],
    490             operands[pad_value_operand]->mutable_operand(0), padding_config));
    491     return ReplaceInstruction(concatenate, pad);
    492   }
    493   return Status::OK();
    494 }
    495 
    496 static HloInstruction* BuildTupleConstant(HloComputation* computation,
    497                                           const Literal& literal) {
    498   if (ShapeUtil::IsTuple(literal.shape())) {
    499     std::vector<HloInstruction*> elems;
    500     elems.reserve(ShapeUtil::TupleElementCount(literal.shape()));
    501     for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) {
    502       elems.push_back(
    503           BuildTupleConstant(computation, LiteralView::Create(literal, {i})));
    504     }
    505     return computation->AddInstruction(HloInstruction::CreateTuple(elems));
    506   } else {
    507     return computation->AddInstruction(
    508         HloInstruction::CreateConstant(literal.CloneToUnique()));
    509   }
    510 }
    511 
    512 Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
    513   // Tuple constants aren't directly supported by any backend. Expand them into
    514   // explicit Tuple instructions.
    515   if (ShapeUtil::IsTuple(constant->shape())) {
    516     return ReplaceInstruction(
    517         constant, BuildTupleConstant(computation_, constant->literal()));
    518   }
    519   return Status::OK();
    520 }
    521 
    522 Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
    523   auto lhs = sub->mutable_operand(0);
    524   auto rhs = sub->mutable_operand(1);
    525   // A - 0 => A
    526   VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString();
    527   if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) {
    528     return Status::OK();
    529   }
    530 
    531   // Canonicalize subtraction of a constant to addition.
    532   VLOG(10) << "trying transform [A - Const => A + (-Const)]";
    533   if (rhs->IsConstant() && !lhs->IsConstant()) {
    534     HloInstruction* negative_const = computation_->AddInstruction(
    535         HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs));
    536     return ReplaceWithNewInstruction(
    537         sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs,
    538                                           negative_const));
    539   }
    540 
    541   return Status::OK();
    542 }
    543 
    544 Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
    545   auto lhs = divide->mutable_operand(0);
    546   auto rhs = divide->mutable_operand(1);
    547   // A/1 => A
    548   VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString();
    549   if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) {
    550     return Status::OK();
    551   }
    552 
    553   // exp(A)/exp(B) => exp(A-B)
    554   if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) {
    555     VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString();
    556     HloInstruction* subtract =
    557         computation_->AddInstruction(HloInstruction::CreateBinary(
    558             divide->shape(), HloOpcode::kSubtract, lhs->mutable_operand(0),
    559             rhs->mutable_operand(0)));
    560     return ReplaceWithNewInstruction(
    561         divide, HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp,
    562                                             subtract));
    563   }
    564 
    565   // A/exp(B) => A*exp(-B)
    566   if (rhs->opcode() == HloOpcode::kExp) {
    567     VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString();
    568     HloInstruction* negate =
    569         computation_->AddInstruction(HloInstruction::CreateUnary(
    570             divide->shape(), HloOpcode::kNegate, rhs->mutable_operand(0)));
    571     HloInstruction* new_exp = computation_->AddInstruction(
    572         HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate));
    573     return ReplaceWithNewInstruction(
    574         divide, HloInstruction::CreateBinary(
    575                     divide->shape(), HloOpcode::kMultiply, lhs, new_exp));
    576   }
    577 
    578   // A/pow(B,C) => A*pow(B,-C)
    579   if (rhs->opcode() == HloOpcode::kPower) {
    580     VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString();
    581     // The output shape of the created negate operator should be the same as the
    582     // input.
    583     const Shape& negate_shape = rhs->operand(1)->shape();
    584     HloInstruction* negate =
    585         computation_->AddInstruction(HloInstruction::CreateUnary(
    586             negate_shape, HloOpcode::kNegate, rhs->mutable_operand(1)));
    587     // And the power operator should retain the output shape of the old one.
    588     const Shape& new_power_shape = rhs->shape();
    589     HloInstruction* new_power = computation_->AddInstruction(
    590         HloInstruction::CreateBinary(new_power_shape, HloOpcode::kPower,
    591                                      rhs->mutable_operand(0), negate));
    592     return ReplaceWithNewInstruction(
    593         divide, HloInstruction::CreateBinary(
    594                     divide->shape(), HloOpcode::kMultiply, lhs, new_power));
    595   }
    596 
    597   // Simplifying integral division would produce unexpected results.
    598   if (ShapeUtil::ElementIsIntegral(divide->shape())) {
    599     return Status::OK();
    600   }
    601 
    602   // A / Const => A * (1 / Const)
    603   //
    604   // (Backends can do this transformation, but generally only if the constant is
    605   // a scalar.)
    606   if (lhs->opcode() != HloOpcode::kConstant &&
    607       rhs->opcode() == HloOpcode::kConstant) {
    608     HloInstruction* one =
    609         computation_->AddInstruction(HloInstruction::CreateConstant(
    610             Literal::One(lhs->shape().element_type()).CloneToUnique()));
    611     HloInstruction* inverse =
    612         computation_->AddInstruction(HloInstruction::CreateBinary(
    613             rhs->shape(), HloOpcode::kDivide, one, rhs));
    614     return ReplaceWithNewInstruction(
    615         divide, HloInstruction::CreateBinary(
    616                     divide->shape(), HloOpcode::kMultiply, lhs, inverse));
    617   }
    618 
    619   // (A / B) / (C / D)  =>  (A / B)*(D / C) => (A * D) / (B * C)
    620   if (lhs->opcode() == HloOpcode::kDivide &&
    621       rhs->opcode() == HloOpcode::kDivide) {
    622     TF_ASSIGN_OR_RETURN(
    623         const Shape a_times_d_shape,
    624         ShapeInference::InferBinaryOpShape(HloOpcode::kMultiply,
    625                                            lhs->operand(0), rhs->operand(1)));
    626     auto a_times_d = computation_->AddInstruction(HloInstruction::CreateBinary(
    627         a_times_d_shape, HloOpcode::kMultiply, lhs->mutable_operand(0),
    628         rhs->mutable_operand(1)));
    629     TF_ASSIGN_OR_RETURN(
    630         const Shape b_times_c_shape,
    631         ShapeInference::InferBinaryOpShape(HloOpcode::kMultiply,
    632                                            lhs->operand(1), rhs->operand(0)));
    633     auto b_times_c = computation_->AddInstruction(HloInstruction::CreateBinary(
    634         b_times_c_shape, HloOpcode::kMultiply, lhs->mutable_operand(1),
    635         rhs->mutable_operand(0)));
    636     return ReplaceWithNewInstruction(
    637         divide, HloInstruction::CreateBinary(
    638                     divide->shape(), HloOpcode::kDivide, a_times_d, b_times_c));
    639   }
    640 
    641   // (A / B) / C => A / (B * C)
    642   if (lhs->opcode() == HloOpcode::kDivide) {
    643     TF_ASSIGN_OR_RETURN(const Shape b_times_c_shape,
    644                         ShapeInference::InferBinaryOpShape(
    645                             HloOpcode::kMultiply, lhs->operand(1), rhs));
    646     auto b_times_c = computation_->AddInstruction(HloInstruction::CreateBinary(
    647         b_times_c_shape, HloOpcode::kMultiply, lhs->mutable_operand(1), rhs));
    648     return ReplaceWithNewInstruction(
    649         divide,
    650         HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide,
    651                                      lhs->mutable_operand(0), b_times_c));
    652   }
    653 
    654   // A / (B / C) => (A*C) / B
    655   if (rhs->opcode() == HloOpcode::kDivide) {
    656     TF_ASSIGN_OR_RETURN(const Shape a_times_c_shape,
    657                         ShapeInference::InferBinaryOpShape(
    658                             HloOpcode::kMultiply, lhs, rhs->operand(1)));
    659     auto a_times_c = computation_->AddInstruction(HloInstruction::CreateBinary(
    660         a_times_c_shape, HloOpcode::kMultiply, lhs, rhs->mutable_operand(1)));
    661     return ReplaceWithNewInstruction(
    662         divide,
    663         HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide,
    664                                      a_times_c, rhs->mutable_operand(0)));
    665   }
    666 
    667   return Status::OK();
    668 }
    669 
    670 StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
    671     HloInstruction* dot) {
    672   HloInstruction* lhs = dot->mutable_operand(0);
    673   HloInstruction* rhs = dot->mutable_operand(1);
    674   int64 lhs_collapsing_dim =
    675       dot->dot_dimension_numbers().lhs_contracting_dimensions(0);
    676   if (lhs->IsRank2Transpose()) {
    677     lhs = lhs->mutable_operand(0);
    678     lhs_collapsing_dim = 1 - lhs_collapsing_dim;
    679   }
    680   const int64 lhs_kept_dim = 1 - lhs_collapsing_dim;
    681 
    682   int64 rhs_collapsing_dim =
    683       dot->dot_dimension_numbers().rhs_contracting_dimensions(0);
    684   if (rhs->IsRank2Transpose()) {
    685     rhs = rhs->mutable_operand(0);
    686     rhs_collapsing_dim = 1 - rhs_collapsing_dim;
    687   }
    688   const int64 rhs_kept_dim = 1 - rhs_collapsing_dim;
    689 
    690   auto reshape_if_necessary = [&](HloInstruction* hlo) {
    691     if (ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) {
    692       return hlo;
    693     }
    694     return computation_->AddInstruction(
    695         HloInstruction::CreateReshape(dot->shape(), hlo));
    696   };
    697 
    698   auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape,
    699                               int64 dim) {
    700     return computation_->AddInstruction(
    701         HloInstruction::CreateBroadcast(shape, hlo, {dim}));
    702   };
    703 
    704   auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) {
    705     return computation_->AddInstruction(HloInstruction::CreateBinary(
    706         local_lhs->shape(), HloOpcode::kMultiply, local_lhs, local_rhs));
    707   };
    708 
    709   // Strength reduce dot(a[K] , b[K]) =
    710   //  reshape(result.shape,
    711   //          reduce_sum(multiply(a, b), {0}))
    712   if (ShapeUtil::Rank(rhs->shape()) == 1 &&
    713       ShapeUtil::Rank(lhs->shape()) == 1) {
    714     TF_RETURN_IF_ERROR(
    715         ReplaceInstruction(dot, reshape_if_necessary(AddReduce(
    716                                     multiply(Flatten(lhs), Flatten(rhs)), 0))));
    717     return true;
    718   }
    719 
    720   if (ShapeUtil::IsEffectiveScalar(rhs->shape()) &&
    721       ShapeUtil::IsEffectiveScalar(lhs->shape())) {
    722     TF_RETURN_IF_ERROR(ReplaceInstruction(
    723         dot, reshape_if_necessary(multiply(Flatten(lhs), Flatten(rhs)))));
    724     return true;
    725   }
    726 
    727   // Simplify outer product into multiply with implicit broadcasting.
    728   //
    729   // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N])
    730   if (ShapeUtil::Rank(rhs->shape()) == 2 &&
    731       rhs->shape().dimensions(rhs_collapsing_dim) == 1) {
    732     TF_RETURN_IF_ERROR(ReplaceInstruction(
    733         dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0),
    734                       broadcast_to_dim(Flatten(rhs), dot->shape(), 1))));
    735     return true;
    736   }
    737 
    738   // Strength reduce dot(a[1, K], b) =
    739   //    reshape(result.shape,
    740   //      reduce_sum(
    741   //        multiply(broadcast(reshape(a, [K]), {0}), b),
    742   //        {0})
    743   //      )
    744   //    )
    745   if (ShapeUtil::Rank(lhs->shape()) == 1 ||
    746       (ShapeUtil::Rank(lhs->shape()) == 2 &&
    747        lhs->shape().dimensions(lhs_kept_dim) == 1)) {
    748     if (ShapeUtil::Rank(rhs->shape()) == 1) {
    749       TF_RETURN_IF_ERROR(ReplaceInstruction(
    750           dot,
    751           reshape_if_necessary(AddReduce(multiply(Flatten(lhs), rhs), 0))));
    752       return true;
    753     }
    754     TF_RETURN_IF_ERROR(ReplaceInstruction(
    755         dot, reshape_if_necessary(
    756                  AddReduce(multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(),
    757                                                      rhs_collapsing_dim),
    758                                     rhs),
    759                            rhs_collapsing_dim))));
    760     return true;
    761   }
    762 
    763   // Strength reduce dot(a, b[K, 1]) =
    764   //  reshape(result.shape,
    765   //    reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0})
    766   //  )
    767   if (ShapeUtil::Rank(rhs->shape()) == 1 ||
    768       (ShapeUtil::Rank(rhs->shape()) == 2 &&
    769        rhs->shape().dimensions(rhs_kept_dim) == 1)) {
    770     TF_RETURN_IF_ERROR(ReplaceInstruction(
    771         dot, reshape_if_necessary(AddReduce(
    772                  multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(),
    773                                                 lhs_collapsing_dim)),
    774                  lhs_collapsing_dim))));
    775     return true;
    776   }
    777   return false;
    778 }
    779 
    780 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat(
    781     HloInstruction* dot) {
    782   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
    783   if (dnums.lhs_contracting_dimensions_size() != 1 ||
    784       dnums.lhs_batch_dimensions_size() != 0) {
    785     return nullptr;
    786   }
    787 
    788   const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0);
    789   const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0);
    790   HloInstruction* lhs = dot->mutable_operand(0);
    791   HloInstruction* rhs = dot->mutable_operand(1);
    792 
    793   TF_ASSIGN_OR_RETURN(
    794       HloInstruction * optimized_lhs_concat,
    795       OptimizeDotOfConcatHelper(dot->shape(), lhs, lhs_contracting_dim, rhs,
    796                                 rhs_contracting_dim, /*swapped=*/false));
    797   if (optimized_lhs_concat) {
    798     return optimized_lhs_concat;
    799   }
    800 
    801   return OptimizeDotOfConcatHelper(dot->shape(), rhs, rhs_contracting_dim, lhs,
    802                                    lhs_contracting_dim, /*swapped=*/true);
    803 }
    804 
    805 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
    806     const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim,
    807     HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) {
    808   bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate &&
    809                       lhs->concatenate_dimension() == lhs_contracting_dim &&
    810                       rhs->opcode() == HloOpcode::kConstant;
    811   if (!can_optimize) {
    812     return nullptr;
    813   }
    814 
    815   // We're replacing this:
    816   //
    817   //   +-----+-----+-----+      +-------------------+
    818   //   |     |     |     |      |                   |
    819   //   |     |     |     |      |        R_0        |
    820   //   |     |     |     |      |                   |
    821   //   |     |     |     |      +-------------------+
    822   //   |     |     |     |      |                   |
    823   //   | L_0 | L_1 | L_2 |   *  |        R_1        |
    824   //   |     |     |     |      |                   |
    825   //   |     |     |     |      +-------------------+
    826   //   |     |     |     |      |                   |
    827   //   |     |     |     |      |        R_2        |
    828   //   |     |     |     |      |                   |
    829   //   +-----+-----+-----+      +-------------------+
    830   //
    831   // with this:
    832   //
    833   // [Sum over i]
    834   //
    835   //   +-----+     +-------------------+
    836   //   |     |     |                   |
    837   //   |     |  *  |        R_i        |
    838   //   |     |     |                   |
    839   //   |     |     +-------------------+
    840   //   |     |
    841   //   | L_i |
    842   //   |     |
    843   //   |     |
    844   //   |     |
    845   //   |     |
    846   //   |     |
    847   //   +-----+
    848   //
    849   // where the LHS is a concatenate operation (so we can "split" the LHS tensor
    850   // for free) and the RHS is a constant tensor (and thus can be split at
    851   // compile time).  In the future, we may also want to do this when both the
    852   // LHS and the RHS are concatenate operations that line up along the dimension
    853   // being contracted over.
    854   //
    855   // We should be able to generalize this transform to work on a non-constant
    856   // RHS when/if we have in-place slices or support input-fusing slices into
    857   // Dots.
    858 
    859   // Dimension numbers for the new dot instructions we'll create (L_i * R_i in
    860   // the diagram above).
    861   DotDimensionNumbers new_dot_dnums;
    862   new_dot_dnums.add_lhs_contracting_dimensions(swapped ? rhs_contracting_dim
    863                                                        : lhs_contracting_dim);
    864   new_dot_dnums.add_rhs_contracting_dimensions(swapped ? lhs_contracting_dim
    865                                                        : rhs_contracting_dim);
    866 
    867   // Here we use the MKN notation, where the contracted dimension has K
    868   // elements and the two non-contracted dimensions have M and N elements.
    869   HloInstruction* add_result = nullptr;
    870   int64 rhs_contracting_dim_offset = 0;
    871   int64 n = rhs->shape().dimensions(1 - rhs_contracting_dim);
    872   for (HloInstruction* concat_op : lhs->operands()) {
    873     int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim);
    874     Shape rhs_slice_shape(rhs->shape());
    875     rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k);
    876 
    877     std::array<int64, 2> start_indices;
    878     start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset;
    879     start_indices[1 - rhs_contracting_dim] = 0;
    880 
    881     std::array<int64, 2> limit_indices;
    882     limit_indices[rhs_contracting_dim] = rhs_contracting_dim_offset + sub_k;
    883     limit_indices[1 - rhs_contracting_dim] = n;
    884 
    885     HloInstruction* rhs_slice =
    886         computation_->AddInstruction(HloInstruction::CreateSlice(
    887             rhs_slice_shape, rhs, /*start_indices=*/start_indices,
    888             /*limit_indices=*/limit_indices, /*strides=*/{1, 1}));
    889 
    890     // TODO(b/69062148): We can get rid of `swapped` once all backends support
    891     // "non-canonical" contraction dimensions (that contracts dimension 1 of the
    892     // LHS with dimension 0 of the RHS).  But for now we keep the same
    893     // contraction dimensions as the incoming dot operation to ensure the new
    894     // dot operations can be lowered.
    895     HloInstruction *new_dot_lhs, *new_dot_rhs;
    896     if (swapped) {
    897       new_dot_lhs = rhs_slice;
    898       new_dot_rhs = concat_op;
    899     } else {
    900       new_dot_lhs = concat_op;
    901       new_dot_rhs = rhs_slice;
    902     }
    903 
    904     auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
    905         dot_shape, new_dot_lhs, new_dot_rhs, new_dot_dnums));
    906 
    907     if (add_result) {
    908       add_result = computation_->AddInstruction(HloInstruction::CreateBinary(
    909           dot_shape, HloOpcode::kAdd, add_result, new_dot));
    910     } else {
    911       add_result = new_dot;
    912     }
    913 
    914     rhs_contracting_dim_offset += sub_k;
    915   }
    916 
    917   return add_result;
    918 }
    919 
    920 Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
    921   auto lhs = dot->mutable_operand(0);
    922   auto rhs = dot->mutable_operand(1);
    923 
    924   // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or
    925   // below.
    926   if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 ||
    927       ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) {
    928     return Status::OK();
    929   }
    930 
    931   // Replace a zero element dot with a broadcast of the constant 0.
    932   if (ShapeUtil::HasZeroElements(dot->shape()) ||
    933       ShapeUtil::HasZeroElements(lhs->shape()) ||
    934       ShapeUtil::HasZeroElements(rhs->shape())) {
    935     auto zero = computation_->AddInstruction(
    936         HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
    937     return ReplaceWithNewInstruction(
    938         dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
    939   }
    940 
    941   TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized,
    942                       OptimizeDotOfConcat(dot));
    943   if (dot_of_concat_optimized) {
    944     VLOG(10) << "Replaced dot(concat(...), constant) with add(dot(..., "
    945                 "constant)...)";
    946     return ReplaceInstruction(dot, dot_of_concat_optimized);
    947   }
    948 
    949   if (enable_dot_strength_reduction_ && !is_layout_sensitive_) {
    950     TF_ASSIGN_OR_RETURN(bool did_strength_reduction,
    951                         HandleDotStrengthReduction(dot));
    952     if (did_strength_reduction) {
    953       return Status::OK();
    954     }
    955   }
    956 
    957   // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)).
    958   if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) {
    959     DotDimensionNumbers dot_dimension_numbers;
    960     dot_dimension_numbers.add_lhs_contracting_dimensions(1);
    961     dot_dimension_numbers.add_rhs_contracting_dimensions(0);
    962     auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
    963         ShapeUtil::PermuteDimensions({1, 0}, dot->shape()),
    964         rhs->mutable_operand(0), lhs->mutable_operand(0),
    965         dot_dimension_numbers));
    966     return ReplaceWithNewInstruction(
    967         dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
    968   }
    969 
    970   return Status::OK();
    971 }
    972 
    973 Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
    974   auto lhs = multiply->mutable_operand(0);
    975   auto rhs = multiply->mutable_operand(1);
    976   // A*1 => A
    977   VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString();
    978   if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) {
    979     return Status::OK();
    980   }
    981   // 1*A => A
    982   VLOG(10) << "trying transform [1*A => A]: " << multiply->ToString();
    983   if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(multiply, rhs)) {
    984     return Status::OK();
    985   }
    986 
    987   // exp(A) * exp(B) => exp(A+B)
    988   if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) {
    989     auto add = computation_->AddInstruction(HloInstruction::CreateBinary(
    990         multiply->shape(), HloOpcode::kAdd, lhs->mutable_operand(0),
    991         rhs->mutable_operand(0)));
    992     return ReplaceWithNewInstruction(
    993         multiply,
    994         HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add));
    995   }
    996   return Status::OK();
    997 }
    998 
    999 Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) {
   1000   // ln(exp(A)) => A
   1001   VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString();
   1002   auto operand = log->mutable_operand(0);
   1003   if (operand->opcode() == HloOpcode::kExp &&
   1004       ReplaceInstructionIfSameShape(log, operand->mutable_operand(0))) {
   1005     return Status::OK();
   1006   }
   1007 
   1008   // ln(pow(A,B)) => B*ln(A)
   1009   if (operand->opcode() == HloOpcode::kPower) {
   1010     auto new_log = computation_->AddInstruction(HloInstruction::CreateUnary(
   1011         log->shape(), HloOpcode::kLog, operand->mutable_operand(0)));
   1012     return ReplaceWithNewInstruction(
   1013         log,
   1014         HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
   1015                                      new_log, operand->mutable_operand(1)));
   1016   }
   1017 
   1018   return Status::OK();
   1019 }
   1020 
   1021 Status AlgebraicSimplifierVisitor::HandleGetTupleElement(
   1022     HloInstruction* get_tuple_element) {
   1023   auto operand = get_tuple_element->mutable_operand(0);
   1024   if (operand->opcode() == HloOpcode::kTuple) {
   1025     // get_tuple_element(make_tuple({A_0, A_1, ..., A_n}), i) => A_i
   1026     VLOG(10) << "trying transform "
   1027              << "[get_tuple_element(make_tuple({...,A_i,...}), i)] => A_i: "
   1028              << get_tuple_element->ToString();
   1029     if (ReplaceInstructionIfSameShape(
   1030             get_tuple_element,
   1031             operand->mutable_operand(get_tuple_element->tuple_index()))) {
   1032       return Status::OK();
   1033     }
   1034   }
   1035   return Status::OK();
   1036 }
   1037 
   1038 namespace {
   1039 
   1040 // Return whether the given reshape instruction leaves the dimensions at the
   1041 // given input indices unmodified, and returns their output indices.
   1042 //
   1043 // Example:
   1044 //   input_dim_indices = {2, 3}
   1045 //   input  shape = T[a, b, x, y, cd]
   1046 //   output shape = T[ab, x, 1, y, c, d]
   1047 //   return value = {1, 3}
   1048 //
   1049 // Precondition: input_dim_indices is sorted.
   1050 std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
   1051     const HloInstruction* hlo,
   1052     tensorflow::gtl::ArraySlice<int64> input_dim_indices) {
   1053   CHECK_EQ(HloOpcode::kReshape, hlo->opcode());
   1054   CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
   1055 
   1056   std::vector<int64> output_dim_indices;
   1057   std::vector<std::pair<int64, int64>> unmodified_dims =
   1058       ShapeUtil::DimensionsUnmodifiedByReshape(hlo->operand(0)->shape(),
   1059                                                hlo->shape());
   1060   size_t i = 0;  // index to unmodified_dims
   1061   for (int64 input_dim_index : input_dim_indices) {
   1062     // Search unmodified_dims for input_dim_index. We can search from the last
   1063     // matching position because input_dim_indices is guaranteed to be sorted.
   1064     while (i < unmodified_dims.size() &&
   1065            unmodified_dims[i].first < input_dim_index) {
   1066       ++i;
   1067     }
   1068     if (i >= unmodified_dims.size() ||
   1069         unmodified_dims[i].first != input_dim_index) {
   1070       return std::make_pair(false, std::vector<int64>());
   1071     }
   1072     output_dim_indices.push_back(unmodified_dims[i].second);
   1073   }
   1074   return std::make_pair(true, output_dim_indices);
   1075 }
   1076 
   1077 // Returns true if the output of "instruction" is a permutation of the
   1078 // elements of "operand". Precondition: "operand" is an operand of
   1079 // "instruction".
   1080 bool OutputIsPermutationOfOperandElements(HloInstruction* instruction,
   1081                                           HloInstruction* operand) {
   1082   DCHECK(!instruction->OperandIndices(operand).empty());
   1083   switch (instruction->opcode()) {
   1084     case HloOpcode::kReshape:
   1085     case HloOpcode::kReverse:
   1086     case HloOpcode::kSort:
   1087     case HloOpcode::kTranspose:
   1088       return true;
   1089     default:
   1090       return false;
   1091   }
   1092 }
   1093 
   1094 // Returns true if the output of "instruction" is a subset of the elements of
   1095 // "operand". Precondition: "operand" is an operand of "instruction".
   1096 bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
   1097                                      HloInstruction* operand) {
   1098   std::vector<int64> operand_indices = instruction->OperandIndices(operand);
   1099   CHECK(!operand_indices.empty());
   1100   if (operand_indices.size() != 1) {
   1101     return false;
   1102   }
   1103   int64 operand_index = operand_indices[0];
   1104   switch (instruction->opcode()) {
   1105     case HloOpcode::kSlice:
   1106       CHECK_EQ(0, operand_index);
   1107       return true;
   1108     case HloOpcode::kDynamicSlice:
   1109       return operand_index == 0;
   1110     default:
   1111       return false;
   1112   }
   1113 }
   1114 
   1115 }  // namespace
   1116 
   1117 Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
   1118   auto operand = broadcast->mutable_operand(0);
   1119   // A degenerate broadcast of a reshape that does not change the number of
   1120   // elements can be replaced by a reshape.
   1121   if (std::is_sorted(broadcast->dimensions().begin(),
   1122                      broadcast->dimensions().end()) &&
   1123       ShapeUtil::ElementsIn(broadcast->shape()) ==
   1124           ShapeUtil::ElementsIn(operand->shape())) {
   1125     VLOG(10) << "transform broadcast(X) -> reshape(X) where "
   1126                 "n(broadcast(X)) == n(X)";
   1127     return ReplaceWithNewInstruction(
   1128         broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand));
   1129   }
   1130 
   1131   // A degenerate broadcast that has the same input and output rank can be
   1132   // converted into a transpose.
   1133   if (ShapeUtil::Rank(broadcast->shape()) ==
   1134           ShapeUtil::Rank(operand->shape()) &&
   1135       ShapeUtil::ElementsIn(broadcast->shape()) ==
   1136           ShapeUtil::ElementsIn(operand->shape())) {
   1137     VLOG(10) << "transform broadcast(X) -> transpose(X) where "
   1138                 "n(broadcast(X)) == n(X)";
   1139     return ReplaceWithNewInstruction(
   1140         broadcast, HloInstruction::CreateTranspose(broadcast->shape(), operand,
   1141                                                    broadcast->dimensions()));
   1142   }
   1143 
   1144   // A broadcast of a reshape which merely inserts 1-sized dimensions can
   1145   // elide its operand.
   1146   {
   1147     bool merely_inserts_or_deletes_1_sized_dimensions;
   1148     std::vector<int64> inserted_indices, deleted_indices;
   1149     std::tie(merely_inserts_or_deletes_1_sized_dimensions, deleted_indices,
   1150              inserted_indices) =
   1151         operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
   1152     if (merely_inserts_or_deletes_1_sized_dimensions &&
   1153         deleted_indices.empty()) {
   1154       std::reverse(inserted_indices.begin(), inserted_indices.end());
   1155       auto dims = broadcast->dimensions();
   1156       for (auto inserted_index : inserted_indices) {
   1157         dims.erase(dims.begin() + inserted_index);
   1158       }
   1159       return ReplaceWithNewInstruction(
   1160           broadcast,
   1161           HloInstruction::CreateBroadcast(broadcast->shape(),
   1162                                           operand->mutable_operand(0), dims));
   1163     }
   1164   }
   1165 
   1166   // A Broadcast that feeds a unary element-wise operation can sink the
   1167   // broadcast after the unary element-wise operation.
   1168   TF_ASSIGN_OR_RETURN(
   1169       bool sink_succeeded,
   1170       TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(broadcast));
   1171   changed_ |= sink_succeeded;
   1172   if (sink_succeeded) {
   1173     return Status::OK();
   1174   }
   1175 
   1176   // A scalar broadcast feeding an instruction which only permutes (reshape,
   1177   // transpose, sort, reverse) or selects a subset of operand elements (slice,
   1178   // dynamic slice) can be replaced with a broadcast directly to the output
   1179   // shape of the instruction.
   1180   if (ShapeUtil::IsScalar(operand->shape())) {
   1181     for (HloInstruction* user : broadcast->users()) {
   1182       // Skip if the broadcast user has no uses itself.
   1183       if (user->user_count() == 0 && user != computation_->root_instruction()) {
   1184         continue;
   1185       }
   1186       if (OutputIsPermutationOfOperandElements(user, broadcast) ||
   1187           OutputIsSubsetOfOperandElements(user, broadcast)) {
   1188         VLOG(10) << "transform permuting/subset  of a scalar broadcast into "
   1189                  << "a single broadcast";
   1190         HloInstruction* new_broadcast = computation_->AddInstruction(
   1191             HloInstruction::CreateBroadcast(user->shape(), operand, {}));
   1192         // Use HloInstruction::ReplaceAllUsesWith instead of
   1193         // HloComputation::ReplaceWithNewInstruction because we are replacing an
   1194         // instruction other than the visited instruction.
   1195         changed_ = true;
   1196         return user->ReplaceAllUsesWith(new_broadcast);
   1197       }
   1198     }
   1199   }
   1200   return Status::OK();
   1201 }
   1202 
   1203 // A conversion to the same element type as the operand is a nop and can be
   1204 // removed.  A conversion of a constant can be simplified by making a new
   1205 // constant.
   1206 Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
   1207   PrimitiveType src_type = convert->operand(0)->shape().element_type();
   1208   PrimitiveType dest_type = convert->shape().element_type();
   1209   if (src_type == dest_type) {
   1210     return ReplaceInstruction(convert, convert->mutable_operand(0));
   1211   }
   1212   return Status::OK();
   1213 }
   1214 
   1215 // Complex(Real(c), Imag(c)) -> c
   1216 Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) {
   1217   auto real = complex->mutable_operand(0);
   1218   auto imag = complex->mutable_operand(1);
   1219   if (real->opcode() == HloOpcode::kReal &&
   1220       imag->opcode() == HloOpcode::kImag &&
   1221       real->operand(0) == imag->operand(0)) {
   1222     return ReplaceInstruction(complex, real->mutable_operand(0));
   1223   }
   1224   return Status::OK();
   1225 }
   1226 
   1227 // Real(Complex(r, i)) -> r
   1228 Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) {
   1229   auto operand = real->mutable_operand(0);
   1230   if (operand->opcode() == HloOpcode::kComplex) {
   1231     return ReplaceInstruction(real, operand->mutable_operand(0));
   1232   }
   1233   return Status::OK();
   1234 }
   1235 
   1236 // Imag(Complex(r, i)) -> i
   1237 Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
   1238   auto operand = imag->mutable_operand(0);
   1239   if (operand->opcode() == HloOpcode::kComplex) {
   1240     return ReplaceInstruction(imag, operand->mutable_operand(1));
   1241   }
   1242   return Status::OK();
   1243 }
   1244 
   1245 Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
   1246   if (ShapeUtil::HasZeroElements(pad->operand(0)->shape())) {
   1247     return ReplaceWithNewInstruction(
   1248         pad, HloInstruction::CreateBroadcast(pad->shape(),
   1249                                              pad->mutable_operand(1), {}));
   1250   }
   1251   // Eliminate nop pads (padding all zero), and replace a pad with negative
   1252   // padding with a pad with non-negative padding followed by a slice.
   1253   bool all_zero = true;
   1254   bool has_negative = false;
   1255   for (auto& padding_dimension : pad->padding_config().dimensions()) {
   1256     if (padding_dimension.edge_padding_low() < 0 ||
   1257         padding_dimension.edge_padding_high() < 0) {
   1258       has_negative = true;
   1259     }
   1260     if (padding_dimension.edge_padding_low() != 0 ||
   1261         padding_dimension.edge_padding_high() != 0) {
   1262       all_zero = false;
   1263     }
   1264   }
   1265 
   1266   if (all_zero) {
   1267     ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0));
   1268     return Status::OK();
   1269   }
   1270 
   1271   if (has_negative) {
   1272     // Pad has negative padding. Replace with a pad with the non-negative
   1273     // padding followed by a slice which effectively performs the negative
   1274     // padding.
   1275     // TODO(b/34628603): Add support for negative padding in the backends, or
   1276     // change kPad semantics to disallow negative padding and use slice
   1277     // instead.
   1278 
   1279     // First construct the padding config with non-negative entries and the
   1280     // compute the shape of this new pad instruction.
   1281     PaddingConfig nonzero_padding = pad->padding_config();
   1282     for (int i = 0; i < pad->padding_config().dimensions_size(); ++i) {
   1283       PaddingConfig::PaddingConfigDimension* padding_dimension =
   1284           nonzero_padding.mutable_dimensions(i);
   1285       // Set negative padding to zero.
   1286       if (padding_dimension->edge_padding_low() < 0) {
   1287         padding_dimension->set_edge_padding_low(0);
   1288       }
   1289       if (padding_dimension->edge_padding_high() < 0) {
   1290         padding_dimension->set_edge_padding_high(0);
   1291       }
   1292     }
   1293     TF_ASSIGN_OR_RETURN(Shape nonzero_pad_shape,
   1294                         ShapeInference::InferPadShape(pad->operand(0)->shape(),
   1295                                                       pad->operand(1)->shape(),
   1296                                                       nonzero_padding));
   1297     // Copy the layout from the original pad instructions. The new pad and the
   1298     // slice instruction should all have the same layout.
   1299     TF_RETURN_IF_ERROR(
   1300         LayoutUtil::CopyLayoutBetweenShapes(pad->shape(), &nonzero_pad_shape));
   1301     HloInstruction* nonzero_pad = computation_->AddInstruction(
   1302         HloInstruction::CreatePad(nonzero_pad_shape, pad->mutable_operand(0),
   1303                                   pad->mutable_operand(1), nonzero_padding));
   1304 
   1305     // Second, construct the slice instruction to perform the negative padding.
   1306     std::vector<int64> start_indices;
   1307     std::vector<int64> end_indices;
   1308     std::vector<int64> strides;
   1309     for (int64 i = 0; i < pad->padding_config().dimensions_size(); ++i) {
   1310       const PaddingConfig::PaddingConfigDimension& padding_dimension =
   1311           pad->padding_config().dimensions(i);
   1312       int64 start = 0;
   1313       if (padding_dimension.edge_padding_low() < 0) {
   1314         start = -1 * padding_dimension.edge_padding_low();
   1315       }
   1316       int64 end = nonzero_pad_shape.dimensions(i);
   1317       if (padding_dimension.edge_padding_high() < 0) {
   1318         end += padding_dimension.edge_padding_high();
   1319       }
   1320       start_indices.push_back(start);
   1321       end_indices.push_back(end);
   1322       strides.push_back(1);
   1323     }
   1324 
   1325     // Verify that the slice shape matches the pad shape.
   1326     TF_ASSIGN_OR_RETURN(
   1327         Shape inferred_slice_shape,
   1328         ShapeInference::InferSliceShape(nonzero_pad_shape, start_indices,
   1329                                         end_indices, strides));
   1330     TF_RET_CHECK(ShapeUtil::Compatible(inferred_slice_shape, pad->shape()));
   1331 
   1332     std::unique_ptr<HloInstruction> slice = HloInstruction::CreateSlice(
   1333         pad->shape(), nonzero_pad, start_indices, end_indices, strides);
   1334     return ReplaceWithNewInstruction(pad, std::move(slice));
   1335   }
   1336 
   1337   return Status::OK();
   1338 }
   1339 
   1340 Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
   1341   VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString();
   1342   auto lhs = power->mutable_operand(0);
   1343   auto rhs = power->mutable_operand(1);
   1344   if (IsAll(rhs, 0)) {
   1345     auto one = HloInstruction::CreateConstant(
   1346         Literal::One(power->shape().element_type()).CloneToUnique());
   1347     std::unique_ptr<HloInstruction> ones;
   1348     if (ShapeUtil::IsScalar(power->shape())) {
   1349       ones = std::move(one);
   1350     } else {
   1351       ones = HloInstruction::CreateBroadcast(
   1352           power->shape(), computation_->AddInstruction(std::move(one)), {});
   1353     }
   1354     return ReplaceWithNewInstruction(power, std::move(ones));
   1355   }
   1356 
   1357   VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString();
   1358   if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) {
   1359     return Status::OK();
   1360   }
   1361 
   1362   // pow(exp(A),B) => exp(A*B)
   1363   if (lhs->opcode() == HloOpcode::kExp) {
   1364     auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary(
   1365         power->shape(), HloOpcode::kMultiply, lhs->operands()[0], rhs));
   1366     return ReplaceWithNewInstruction(
   1367         power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp,
   1368                                            a_times_b));
   1369   }
   1370   VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString();
   1371   if (IsAll(rhs, 2)) {
   1372     return ReplaceWithNewInstruction(
   1373         power, HloInstruction::CreateBinary(power->shape(),
   1374                                             HloOpcode::kMultiply, lhs, lhs));
   1375   }
   1376 
   1377   VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
   1378   if (IsAll(rhs, -1)) {
   1379     auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
   1380         Literal::One(rhs->shape().element_type()).CloneToUnique()));
   1381 
   1382     // Explicitly broadcast scalar 1 to the output shape, to avoid implicit
   1383     // broadcast in divide HLO as we are trying to eliminate implicit
   1384     // broadcasting at HLO level.
   1385     auto* broadcast_one = computation_->AddInstruction(
   1386         HloInstruction::CreateBroadcast(power->shape(), one, {}));
   1387     return ReplaceWithNewInstruction(
   1388         power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide,
   1389                                             broadcast_one, lhs));
   1390   }
   1391 
   1392   VLOG(10) << "trying transform [pow(pow(A, X), Y) => pow(A, X*Y)]: "
   1393            << power->ToString();
   1394 
   1395   // Don't perform this optimization if either of the exponents is complex; this
   1396   // identity is true only for real-valued exponents.  In addition, we cowardly
   1397   // refuse to do this transformation if the two expontents have different
   1398   // element types.
   1399   if (lhs->opcode() == HloOpcode::kPower &&
   1400       !ShapeUtil::ElementIsComplex(lhs->operand(1)->shape()) &&
   1401       !ShapeUtil::ElementIsComplex(rhs->shape()) &&
   1402       ShapeUtil::SameElementType(lhs->operand(1)->shape(), rhs->shape())) {
   1403     auto exponent_product =
   1404         computation_->AddInstruction(HloInstruction::CreateBinary(
   1405             rhs->shape(), HloOpcode::kMultiply, lhs->mutable_operand(1), rhs));
   1406     return ReplaceWithNewInstruction(
   1407         power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kPower,
   1408                                             lhs->mutable_operand(0),
   1409                                             exponent_product));
   1410   }
   1411 
   1412   return Status::OK();
   1413 }
   1414 
   1415 StatusOr<bool> AlgebraicSimplifierVisitor::
   1416     TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(
   1417         HloInstruction* reshape_or_broadcast) {
   1418   bool changed = false;
   1419   if (ShapeUtil::IsScalar(reshape_or_broadcast->shape())) {
   1420     return false;
   1421   }
   1422   HloInstruction* operand = reshape_or_broadcast->mutable_operand(0);
   1423   for (HloInstruction* user : reshape_or_broadcast->users()) {
   1424     if (user->user_count() == 0 && user != computation_->root_instruction()) {
   1425       continue;
   1426     }
   1427     // Do not move reshapes or broadcasts past copies since the shape the copy
   1428     // will operate on will change.
   1429     if (user->opcode() == HloOpcode::kCopy) {
   1430       continue;
   1431     }
   1432     // Do not change the shape of fusion nodes in case there a multiple shapes
   1433     // inside the fusion node already.
   1434     if (user->opcode() == HloOpcode::kFusion) {
   1435       continue;
   1436     }
   1437     if (!user->IsElementwise()) {
   1438       continue;
   1439     }
   1440 
   1441     int64 reshape_or_broadcast_operand_index = -1;
   1442     // Find the unique non-scalar operand or continue if there isn't one.
   1443     int64 scalar_count = 0;
   1444     for (int64 i = 0; i < user->operand_count(); ++i) {
   1445       if (ShapeUtil::IsScalar(user->operand(i)->shape())) {
   1446         ++scalar_count;
   1447       } else {
   1448         reshape_or_broadcast_operand_index = i;
   1449       }
   1450     }
   1451     if (scalar_count != user->operand_count() - 1) {
   1452       continue;
   1453     }
   1454     VLOG(4) << "Sinking reshape or broadcast after user:";
   1455     VLOG(4) << "  old reshape/broadcast: " << reshape_or_broadcast->ToString();
   1456     VLOG(4) << "  old user: " << user->ToString();
   1457     CHECK_EQ(user->operand(reshape_or_broadcast_operand_index),
   1458              reshape_or_broadcast);
   1459     auto new_user_operands = user->operands();
   1460     new_user_operands[reshape_or_broadcast_operand_index] = operand;
   1461     auto new_user = computation_->AddInstruction(user->CloneWithNewOperands(
   1462         ShapeUtil::MakeShapeWithLayout(
   1463             user->shape().element_type(),
   1464             AsInt64Slice(operand->shape().dimensions()),
   1465             LayoutUtil::MinorToMajor(operand->shape())),
   1466         new_user_operands));
   1467     VLOG(4) << "  new user: " << new_user->ToString();
   1468     HloInstruction* new_reshape_or_broadcast = nullptr;
   1469     if (reshape_or_broadcast->opcode() == HloOpcode::kReshape) {
   1470       new_reshape_or_broadcast =
   1471           computation_->AddInstruction(HloInstruction::CreateReshape(
   1472               ShapeUtil::MakeShapeWithLayout(
   1473                   user->shape().element_type(),
   1474                   AsInt64Slice(reshape_or_broadcast->shape().dimensions()),
   1475                   LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())),
   1476               new_user));
   1477     } else {
   1478       TF_RET_CHECK(reshape_or_broadcast->opcode() == HloOpcode::kBroadcast);
   1479       new_reshape_or_broadcast =
   1480           computation_->AddInstruction(HloInstruction::CreateBroadcast(
   1481               ShapeUtil::MakeShapeWithLayout(
   1482                   user->shape().element_type(),
   1483                   AsInt64Slice(reshape_or_broadcast->shape().dimensions()),
   1484                   LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())),
   1485               new_user, reshape_or_broadcast->dimensions()));
   1486     }
   1487     VLOG(4) << "  new reshape/broadcast: "
   1488             << new_reshape_or_broadcast->ToString();
   1489     TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_reshape_or_broadcast));
   1490     changed = true;
   1491   }
   1492   return changed;
   1493 }
   1494 
   1495 Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
   1496   auto operand = reshape->mutable_operand(0);
   1497 
   1498   // Reshape directly to empty constant if the shape contains zero-element
   1499   // dimension.
   1500   if (ShapeUtil::HasZeroElements(reshape->shape())) {
   1501     auto empty_constant = HloInstruction::CreateConstant(
   1502         Literal::CreateFromShape(reshape->shape()));
   1503 
   1504     return ReplaceWithNewInstruction(reshape, std::move(empty_constant));
   1505   }
   1506 
   1507   // Delete no-op reshapes, i.e. where shape = operand shape.
   1508   if (SameShape(reshape, operand)) {
   1509     VLOG(10) << "deleting no-op reshape";
   1510     return ReplaceInstruction(reshape, operand);
   1511   }
   1512 
   1513   // Merge reshapes.
   1514   if (HloOpcode::kReshape == operand->opcode()) {
   1515     return ReplaceWithNewInstruction(
   1516         reshape, HloInstruction::CreateReshape(reshape->shape(),
   1517                                                operand->mutable_operand(0)));
   1518   }
   1519 
   1520   if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
   1521     auto opt_dims = ReshapeLeavesDimensionsUnmodified(
   1522         reshape, reshape->operand(0)->dimensions());
   1523     if (opt_dims.first) {
   1524       return ReplaceWithNewInstruction(
   1525           reshape,
   1526           HloInstruction::CreateBroadcast(
   1527               reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0),
   1528               opt_dims.second));
   1529     }
   1530   }
   1531 
   1532   // A Reshape that feeds a unary element-wise operation can sink the
   1533   // reshape after the unary element-wise operation.
   1534   TF_ASSIGN_OR_RETURN(
   1535       bool sink_succeeded,
   1536       TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(reshape));
   1537   changed_ |= sink_succeeded;
   1538   if (sink_succeeded) {
   1539     return Status::OK();
   1540   }
   1541 
   1542   // Make this a bitcast if possible.
   1543   if (is_layout_sensitive_ &&
   1544       ReshapeIsBitcast(reshape, valid_bitcast_callback_)) {
   1545     ReplaceWithBitcast(reshape);
   1546     return Status::OK();
   1547   }
   1548 
   1549   return Status::OK();
   1550 }
   1551 
   1552 Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
   1553   // When all the dimensions to reverse are trivial (i.e. the bound is 1),
   1554   // there is nothing to be done.
   1555   auto dim_is_one = [&](int64 i) -> bool {
   1556     return reverse->shape().dimensions(i) == 1;
   1557   };
   1558   if (std::all_of(reverse->dimensions().begin(), reverse->dimensions().end(),
   1559                   dim_is_one)) {
   1560     return ReplaceInstruction(reverse, reverse->mutable_operand(0));
   1561   }
   1562   return Status::OK();
   1563 }
   1564 
   1565 Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
   1566   // Delete no-op slices, i.e. where shape = operand shape.
   1567   if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) {
   1568     return Status::OK();
   1569   }
   1570   return Status::OK();
   1571 }
   1572 
   1573 Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
   1574     HloInstruction* dynamic_slice) {
   1575   auto operand = dynamic_slice->mutable_operand(0);
   1576   auto start_indices = dynamic_slice->operand(1);
   1577   if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
   1578     return ReplaceInstruction(dynamic_slice, operand);
   1579   }
   1580   // DynamicSlice where operand has the same size as the output and
   1581   // start_indices are all zero is simply equal to operand.
   1582   if (IsAll(start_indices, 0) && SameShape(operand, dynamic_slice)) {
   1583     return ReplaceInstruction(dynamic_slice, operand);
   1584   }
   1585   return Status::OK();
   1586 }
   1587 
   1588 Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
   1589     HloInstruction* dynamic_update_slice) {
   1590   auto update = dynamic_update_slice->mutable_operand(1);
   1591   auto start_indices = dynamic_update_slice->operand(2);
   1592   // DynamicUpdateSlice on a scalar just passes through the update argument.
   1593   if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) {
   1594     return ReplaceInstruction(dynamic_update_slice, update);
   1595   }
   1596 
   1597   // DynamicUpdateSlice where operand and update have the same size and
   1598   // start_indices are all zero is simply equal to update.
   1599   //
   1600   // (We require start_indices to be all zero because we want this optimization
   1601   // not to affect the visible behavior of this op even when the indices are out
   1602   // of range.  Currently dynamic-update-slice wraps out-of-range indices, so
   1603   // we can only remove the op if its indices never wrap.)
   1604   if (IsAll(start_indices, 0) && SameShape(dynamic_update_slice, update)) {
   1605     return ReplaceInstruction(dynamic_update_slice, update);
   1606   }
   1607   return Status::OK();
   1608 }
   1609 
   1610 Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
   1611   auto arg = reduce->mutable_operand(0);
   1612   auto init_value = reduce->mutable_operand(1);
   1613   tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
   1614   HloComputation* function = reduce->to_apply();
   1615   if (ShapeUtil::HasZeroElements(arg->shape()) ||
   1616       ShapeUtil::HasZeroElements(reduce->shape())) {
   1617     return ReplaceWithNewInstruction(
   1618         reduce,
   1619         HloInstruction::CreateBroadcast(reduce->shape(), init_value, {}));
   1620   }
   1621 
   1622   // A Transpose feeding a reduce can simply permute the reduction dimensions
   1623   // field if the output of the reduce is a vector or scalar. Higher ranked
   1624   // result may require a transpose of the output.
   1625   if (ShapeUtil::Rank(reduce->shape()) <= 1 &&
   1626       arg->opcode() == HloOpcode::kTranspose) {
   1627     auto transpose_dimensions = arg->dimensions();
   1628     std::vector<int64> new_reduce_dimensions;
   1629     for (auto dim : dimensions) {
   1630       new_reduce_dimensions.push_back(transpose_dimensions[dim]);
   1631     }
   1632     return ReplaceWithNewInstruction(
   1633         reduce, HloInstruction::CreateReduce(
   1634                     reduce->shape(), arg->mutable_operand(0), init_value,
   1635                     new_reduce_dimensions, function));
   1636   }
   1637 
   1638   // A reshape that collapses multiple dimensions into a dimension being
   1639   // reduced can just reduce all of those dimensions instead of doing a
   1640   // collapsing reshape before a reduction.
   1641   if (arg->opcode() == HloOpcode::kReshape) {
   1642     std::vector<std::pair<int64, int64>> unmodified_dims =
   1643         ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(),
   1644                                                  arg->shape());
   1645     std::vector<bool> arg_dim_in_output(ShapeUtil::Rank(arg->shape()), true);
   1646     std::vector<bool> arg_dim_unmodified(ShapeUtil::Rank(arg->shape()), false);
   1647     for (auto dim : dimensions) {
   1648       arg_dim_in_output[dim] = false;
   1649     }
   1650     for (auto dim_pair : unmodified_dims) {
   1651       arg_dim_unmodified[dim_pair.second] = true;
   1652     }
   1653     // The goal is to verify that all dimensions that are not removed in the
   1654     // reduce are unmodified by the reshape. For example:
   1655     // reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2])
   1656     bool can_move_reshape_into_reduce = true;
   1657     for (int64 i = 0; i < arg_dim_in_output.size(); ++i) {
   1658       if (arg_dim_in_output[i] && !arg_dim_unmodified[i]) {
   1659         can_move_reshape_into_reduce = false;
   1660       }
   1661     }
   1662     if (can_move_reshape_into_reduce) {
   1663       changed_ = true;
   1664       std::unordered_set<int64> dimensions_not_to_reduce;
   1665       for (auto dim_pair : unmodified_dims) {
   1666         if (arg_dim_in_output[dim_pair.second]) {
   1667           dimensions_not_to_reduce.insert(dim_pair.first);
   1668         }
   1669       }
   1670       std::vector<int64> new_reduce_dimensions;
   1671       for (int64 i = 0; i < ShapeUtil::Rank(arg->operand(0)->shape()); ++i) {
   1672         if (dimensions_not_to_reduce.count(i) == 0) {
   1673           new_reduce_dimensions.push_back(i);
   1674         }
   1675       }
   1676       return ReplaceWithNewInstruction(
   1677           reduce, HloInstruction::CreateReduce(
   1678                       reduce->shape(), arg->mutable_operand(0), init_value,
   1679                       new_reduce_dimensions, function));
   1680     }
   1681   }
   1682   if (ShapeUtil::ElementsIn(reduce->shape()) ==
   1683           ShapeUtil::ElementsIn(arg->shape()) ||
   1684       ShapeUtil::HasZeroElements(arg->shape())) {
   1685     auto reshape = computation_->AddInstruction(
   1686         HloInstruction::CreateReshape(reduce->shape(), arg));
   1687     return ReplaceWithNewInstruction(
   1688         reduce, HloInstruction::CreateMap(reduce->shape(),
   1689                                           {reshape, init_value}, function));
   1690   }
   1691   return Status::OK();
   1692 }
   1693 
   1694 Status AlgebraicSimplifierVisitor::HandleReduceWindow(
   1695     HloInstruction* reduce_window) {
   1696   if (ShapeUtil::HasZeroElements(reduce_window->operand(0)->shape())) {
   1697     return ReplaceWithNewInstruction(
   1698         reduce_window,
   1699         HloInstruction::CreateBroadcast(reduce_window->shape(),
   1700                                         reduce_window->mutable_operand(1), {}));
   1701   }
   1702   auto operand = reduce_window->mutable_operand(0);
   1703   const Window& window = reduce_window->window();
   1704   auto function = reduce_window->to_apply();
   1705   if (ShapeUtil::IsScalar(operand->shape())) {
   1706     TF_RET_CHECK(ShapeUtil::IsScalar(reduce_window->shape()));
   1707     return ReplaceWithNewInstruction(
   1708         reduce_window,
   1709         HloInstruction::CreateMap(reduce_window->shape(),
   1710                                   {operand, reduce_window->mutable_operand(1)},
   1711                                   function));
   1712   }
   1713 
   1714   VLOG(10) << "Considering folding Pad: " << operand->ToString()
   1715            << "\ninto reduce-window: " << reduce_window->ToString();
   1716 
   1717   // This optimization folds a pad op into reduce_window.
   1718   if (operand->opcode() != HloOpcode::kPad) {
   1719     VLOG(10) << "Not folding pad into reduce-window as there is no pad.";
   1720     return Status::OK();
   1721   }
   1722 
   1723   // Do not fold interior padding into ReduceWindow since the backends do not
   1724   // support it.
   1725   const PaddingConfig& pad_config = operand->padding_config();
   1726   if (HasInteriorPadding(pad_config)) {
   1727     VLOG(10) << "Not folding pad into reduce-window due to interior padding.";
   1728     return Status::OK();
   1729   }
   1730 
   1731   // If reduce_window already has padding, the pad value of the pad op and the
   1732   // init value of reduce_window must match to allow folding the pad.
   1733   const HloInstruction* pad_value = operand->operand(1);
   1734   const HloInstruction* reduce_init_value = reduce_window->operand(1);
   1735   if (pad_value != reduce_init_value) {
   1736     // The pad value is usually a constant, so we handle that case and do not
   1737     // try to get more fancy about proving equivalence in cases beyond that.
   1738     if (pad_value->opcode() != HloOpcode::kConstant ||
   1739         reduce_init_value->opcode() != HloOpcode::kConstant ||
   1740         pad_value->literal() != reduce_init_value->literal()) {
   1741       VLOG(10) << "Not folding pad into reduce-window due to different pad "
   1742                   "values.";
   1743       return Status::OK();
   1744     }
   1745   }
   1746 
   1747   // If the pad puts a single non-identity value in each window that we're
   1748   // reducing, then this is a broadcast.
   1749   HloInstruction* pad_operand = operand->mutable_operand(0);
   1750   auto is_effective_broadcast = [&] {
   1751     if (window_util::HasStride(window)) {
   1752       VLOG(10) << "Window has stride.";
   1753       return false;
   1754     }
   1755     if (!window_util::HasSymmetricPadding(pad_config)) {
   1756       VLOG(10) << "Window has uneven padding.";
   1757       return false;
   1758     }
   1759     for (int64 i = 0; i < pad_config.dimensions_size(); ++i) {
   1760       const auto& pad_dimension = pad_config.dimensions(i);
   1761       if ((pad_dimension.edge_padding_low() != 0 ||
   1762            pad_dimension.edge_padding_high() != 0) &&
   1763           pad_operand->shape().dimensions(i) != 1) {
   1764         VLOG(10) << "Found non-trivial dimension being padded: " << i;
   1765         return false;
   1766       }
   1767     }
   1768     VLOG(10) << "Found to be padding trivial dimensions only.";
   1769 
   1770     for (int64 i = 0; i < window.dimensions_size(); ++i) {
   1771       const auto& pad_dimension = pad_config.dimensions(i);
   1772       const WindowDimension& window_dimension = window.dimensions(i);
   1773       bool dimension_has_padding = (pad_dimension.edge_padding_low() != 0 ||
   1774                                     pad_dimension.edge_padding_high() != 0);
   1775       if (dimension_has_padding &&
   1776           window_dimension.size() < pad_dimension.edge_padding_low() + 1) {
   1777         VLOG(10) << "Found window did not cover single unpadded element in "
   1778                     "dimension: "
   1779                  << i;
   1780         return false;
   1781       }
   1782       if (pad_operand->shape().dimensions(i) != 1 &&
   1783           window_dimension.size() != 1) {
   1784         VLOG(10) << "Found window covers more than one element in non-trivial "
   1785                     "dimension: "
   1786                  << i;
   1787         return false;
   1788       }
   1789     }
   1790     VLOG(10) << "Found window covers a single unpadded element.";
   1791     return true;
   1792   };
   1793   if (is_effective_broadcast()) {
   1794     VLOG(10) << "Replacing pad/reduce-window with (implicit) broadcast.";
   1795     auto fadd = [this](std::unique_ptr<HloInstruction> x) {
   1796       return computation_->AddInstruction(std::move(x));
   1797     };
   1798     return ReplaceWithNewInstruction(
   1799         reduce_window, HloInstruction::CreateBroadcastSequence(
   1800                            /*output_shape=*/reduce_window->shape(),
   1801                            /*operand=*/pad_operand, fadd));
   1802   }
   1803 
   1804   // Carry out the folding of the pad into reduce_window.
   1805   VLOG(10) << "Folding pad into reduce-window.";
   1806   Window new_window = window;
   1807   const int64 rank = ShapeUtil::Rank(reduce_window->shape());
   1808   TF_RET_CHECK(pad_config.dimensions_size() == rank);
   1809   TF_RET_CHECK(window.dimensions_size() == rank);
   1810   for (int64 i = 0; i < rank; ++i) {
   1811     const auto& pad_dim = pad_config.dimensions(i);
   1812     auto& window_dim = *new_window.mutable_dimensions(i);
   1813     window_dim.set_padding_low(window_dim.padding_low() +
   1814                                pad_dim.edge_padding_low());
   1815     window_dim.set_padding_high(window_dim.padding_high() +
   1816                                 pad_dim.edge_padding_high());
   1817   }
   1818   return ReplaceWithNewInstruction(
   1819       reduce_window, HloInstruction::CreateReduceWindow(
   1820                          /*shape=*/reduce_window->shape(),
   1821                          /*operand=*/pad_operand,
   1822                          /*init_value=*/reduce_window->mutable_operand(1),
   1823                          /*window=*/new_window,
   1824                          /*reduce_computation=*/function));
   1825 }
   1826 
   1827 Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
   1828   auto operand = transpose->mutable_operand(0);
   1829   if (std::is_sorted(transpose->dimensions().begin(),
   1830                      transpose->dimensions().end())) {
   1831     VLOG(10) << "deleting no-op transpose";
   1832     return ReplaceInstruction(transpose, operand);
   1833   }
   1834 
   1835   if (HloOpcode::kTranspose == operand->opcode()) {
   1836     return ReplaceWithNewInstruction(
   1837         transpose, HloInstruction::CreateTranspose(
   1838                        transpose->shape(), operand->mutable_operand(0),
   1839                        ComposePermutations(operand->dimensions(),
   1840                                            transpose->dimensions())));
   1841   }
   1842 
   1843   if (is_layout_sensitive_ && TransposeIsBitcast(transpose)) {
   1844     ReplaceWithBitcast(transpose);
   1845     return Status::OK();
   1846   }
   1847 
   1848   return Status::OK();
   1849 }
   1850 
   1851 Status AlgebraicSimplifierVisitor::HandleConvolution(
   1852     HloInstruction* convolution) {
   1853   auto lhs = convolution->mutable_operand(0);
   1854   auto rhs = convolution->mutable_operand(1);
   1855   if (ShapeUtil::HasZeroElements(lhs->shape()) ||
   1856       ShapeUtil::HasZeroElements(rhs->shape())) {
   1857     return ReplaceWithNewInstruction(
   1858         convolution,
   1859         HloInstruction::CreateBroadcast(
   1860             convolution->shape(),
   1861             computation_->AddInstruction(HloInstruction::CreateConvert(
   1862                 ShapeUtil::MakeShape(convolution->shape().element_type(), {}),
   1863                 computation_->AddInstruction(
   1864                     HloInstruction::CreateConstant(Literal::CreateR0(0.0f))))),
   1865             {}));
   1866   }
   1867   const auto& window = convolution->window();
   1868   if (!enable_conv_simplification_) {
   1869     return Status::OK();
   1870   }
   1871   // HandleConvolution tries to replace a convolution with a DOT instruction.
   1872   //
   1873   // Only add when bitcasts can be used:
   1874   // - if bitcasts are not supported, then reshapes could be used but will
   1875   //   end up with another copy.
   1876   // - if bitcasts are supported, the simplifier will be called again with
   1877   //   bitcasts_ == true.
   1878 
   1879   // TODO(cwhipkey): b/31337498, make this layout insensitive.
   1880   if (!is_layout_sensitive_) {
   1881     return Status::OK();
   1882   }
   1883 
   1884   const ConvolutionDimensionNumbers& dnums =
   1885       convolution->convolution_dimension_numbers();
   1886   const Shape& input_shape = lhs->shape();
   1887   const Shape& filter_shape = rhs->shape();
   1888   const Shape& convolution_shape = convolution->shape();
   1889   TF_RET_CHECK(LayoutUtil::HasLayout(input_shape));
   1890   TF_RET_CHECK(LayoutUtil::HasLayout(filter_shape));
   1891   TF_RET_CHECK(LayoutUtil::HasLayout(convolution_shape));
   1892 
   1893   // Require the spatial dimensions in the kernel to have a bound of one.
   1894   for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
   1895     if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) {
   1896       return Status::OK();
   1897     }
   1898   }
   1899 
   1900   // Stride ignores part of the output, which matrix multiplication does not do,
   1901   // so require no stride. Padding and base (lhs) dilation both implicitly
   1902   // extend the data, which matrix multiplication also does not do, so require
   1903   // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect
   1904   // for a 1x1 window, so window dilation is no problem.
   1905   if (window_util::HasStride(window) || window_util::HasPadding(window) ||
   1906       window_util::HasBaseDilation(window)) {
   1907     return Status::OK();
   1908   }
   1909 
   1910   // Also, the shapes must align for a rowmajor matmul:
   1911   // - the input and output have the same layout.
   1912   // - for input/output, the channel dimension must be the most minor. Other
   1913   //   spatial dims can be in any order.
   1914   // - for filters, the input channel dimension must be more major than the
   1915   //   output channel dimension. The width+height don't matter because
   1916   //   they are 1.
   1917   //
   1918   // These constraints are harsh. If the channel dimension is the most major
   1919   // and/or the layout of input/output feature dimensions are reversed, we can
   1920   // still convert Conv into more efficient Matmul with operand transposition
   1921   // (such as the transposition flags in cuBLAS SGEMM).
   1922   if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) ||
   1923       LayoutUtil::Minor(input_shape.layout(), 0) !=
   1924           dnums.input_feature_dimension() ||
   1925       LayoutUtil::Minor(convolution_shape.layout(), 0) !=
   1926           dnums.output_feature_dimension() ||
   1927       // The input feature dimension should come later in the minor-to-major
   1928       // order.
   1929       (PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
   1930                            dnums.kernel_input_feature_dimension()) <
   1931        PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
   1932                            dnums.kernel_output_feature_dimension()))) {
   1933     return Status::OK();
   1934   }
   1935 
   1936   auto add_bitcast = [&](Shape shape, HloInstruction* operand) {
   1937     std::vector<int64> dims(operand->shape().dimensions_size());
   1938     std::iota(dims.begin(), dims.end(), 0);
   1939     return computation_->AddInstruction(
   1940         HloInstruction::CreateUnary(shape, HloOpcode::kBitcast, operand));
   1941   };
   1942 
   1943   // Replace it with a dot, with bitcasts around it to get the right shape.
   1944   const int64 input_channels =
   1945       input_shape.dimensions(dnums.input_feature_dimension());
   1946   const int64 output_channels =
   1947       filter_shape.dimensions(dnums.kernel_output_feature_dimension());
   1948 
   1949   // Computes the product of the non-feature dimensions.
   1950   int64 conv_width = 1;
   1951   for (int i = 0; i < input_shape.dimensions_size(); ++i) {
   1952     if (i != dnums.input_feature_dimension()) {
   1953       conv_width *= input_shape.dimensions(i);
   1954     }
   1955   }
   1956 
   1957   // We already checked feature_dimension is most minor, so data in input_shape
   1958   // and row-major {conv_width,input_channels} are bitwise identical.
   1959   const Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
   1960       input_shape.element_type(), {conv_width, input_channels});
   1961   // We already checked input_feature_dimension is more major than
   1962   // output_feature_dimension, so data in filter_shape and row-major
   1963   // {input_channels,output_channels} are bitwise identical.
   1964   const Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
   1965       filter_shape.element_type(), {input_channels, output_channels});
   1966   const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
   1967       convolution_shape.element_type(), {conv_width, output_channels});
   1968 
   1969   // We cannot insert bitcasts if the layouts will not be compatible.
   1970   // TODO(b/33178038): Consider inserting a transpose if a bitcast would be
   1971   // invalid.
   1972   if (!valid_bitcast_callback_(input_shape, new_input_shape) ||
   1973       !valid_bitcast_callback_(filter_shape, new_filter_shape) ||
   1974       !valid_bitcast_callback_(dot_output_shape, convolution_shape)) {
   1975     return Status::OK();
   1976   }
   1977 
   1978   auto new_lhs = add_bitcast(new_input_shape, lhs);
   1979   auto new_rhs = add_bitcast(new_filter_shape, rhs);
   1980   DotDimensionNumbers dot_dimension_numbers;
   1981   dot_dimension_numbers.add_lhs_contracting_dimensions(1);
   1982   dot_dimension_numbers.add_rhs_contracting_dimensions(0);
   1983   auto dot = computation_->AddInstruction(HloInstruction::CreateDot(
   1984       dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers));
   1985   return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot));
   1986 }
   1987 
   1988 bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
   1989     HloInstruction* root, HloInstruction* min, HloInstruction* min_operand,
   1990     HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand) {
   1991   // Ensure shapes of min and max operand are equal to match current shape
   1992   // inference.
   1993   if (!SameShape(min_operand, max_operand)) {
   1994     return false;
   1995   }
   1996 
   1997   auto clamp = HloInstruction::CreateTernary(root->shape(), HloOpcode::kClamp,
   1998                                              max_operand, operand, min_operand);
   1999   TF_CHECK_OK(ReplaceWithNewInstruction(root, std::move(clamp)));
   2000   return true;
   2001 }
   2002 
   2003 Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) {
   2004   // Match the following tree:
   2005   //          min_operand     operand
   2006   //                     \   /
   2007   //      max_operand     min
   2008   //                 \   /
   2009   //                  max
   2010   // where max_operand and min_operand are scalar constants.
   2011   {
   2012     HloInstruction* min;
   2013     HloInstruction* max_operand;
   2014     HloInstruction* min_operand;
   2015     HloInstruction* operand;
   2016 
   2017     if (hlo_query::MatchBinaryInstructionOperandOpcode(
   2018             HloOpcode::kMinimum, maximum,
   2019             /*matching_operand=*/&min,
   2020             /*other_operand=*/&max_operand) &&
   2021         hlo_query::MatchBinaryInstructionOperand(
   2022             hlo_query::IsScalarConstant, min,
   2023             /*matching_operand=*/&min_operand,
   2024             /*other_operand=*/&operand) &&
   2025         TransformToClampIfSameShape(maximum, min, min_operand, operand, maximum,
   2026                                     max_operand)) {
   2027       return Status::OK();
   2028     }
   2029   }
   2030 
   2031   return Status::OK();
   2032 }
   2033 
   2034 Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) {
   2035   // Match the following tree:
   2036   //          max_operand     operand
   2037   //                     \   /
   2038   //      min_operand     max
   2039   //                 \   /
   2040   //                  min
   2041   // where max_operand and min_operand are scalar constants.
   2042   {
   2043     HloInstruction* max;
   2044     HloInstruction* max_operand;
   2045     HloInstruction* min_operand;
   2046     HloInstruction* operand;
   2047 
   2048     if (hlo_query::MatchBinaryInstructionOperandOpcode(
   2049             HloOpcode::kMaximum, minimum,
   2050             /*matching_operand=*/&max,
   2051             /*other_operand=*/&min_operand) &&
   2052         hlo_query::MatchBinaryInstructionOperand(
   2053             hlo_query::IsScalarConstant, max,
   2054             /*matching_operand=*/&max_operand,
   2055             /*other_operand=*/&operand) &&
   2056         TransformToClampIfSameShape(minimum, minimum, min_operand, operand, max,
   2057                                     max_operand)) {
   2058       return Status::OK();
   2059     }
   2060   }
   2061 
   2062   return Status::OK();
   2063 }
   2064 
   2065 StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
   2066   XLA_VLOG_LINES(2,
   2067                  "AlgebraicSimplifier::Run(), before:\n" + module->ToString());
   2068   bool changed = false;
   2069   for (auto* comp : module->MakeNonfusionComputations()) {
   2070     if (AlgebraicSimplifierVisitor::Run(
   2071             comp, is_layout_sensitive_, valid_bitcast_callback_,
   2072             enable_dot_strength_reduction_, enable_conv_simplification_)) {
   2073       changed = true;
   2074     }
   2075   }
   2076   XLA_VLOG_LINES(2,
   2077                  "AlgebraicSimplifier::Run(), after:\n" + module->ToString());
   2078   return changed;
   2079 }
   2080 
   2081 }  // namespace xla
   2082