Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
     17 
     18 #include <memory>
     19 #include <utility>
     20 
     21 #include "tensorflow/compiler/xla/layout_util.h"
     22 #include "tensorflow/compiler/xla/literal_util.h"
     23 #include "tensorflow/compiler/xla/ptr_util.h"
     24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     26 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     28 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
     29 #include "tensorflow/compiler/xla/shape_util.h"
     30 #include "tensorflow/compiler/xla/test.h"
     31 #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
     32 #include "tensorflow/compiler/xla/types.h"
     33 #include "tensorflow/compiler/xla/window_util.h"
     34 #include "tensorflow/compiler/xla/xla_data.pb.h"
     35 #include "tensorflow/core/lib/core/status_test_util.h"
     36 #include "tensorflow/core/lib/strings/str_util.h"
     37 
     38 namespace xla {
     39 namespace {
     40 
     41 namespace op = xla::testing::opcode_matchers;
     42 
     43 AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() {
     44   return [](const Shape&, const Shape&) { return true; };
     45 }
     46 
     47 AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() {
     48   return [](const Shape&, const Shape&) { return false; };
     49 }
     50 
     51 class AlgebraicSimplifierTest : public HloVerifiedTestBase {};
     52 
     53 // Test that A + 0 is simplified to A
     54 TEST_F(AlgebraicSimplifierTest, AddZero) {
     55   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
     56   HloComputation::Builder builder(TestName());
     57   HloInstruction* param0 = builder.AddInstruction(
     58       HloInstruction::CreateParameter(0, r0f32, "param0"));
     59   HloInstruction* zero = builder.AddInstruction(
     60       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
     61   builder.AddInstruction(
     62       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
     63 
     64   auto computation = module().AddEntryComputation(builder.Build());
     65   HloInstruction* root = computation->root_instruction();
     66   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
     67   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
     68                                  non_bitcasting_callback());
     69   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
     70   root = computation->root_instruction();
     71   EXPECT_EQ(root, param0);
     72 }
     73 
     74 // Test that Const + A is canonicalized to A + Const.
     75 TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
     76   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
     77   HloComputation::Builder builder(TestName());
     78   HloInstruction* param0 = builder.AddInstruction(
     79       HloInstruction::CreateParameter(0, r0f32, "param0"));
     80   HloInstruction* constant = builder.AddInstruction(
     81       HloInstruction::CreateConstant(Literal::CreateR0(42.0f)));
     82   builder.AddInstruction(
     83       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0));
     84 
     85   auto computation = module().AddEntryComputation(builder.Build());
     86   HloInstruction* root = computation->root_instruction();
     87   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
     88   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
     89                                  non_bitcasting_callback());
     90   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
     91   root = computation->root_instruction();
     92   EXPECT_THAT(root, op::Add(param0, op::Constant()));
     93 }
     94 
     95 // Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2.
     96 TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) {
     97   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
     98   HloComputation::Builder builder(TestName());
     99   HloInstruction* param0 = builder.AddInstruction(
    100       HloInstruction::CreateParameter(0, r0f32, "param0"));
    101   HloInstruction* constant1 = builder.AddInstruction(
    102       HloInstruction::CreateConstant(Literal::CreateR0(42.0f)));
    103   HloInstruction* constant2 = builder.AddInstruction(
    104       HloInstruction::CreateConstant(Literal::CreateR0(3.14159f)));
    105 
    106   HloInstruction* add1 = builder.AddInstruction(
    107       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1));
    108   builder.AddInstruction(
    109       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2));
    110 
    111   auto computation = module().AddEntryComputation(builder.Build());
    112   HloInstruction* root = computation->root_instruction();
    113   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
    114   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    115                                  non_bitcasting_callback());
    116   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    117   root = computation->root_instruction();
    118   EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2)));
    119 }
    120 
    121 TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
    122   Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
    123   HloComputation::Builder builder(TestName());
    124   HloInstruction* param0 = builder.AddInstruction(
    125       HloInstruction::CreateParameter(0, r2f32, "param0"));
    126   HloInstruction* zero = builder.AddInstruction(
    127       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
    128   HloInstruction* bcast = builder.AddInstruction(
    129       HloInstruction::CreateBroadcast(r2f32, zero, {0, 1}));
    130   builder.AddInstruction(
    131       HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
    132 
    133   auto computation = module().AddEntryComputation(builder.Build());
    134   HloInstruction* root = computation->root_instruction();
    135   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
    136   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    137                                  non_bitcasting_callback());
    138   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    139   root = computation->root_instruction();
    140   EXPECT_EQ(root, param0);
    141 }
    142 
    143 TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
    144   Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
    145   HloComputation::Builder builder(TestName());
    146   HloInstruction* param0 = builder.AddInstruction(
    147       HloInstruction::CreateParameter(0, r2f32, "param0"));
    148   HloInstruction* zero = builder.AddInstruction(
    149       HloInstruction::CreateConstant(Literal::CreateR1<float>({0, 0, 0})));
    150   HloInstruction* bcast =
    151       builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1}));
    152   builder.AddInstruction(
    153       HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
    154 
    155   auto computation = module().AddEntryComputation(builder.Build());
    156   HloInstruction* root = computation->root_instruction();
    157   EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
    158   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    159                                  non_bitcasting_callback());
    160   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    161   root = computation->root_instruction();
    162   EXPECT_EQ(root, param0);
    163 }
    164 
    165 // Test that A - 0 is simplified to A
    166 TEST_F(AlgebraicSimplifierTest, SubZero) {
    167   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    168   HloComputation::Builder builder(TestName());
    169   HloInstruction* param0 = builder.AddInstruction(
    170       HloInstruction::CreateParameter(0, r0f32, "param0"));
    171   HloInstruction* zero = builder.AddInstruction(
    172       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
    173   builder.AddInstruction(
    174       HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
    175 
    176   auto computation = module().AddEntryComputation(builder.Build());
    177   HloInstruction* root = computation->root_instruction();
    178   EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
    179   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    180                                  non_bitcasting_callback());
    181   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    182   root = computation->root_instruction();
    183   EXPECT_EQ(root, param0);
    184 }
    185 
    186 // Test that A - Const is canonicalized to A + (-Const).
    187 TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) {
    188   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    189   HloComputation::Builder builder(TestName());
    190   HloInstruction* param0 = builder.AddInstruction(
    191       HloInstruction::CreateParameter(0, r0f32, "param0"));
    192   HloInstruction* constant = builder.AddInstruction(
    193       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
    194   builder.AddInstruction(HloInstruction::CreateBinary(
    195       r0f32, HloOpcode::kSubtract, param0, constant));
    196 
    197   auto computation = module().AddEntryComputation(builder.Build());
    198   HloInstruction* root = computation->root_instruction();
    199   EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
    200   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    201                                  non_bitcasting_callback());
    202   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    203   root = computation->root_instruction();
    204   EXPECT_THAT(root, op::Add(param0, op::Negate(constant)));
    205 }
    206 
    207 // Test that (A/B)/C is simplified to A/(B*C).
    208 TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) {
    209   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    210   HloComputation::Builder builder(TestName());
    211   HloInstruction* param0 = builder.AddInstruction(
    212       HloInstruction::CreateParameter(0, r0f32, "param0"));
    213   HloInstruction* param1 = builder.AddInstruction(
    214       HloInstruction::CreateParameter(1, r0f32, "param1"));
    215   HloInstruction* param2 = builder.AddInstruction(
    216       HloInstruction::CreateParameter(2, r0f32, "param2"));
    217   HloInstruction* div = builder.AddInstruction(
    218       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, param1));
    219   builder.AddInstruction(
    220       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2));
    221 
    222   auto computation = module().AddEntryComputation(builder.Build());
    223 
    224   EXPECT_THAT(computation->root_instruction(),
    225               op::Divide(op::Divide(param0, param1), param2));
    226 
    227   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    228                                  non_bitcasting_callback());
    229   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    230 
    231   EXPECT_THAT(computation->root_instruction(),
    232               op::Divide(param0, op::Multiply(param1, param2)));
    233 }
    234 
    235 // Test that A/(B/C) is simplified to (A*C)/B.
    236 TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) {
    237   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    238   HloComputation::Builder builder(TestName());
    239   HloInstruction* param0 = builder.AddInstruction(
    240       HloInstruction::CreateParameter(0, r0f32, "param0"));
    241   HloInstruction* param1 = builder.AddInstruction(
    242       HloInstruction::CreateParameter(1, r0f32, "param1"));
    243   HloInstruction* param2 = builder.AddInstruction(
    244       HloInstruction::CreateParameter(2, r0f32, "param2"));
    245   HloInstruction* div = builder.AddInstruction(
    246       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param1, param2));
    247   builder.AddInstruction(
    248       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div));
    249 
    250   auto computation = module().AddEntryComputation(builder.Build());
    251 
    252   EXPECT_THAT(computation->root_instruction(),
    253               op::Divide(param0, op::Divide(param1, param2)));
    254 
    255   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    256                                  non_bitcasting_callback());
    257   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    258 
    259   EXPECT_THAT(computation->root_instruction(),
    260               op::Divide(op::Multiply(param0, param2), param1));
    261 }
    262 
    263 // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C).
    264 TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) {
    265   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    266   Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123});
    267   HloComputation::Builder builder(TestName());
    268   HloInstruction* param0 = builder.AddInstruction(
    269       HloInstruction::CreateParameter(0, r0f32, "param0"));
    270   HloInstruction* param1 = builder.AddInstruction(
    271       HloInstruction::CreateParameter(1, r2f32, "param1"));
    272   HloInstruction* param2 = builder.AddInstruction(
    273       HloInstruction::CreateParameter(2, r2f32, "param2"));
    274   HloInstruction* param3 = builder.AddInstruction(
    275       HloInstruction::CreateParameter(3, r0f32, "param3"));
    276   HloInstruction* div0 = builder.AddInstruction(
    277       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, param1));
    278   HloInstruction* div1 = builder.AddInstruction(
    279       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param2, param3));
    280   builder.AddInstruction(
    281       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1));
    282 
    283   auto computation = module().AddEntryComputation(builder.Build());
    284 
    285   EXPECT_THAT(
    286       computation->root_instruction(),
    287       op::Divide(op::Divide(param0, param1), op::Divide(param2, param3)));
    288 
    289   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    290                                  non_bitcasting_callback());
    291   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    292 
    293   EXPECT_THAT(
    294       computation->root_instruction(),
    295       op::Divide(op::Multiply(param0, param3), op::Multiply(param1, param2)));
    296   EXPECT_TRUE(
    297       ShapeUtil::Compatible(computation->root_instruction()->shape(), r2f32));
    298 }
    299 
    300 // Test that A/exp(B) is simplified to A*exp(-B).
    301 TEST_F(AlgebraicSimplifierTest, DivOfExp) {
    302   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    303   HloComputation::Builder builder(TestName());
    304   HloInstruction* param0 = builder.AddInstruction(
    305       HloInstruction::CreateParameter(0, r0f32, "param0"));
    306   HloInstruction* param1 = builder.AddInstruction(
    307       HloInstruction::CreateParameter(1, r0f32, "param1"));
    308   HloInstruction* exp = builder.AddInstruction(
    309       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
    310   builder.AddInstruction(
    311       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp));
    312 
    313   auto computation = module().AddEntryComputation(builder.Build());
    314 
    315   EXPECT_THAT(computation->root_instruction(),
    316               op::Divide(param0, op::Exp(param1)));
    317 
    318   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    319                                  non_bitcasting_callback());
    320   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    321 
    322   EXPECT_THAT(computation->root_instruction(),
    323               op::Multiply(param0, op::Exp(op::Negate(param1))));
    324 }
    325 
    326 // Test that A/pow(B,C) is simplified to A*pow(B,-C).
    327 TEST_F(AlgebraicSimplifierTest, DivOfPower) {
    328   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    329   HloComputation::Builder builder(TestName());
    330   HloInstruction* param0 = builder.AddInstruction(
    331       HloInstruction::CreateParameter(0, r0f32, "param0"));
    332   HloInstruction* param1 = builder.AddInstruction(
    333       HloInstruction::CreateParameter(1, r0f32, "param1"));
    334   HloInstruction* param2 = builder.AddInstruction(
    335       HloInstruction::CreateParameter(2, r0f32, "param2"));
    336   HloInstruction* power = builder.AddInstruction(
    337       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param1, param2));
    338   builder.AddInstruction(
    339       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power));
    340 
    341   auto computation = module().AddEntryComputation(builder.Build());
    342 
    343   EXPECT_THAT(computation->root_instruction(),
    344               op::Divide(param0, op::Power(param1, param2)));
    345 
    346   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    347                                  non_bitcasting_callback());
    348   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    349 
    350   EXPECT_THAT(computation->root_instruction(),
    351               op::Multiply(param0, op::Power(param1, op::Negate(param2))));
    352 }
    353 
    354 // Test that broadcasting is done on the right step when simplifying A/pow(B,C)
    355 // to A*pow(B,-C).
    356 TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) {
    357   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    358   Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
    359   HloComputation::Builder builder(TestName());
    360   HloInstruction* param0 = builder.AddInstruction(
    361       HloInstruction::CreateParameter(0, r1f32, "param0"));
    362   HloInstruction* param1 = builder.AddInstruction(
    363       HloInstruction::CreateParameter(1, r1f32, "param1"));
    364   HloInstruction* param2 = builder.AddInstruction(
    365       HloInstruction::CreateParameter(2, r0f32, "param2"));
    366   HloInstruction* power = builder.AddInstruction(
    367       HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param1, param2));
    368   builder.AddInstruction(
    369       HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power));
    370 
    371   auto computation = module().AddEntryComputation(builder.Build());
    372 
    373   EXPECT_THAT(computation->root_instruction(),
    374               op::Divide(param0, op::Power(param1, param2)));
    375 
    376   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    377                                  non_bitcasting_callback());
    378   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    379 
    380   ASSERT_THAT(computation->root_instruction(),
    381               op::Multiply(param0, op::Power(param1, op::Negate(param2))));
    382 
    383   const HloInstruction* negate =
    384       computation->root_instruction()->operand(1)->operand(1);
    385   const Shape& negate_shape = negate->shape();
    386   EXPECT_EQ(0, negate_shape.dimensions_size());
    387 }
    388 
    389 // A / Const => A * (1 / Const)
    390 TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
    391   Shape r1f32 = ShapeUtil::MakeShape(F32, {3});
    392   HloComputation::Builder builder(TestName());
    393   HloInstruction* param0 = builder.AddInstruction(
    394       HloInstruction::CreateParameter(0, r1f32, "param0"));
    395   HloInstruction* constant =
    396       builder.AddInstruction(HloInstruction::CreateConstant(
    397           Literal::CreateR1<float>({0.f, 1.f, 2.f})));
    398   builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide,
    399                                                       param0, constant));
    400 
    401   auto computation = module().AddEntryComputation(builder.Build());
    402 
    403   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    404                                  non_bitcasting_callback());
    405   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    406 
    407   EXPECT_THAT(computation->root_instruction(),
    408               op::Multiply(param0, op::Divide(op::Constant(), constant)));
    409 }
    410 
    411 // pow(pow(A, X), Y) => pow(A, X*Y)
    412 TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
    413   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    414   Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
    415   HloComputation::Builder builder(TestName());
    416   HloInstruction* base = builder.AddInstruction(
    417       HloInstruction::CreateParameter(0, r1f32, "param0"));
    418   HloInstruction* exp1 = builder.AddInstruction(
    419       HloInstruction::CreateParameter(1, r0f32, "param1"));
    420   HloInstruction* exp2 = builder.AddInstruction(
    421       HloInstruction::CreateParameter(2, r0f32, "param2"));
    422   HloInstruction* inner_power = builder.AddInstruction(
    423       HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1));
    424   builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower,
    425                                                       inner_power, exp2));
    426 
    427   auto computation = module().AddEntryComputation(builder.Build());
    428   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    429                                  non_bitcasting_callback());
    430   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    431   EXPECT_THAT(computation->root_instruction(),
    432               op::Power(base, op::Multiply(exp1, exp2)));
    433 }
    434 
    435 // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex
    436 // numbers.
    437 TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) {
    438   Shape r0c64 = ShapeUtil::MakeShape(C64, {});
    439   Shape r1c64 = ShapeUtil::MakeShape(C64, {7});
    440   HloComputation::Builder builder(TestName());
    441   HloInstruction* base = builder.AddInstruction(
    442       HloInstruction::CreateParameter(0, r1c64, "param0"));
    443   HloInstruction* exp1 = builder.AddInstruction(
    444       HloInstruction::CreateParameter(1, r0c64, "param1"));
    445   HloInstruction* exp2 = builder.AddInstruction(
    446       HloInstruction::CreateParameter(2, r0c64, "param2"));
    447   HloInstruction* inner_power = builder.AddInstruction(
    448       HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1));
    449   builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower,
    450                                                       inner_power, exp2));
    451 
    452   module().AddEntryComputation(builder.Build());
    453   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    454                                  non_bitcasting_callback());
    455   ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie());
    456 }
    457 
    458 // Test that A/1 is simplified to A for a scalar.
    459 TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
    460   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    461   HloComputation::Builder builder(TestName());
    462   HloInstruction* param0 = builder.AddInstruction(
    463       HloInstruction::CreateParameter(0, r0f32, "param0"));
    464   HloInstruction* one = builder.AddInstruction(
    465       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
    466   HloInstruction* div = builder.AddInstruction(
    467       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
    468 
    469   auto computation = module().AddEntryComputation(builder.Build());
    470   HloInstruction* root = computation->root_instruction();
    471   EXPECT_EQ(root, div);
    472   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    473                                  non_bitcasting_callback());
    474   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    475   root = computation->root_instruction();
    476   EXPECT_EQ(root, param0);
    477 }
    478 
    479 // Test that A/1 is simplified to A for an array.
    480 TEST_F(AlgebraicSimplifierTest, DivOneArray) {
    481   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
    482   HloComputation::Builder builder(TestName());
    483   HloInstruction* param0 = builder.AddInstruction(
    484       HloInstruction::CreateParameter(0, r2f32, "param0"));
    485   HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant(
    486       Literal::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
    487   HloInstruction* div = builder.AddInstruction(
    488       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
    489 
    490   auto computation = module().AddEntryComputation(builder.Build());
    491   HloInstruction* root = computation->root_instruction();
    492   EXPECT_EQ(root, div);
    493   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    494                                  non_bitcasting_callback());
    495   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    496   root = computation->root_instruction();
    497   EXPECT_EQ(root, param0);
    498 }
    499 
    500 // Test that complex(real(c), imag(c)) is simplified to c.
    501 TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) {
    502   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
    503   Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2});
    504   HloComputation::Builder builder(TestName());
    505   HloInstruction* param0 = builder.AddInstruction(
    506       HloInstruction::CreateParameter(0, r2c64, "param0"));
    507   HloInstruction* real = builder.AddInstruction(
    508       HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, param0));
    509   HloInstruction* imag = builder.AddInstruction(
    510       HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, param0));
    511   HloInstruction* cplx = builder.AddInstruction(
    512       HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag));
    513 
    514   auto computation = module().AddEntryComputation(builder.Build());
    515   HloInstruction* root = computation->root_instruction();
    516   EXPECT_EQ(root, cplx);
    517   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    518                                  non_bitcasting_callback());
    519   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    520   root = computation->root_instruction();
    521   EXPECT_EQ(root, param0);
    522 }
    523 
    524 // Test that real(complex(r,i)) is simplified to r.
    525 TEST_F(AlgebraicSimplifierTest, RealOfComplex) {
    526   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
    527   HloComputation::Builder builder(TestName());
    528   HloInstruction* param0 = builder.AddInstruction(
    529       HloInstruction::CreateParameter(0, r2f32, "param0"));
    530   HloInstruction* param1 = builder.AddInstruction(
    531       HloInstruction::CreateParameter(1, r2f32, "param1"));
    532   HloInstruction* cplx = builder.AddInstruction(
    533       HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
    534                                    HloOpcode::kComplex, param0, param1));
    535   HloInstruction* real = builder.AddInstruction(
    536       HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx));
    537 
    538   auto computation = module().AddEntryComputation(builder.Build());
    539   HloInstruction* root = computation->root_instruction();
    540   EXPECT_EQ(root, real);
    541   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    542                                  non_bitcasting_callback());
    543   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    544   root = computation->root_instruction();
    545   EXPECT_EQ(root, param0);
    546 }
    547 
    548 // Test that imag(complex(r,i)) is simplified to i.
    549 TEST_F(AlgebraicSimplifierTest, ImagOfComplex) {
    550   Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
    551   HloComputation::Builder builder(TestName());
    552   HloInstruction* param0 = builder.AddInstruction(
    553       HloInstruction::CreateParameter(0, r2f32, "param0"));
    554   HloInstruction* param1 = builder.AddInstruction(
    555       HloInstruction::CreateParameter(1, r2f32, "param1"));
    556   HloInstruction* cplx = builder.AddInstruction(
    557       HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
    558                                    HloOpcode::kComplex, param0, param1));
    559   HloInstruction* imag = builder.AddInstruction(
    560       HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx));
    561 
    562   auto computation = module().AddEntryComputation(builder.Build());
    563   HloInstruction* root = computation->root_instruction();
    564   EXPECT_EQ(root, imag);
    565   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    566                                  non_bitcasting_callback());
    567   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    568   root = computation->root_instruction();
    569   EXPECT_EQ(root, param1);
    570 }
    571 
    572 // Test that get_element(make_tuple({A,B}),1) is simplified to B
    573 TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
    574   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    575   HloComputation::Builder builder(TestName());
    576   HloInstruction* param0 = builder.AddInstruction(
    577       HloInstruction::CreateParameter(0, r0f32, "param0"));
    578   HloInstruction* param1 = builder.AddInstruction(
    579       HloInstruction::CreateParameter(1, r0f32, "param1"));
    580   HloInstruction* param2 = builder.AddInstruction(
    581       HloInstruction::CreateParameter(2, r0f32, "param2"));
    582   HloInstruction* tuple =
    583       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
    584   HloInstruction* get = builder.AddInstruction(
    585       HloInstruction::CreateGetTupleElement(r0f32, tuple, 1));
    586   HloInstruction* add = builder.AddInstruction(
    587       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2));
    588 
    589   auto computation = module().AddEntryComputation(builder.Build());
    590   HloInstruction* root = computation->root_instruction();
    591   EXPECT_EQ(root, add);
    592   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    593                                  non_bitcasting_callback());
    594   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    595   root = computation->root_instruction();
    596   EXPECT_THAT(root, op::Add(param1, param2));
    597 }
    598 
    599 // Test that exp(A)/exp(B) is simplified to exp(A-B)
    600 TEST_F(AlgebraicSimplifierTest, ExpDiv) {
    601   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    602   HloComputation::Builder builder(TestName());
    603   HloInstruction* param0 = builder.AddInstruction(
    604       HloInstruction::CreateParameter(0, r0f32, "param0"));
    605   HloInstruction* param1 = builder.AddInstruction(
    606       HloInstruction::CreateParameter(1, r0f32, "param1"));
    607   HloInstruction* exp0 = builder.AddInstruction(
    608       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
    609   HloInstruction* exp1 = builder.AddInstruction(
    610       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
    611   builder.AddInstruction(
    612       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
    613 
    614   auto computation = module().AddEntryComputation(builder.Build());
    615 
    616   EXPECT_THAT(computation->root_instruction(),
    617               op::Divide(op::Exp(param0), op::Exp(param1)));
    618 
    619   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    620                                  non_bitcasting_callback());
    621   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    622 
    623   EXPECT_THAT(computation->root_instruction(),
    624               op::Exp(op::Subtract(param0, param1)));
    625 }
    626 
    627 // Test that exp(A)*exp(B) is simplified to exp(A+B)
    628 TEST_F(AlgebraicSimplifierTest, ExpMul) {
    629   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    630   HloComputation::Builder builder(TestName());
    631   HloInstruction* param0 = builder.AddInstruction(
    632       HloInstruction::CreateParameter(0, r0f32, "param0"));
    633   HloInstruction* param1 = builder.AddInstruction(
    634       HloInstruction::CreateParameter(1, r0f32, "param1"));
    635   HloInstruction* exp0 = builder.AddInstruction(
    636       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
    637   HloInstruction* exp1 = builder.AddInstruction(
    638       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
    639   builder.AddInstruction(
    640       HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1));
    641 
    642   auto computation = module().AddEntryComputation(builder.Build());
    643 
    644   EXPECT_THAT(computation->root_instruction(),
    645               op::Multiply(op::Exp(param0), op::Exp(param1)));
    646 
    647   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    648                                  non_bitcasting_callback());
    649   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    650 
    651   EXPECT_THAT(computation->root_instruction(),
    652               op::Exp(op::Add(param0, param1)));
    653 }
    654 
    655 // Test that pow(exp(A), B) is simplified to exp(A*B)
    656 TEST_F(AlgebraicSimplifierTest, PowExp) {
    657   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    658   HloComputation::Builder builder(TestName());
    659   HloInstruction* param0 = builder.AddInstruction(
    660       HloInstruction::CreateParameter(0, r0f32, "param0"));
    661   HloInstruction* param1 = builder.AddInstruction(
    662       HloInstruction::CreateParameter(1, r0f32, "param1"));
    663   HloInstruction* exp0 = builder.AddInstruction(
    664       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
    665   builder.AddInstruction(
    666       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1));
    667 
    668   auto computation = module().AddEntryComputation(builder.Build());
    669 
    670   EXPECT_THAT(computation->root_instruction(),
    671               op::Power(op::Exp(param0), param1));
    672 
    673   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    674                                  non_bitcasting_callback());
    675   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    676 
    677   EXPECT_THAT(computation->root_instruction(),
    678               op::Exp(op::Multiply(param0, param1)));
    679 }
    680 
    681 // Test that ln(pow(A, B)) is simplified to ln(A)*B
    682 TEST_F(AlgebraicSimplifierTest, LnPow) {
    683   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    684   HloComputation::Builder builder(TestName());
    685   HloInstruction* param0 = builder.AddInstruction(
    686       HloInstruction::CreateParameter(0, r0f32, "param0"));
    687   HloInstruction* param1 = builder.AddInstruction(
    688       HloInstruction::CreateParameter(1, r0f32, "param1"));
    689   HloInstruction* pow = builder.AddInstruction(
    690       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, param1));
    691   builder.AddInstruction(
    692       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow));
    693 
    694   auto computation = module().AddEntryComputation(builder.Build());
    695 
    696   EXPECT_THAT(computation->root_instruction(),
    697               op::Log(op::Power(param0, param1)));
    698 
    699   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    700                                  non_bitcasting_callback());
    701   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    702 
    703   EXPECT_THAT(computation->root_instruction(),
    704               op::Multiply(op::Log(param0), param1));
    705 }
    706 
    707 // Test that ln(exp(A)) is simplified to A
    708 TEST_F(AlgebraicSimplifierTest, LnExp) {
    709   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    710   HloComputation::Builder builder(TestName());
    711   HloInstruction* param0 = builder.AddInstruction(
    712       HloInstruction::CreateParameter(0, r0f32, "param0"));
    713   HloInstruction* exp0 = builder.AddInstruction(
    714       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
    715   builder.AddInstruction(
    716       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0));
    717 
    718   auto computation = module().AddEntryComputation(builder.Build());
    719 
    720   EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0)));
    721 
    722   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    723                                  non_bitcasting_callback());
    724   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    725 
    726   EXPECT_EQ(computation->root_instruction(), param0);
    727 }
    728 
    729 // Test that ln(exp(A)/exp(B)) is simplified to A-B
    730 TEST_F(AlgebraicSimplifierTest, LnExpDiv) {
    731   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    732   HloComputation::Builder builder(TestName());
    733   HloInstruction* param0 = builder.AddInstruction(
    734       HloInstruction::CreateParameter(0, r0f32, "param0"));
    735   HloInstruction* param1 = builder.AddInstruction(
    736       HloInstruction::CreateParameter(1, r0f32, "param1"));
    737   HloInstruction* exp0 = builder.AddInstruction(
    738       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
    739   HloInstruction* exp1 = builder.AddInstruction(
    740       HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
    741   HloInstruction* div = builder.AddInstruction(
    742       HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
    743   builder.AddInstruction(
    744       HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div));
    745 
    746   auto computation = module().AddEntryComputation(builder.Build());
    747 
    748   EXPECT_THAT(computation->root_instruction(),
    749               op::Log(op::Divide(op::Exp(param0), op::Exp(param1))));
    750 
    751   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    752                                  non_bitcasting_callback());
    753   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    754 
    755   EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1));
    756 }
    757 
    758 // Test that pow(A, 0) where A is a scalar is simplified to the scalar
    759 // constant 1.
    760 TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
    761   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    762   HloComputation::Builder builder(TestName());
    763   HloInstruction* param0 = builder.AddInstruction(
    764       HloInstruction::CreateParameter(0, r0f32, "param0"));
    765   HloInstruction* zero = builder.AddInstruction(
    766       HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
    767   builder.AddInstruction(
    768       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
    769 
    770   auto computation = module().AddEntryComputation(builder.Build());
    771 
    772   EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero));
    773 
    774   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    775                                  non_bitcasting_callback());
    776   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    777 
    778   HloInstruction* root = computation->root_instruction();
    779   EXPECT_THAT(root, op::Constant());
    780   EXPECT_EQ(root->literal().GetFirstElement<float>(), 1);
    781 }
    782 
    783 // Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1).
    784 TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
    785   Shape r1f32 = ShapeUtil::MakeShape(F32, {42});
    786   HloComputation::Builder builder(TestName());
    787   HloInstruction* param0 = builder.AddInstruction(
    788       HloInstruction::CreateParameter(0, r1f32, "param0"));
    789   HloInstruction* zero = builder.AddInstruction(
    790       HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
    791   builder.AddInstruction(
    792       HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
    793 
    794   auto computation = module().AddEntryComputation(builder.Build());
    795 
    796   EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero));
    797 
    798   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    799                                  non_bitcasting_callback());
    800   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    801 
    802   HloInstruction* root = computation->root_instruction();
    803   EXPECT_THAT(root, op::Broadcast());
    804   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32))
    805       << ShapeUtil::HumanString(root->shape());
    806   EXPECT_EQ(root->dimensions().size(), 0);
    807   EXPECT_TRUE(ShapeUtil::IsScalar(root->operand(0)->shape()));
    808   EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
    809 }
    810 
    811 // Test that pow(A, 1) is simplified to A.
    812 TEST_F(AlgebraicSimplifierTest, Pow1) {
    813   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    814   HloComputation::Builder builder(TestName());
    815   HloInstruction* param0 = builder.AddInstruction(
    816       HloInstruction::CreateParameter(0, r0f32, "param0"));
    817   HloInstruction* one = builder.AddInstruction(
    818       HloInstruction::CreateConstant(Literal::CreateR0<float>(1)));
    819   builder.AddInstruction(
    820       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
    821 
    822   auto computation = module().AddEntryComputation(builder.Build());
    823 
    824   EXPECT_THAT(computation->root_instruction(), op::Power(param0, one));
    825 
    826   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    827                                  non_bitcasting_callback());
    828   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    829 
    830   EXPECT_EQ(computation->root_instruction(), param0);
    831 }
    832 
    833 // Test that pow(A, 2) is simplified to A*A.
    834 TEST_F(AlgebraicSimplifierTest, Pow2) {
    835   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    836   HloComputation::Builder builder(TestName());
    837   HloInstruction* param0 = builder.AddInstruction(
    838       HloInstruction::CreateParameter(0, r0f32, "param0"));
    839   HloInstruction* two = builder.AddInstruction(
    840       HloInstruction::CreateConstant(Literal::CreateR0<float>(2)));
    841   builder.AddInstruction(
    842       HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
    843 
    844   auto computation = module().AddEntryComputation(builder.Build());
    845 
    846   EXPECT_THAT(computation->root_instruction(), op::Power(param0, two));
    847 
    848   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    849                                  non_bitcasting_callback());
    850   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    851 
    852   EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0));
    853 }
    854 
    855 // Test that pow(A, -1) is simplified to 1/A.
    856 TEST_F(AlgebraicSimplifierTest, PowNegative1) {
    857   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    858   HloComputation::Builder builder(TestName());
    859   HloInstruction* param0 = builder.AddInstruction(
    860       HloInstruction::CreateParameter(0, r0f32, "param0"));
    861   HloInstruction* negative_one = builder.AddInstruction(
    862       HloInstruction::CreateConstant(Literal::CreateR0<float>(-1)));
    863   builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
    864                                                       param0, negative_one));
    865 
    866   auto computation = module().AddEntryComputation(builder.Build());
    867 
    868   EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one));
    869 
    870   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
    871                                  non_bitcasting_callback());
    872   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    873 
    874   HloInstruction* root = computation->root_instruction();
    875   EXPECT_THAT(root, op::Divide(op::Broadcast(), param0));
    876   EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kBroadcast);
    877   EXPECT_EQ(root->operand(0)->operand(0)->literal().GetFirstElement<float>(),
    878             1);
    879 }
    880 
    881 TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
    882   auto builder = HloComputation::Builder(TestName());
    883   HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter(
    884       0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs"));
    885 
    886   HloInstruction* rhs = builder.AddInstruction(HloInstruction::CreateParameter(
    887       1, ShapeUtil::MakeShape(F32, {3, 0, 3}), "rhs"));
    888 
    889   ConvolutionDimensionNumbers dnums;
    890   dnums.set_input_batch_dimension(0);
    891   dnums.add_input_spatial_dimensions(1);
    892   dnums.set_input_feature_dimension(2);
    893 
    894   dnums.set_output_batch_dimension(0);
    895   dnums.add_output_spatial_dimensions(1);
    896   dnums.set_output_feature_dimension(2);
    897 
    898   dnums.add_kernel_spatial_dimensions(0);
    899   dnums.set_kernel_input_feature_dimension(1);
    900   dnums.set_kernel_output_feature_dimension(2);
    901   Window window;
    902   WindowDimension* dim = window.add_dimensions();
    903   dim->set_size(3);
    904   dim->set_padding_low(0);
    905   dim->set_padding_high(0);
    906   dim->set_stride(1);
    907   dim->set_window_dilation(1);
    908   dim->set_base_dilation(1);
    909   dim->set_window_reversal(false);
    910   // Create add computation.
    911   builder.AddInstruction(HloInstruction::CreateConvolve(
    912       ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums));
    913   module().AddEntryComputation(builder.Build());
    914   HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
    915                                              non_bitcasting_callback());
    916   EXPECT_THAT(module().entry_computation()->root_instruction(),
    917               op::Convolution(lhs, rhs));
    918   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    919   EXPECT_THAT(module().entry_computation()->root_instruction(),
    920               op::Broadcast(op::Constant()));
    921 }
    922 
    923 TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) {
    924   auto builder = HloComputation::Builder(TestName());
    925   HloInstruction* param =
    926       builder.AddInstruction(HloInstruction::CreateParameter(
    927           0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
    928   Window window;
    929   for (int64 i = 0; i < 2; ++i) {
    930     WindowDimension* dim = window.add_dimensions();
    931     dim->set_size(1);
    932     dim->set_padding_low(1);
    933     dim->set_padding_high(1);
    934     dim->set_window_dilation(1);
    935     dim->set_base_dilation(1);
    936   }
    937   // Create add computation.
    938   HloComputation* add_computation = nullptr;
    939   {
    940     HloComputation::Builder builder(TestName() + ".add");
    941     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
    942     HloInstruction* p0 = builder.AddInstruction(
    943         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
    944     HloInstruction* p1 = builder.AddInstruction(
    945         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
    946     builder.AddInstruction(
    947         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
    948     add_computation = module().AddEmbeddedComputation(builder.Build());
    949   }
    950   builder.AddInstruction(HloInstruction::CreateReduceWindow(
    951       ShapeUtil::MakeShape(F32, {5, 2}), param,
    952       builder.AddInstruction(
    953           HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))),
    954       window, add_computation));
    955   module().AddEntryComputation(builder.Build());
    956   HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
    957                                              non_bitcasting_callback());
    958   EXPECT_THAT(module().entry_computation()->root_instruction(),
    959               op::ReduceWindow(param, op::Constant()));
    960   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    961   EXPECT_THAT(module().entry_computation()->root_instruction(),
    962               op::Broadcast(op::Constant()));
    963 }
    964 
    965 TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) {
    966   auto builder = HloComputation::Builder(TestName());
    967   HloInstruction* param =
    968       builder.AddInstruction(HloInstruction::CreateParameter(
    969           0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
    970   PaddingConfig padding;
    971   for (int i = 0; i < 2; ++i) {
    972     PaddingConfig::PaddingConfigDimension* dimension = padding.add_dimensions();
    973     dimension->set_edge_padding_low(1);
    974     dimension->set_edge_padding_high(1);
    975     dimension->set_interior_padding(0);
    976   }
    977   builder.AddInstruction(HloInstruction::CreatePad(
    978       ShapeUtil::MakeShape(F32, {5, 2}), param,
    979       builder.AddInstruction(
    980           HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
    981       padding));
    982   module().AddEntryComputation(builder.Build());
    983   EXPECT_THAT(module().entry_computation()->root_instruction(),
    984               op::Pad(param, op::Constant()));
    985   HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
    986                                              non_bitcasting_callback());
    987   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
    988   EXPECT_THAT(module().entry_computation()->root_instruction(),
    989               op::Broadcast(op::Constant()));
    990 }
    991 
    992 TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
    993   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
    994 
    995   auto builder = HloComputation::Builder(TestName());
    996   auto op = builder.AddInstruction(HloInstruction::CreateParameter(
    997       0, ShapeUtil::MakeShape(F32, {3, 2}), "op"));
    998   auto reshape1 = builder.AddInstruction(
    999       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), op));
   1000   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
   1001       ShapeUtil::MakeShape(F32, {1, 6}), reshape1, {1}));
   1002   builder.AddInstruction(HloInstruction::CreateReshape(
   1003       ShapeUtil::MakeShape(F32, {3, 2}), broadcast));
   1004 
   1005   auto computation = builder.Build();
   1006   module().AddEntryComputation(std::move(computation));
   1007 
   1008   EXPECT_THAT(module().entry_computation()->root_instruction(),
   1009               op::Reshape(op::Broadcast(op::Reshape(op))));
   1010 
   1011   HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
   1012                                              non_bitcasting_callback());
   1013   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1014 
   1015   EXPECT_THAT(module().entry_computation()->root_instruction(), op);
   1016 }
   1017 
   1018 // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE.
   1019 TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
   1020   HloComputation::Builder builder(TestName());
   1021   HloInstruction* input = builder.AddInstruction(
   1022       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
   1023   builder.AddInstruction(
   1024       HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
   1025 
   1026   auto computation = module().AddEntryComputation(builder.Build());
   1027 
   1028   EXPECT_THAT(computation->root_instruction(), op::Convert(input));
   1029 
   1030   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1031                                  non_bitcasting_callback());
   1032   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1033 
   1034   EXPECT_THAT(computation->root_instruction(), input);
   1035 }
   1036 
   1037 // Test that copies are removed.
   1038 TEST_F(AlgebraicSimplifierTest, RemoveCopy) {
   1039   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
   1040   HloComputation::Builder builder(TestName());
   1041   HloInstruction* param0 = builder.AddInstruction(
   1042       HloInstruction::CreateParameter(0, r0f32, "param0"));
   1043   builder.AddInstruction(
   1044       HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
   1045 
   1046   auto computation = module().AddEntryComputation(builder.Build());
   1047 
   1048   EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
   1049 
   1050   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1051                                  non_bitcasting_callback());
   1052   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1053 
   1054   EXPECT_THAT(computation->root_instruction(), param0);
   1055 }
   1056 
   1057 // Test that unary concatenates are removed.
   1058 TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) {
   1059   Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
   1060   HloComputation::Builder builder(TestName());
   1061   HloInstruction* param0 = builder.AddInstruction(
   1062       HloInstruction::CreateParameter(0, r1f32, "param0"));
   1063   builder.AddInstruction(
   1064       HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0));
   1065 
   1066   auto computation = module().AddEntryComputation(builder.Build());
   1067 
   1068   EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0));
   1069 
   1070   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1071                                  non_bitcasting_callback());
   1072   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1073 
   1074   EXPECT_THAT(computation->root_instruction(), param0);
   1075 }
   1076 
   1077 // Test that empty operands of concatenates are removed.
   1078 TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
   1079   const int kParamLength = 100;
   1080   Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength});
   1081   HloComputation::Builder builder(TestName());
   1082   HloInstruction* param0 = builder.AddInstruction(
   1083       HloInstruction::CreateParameter(0, r1f32, "param0"));
   1084   HloInstruction* param1 = builder.AddInstruction(
   1085       HloInstruction::CreateParameter(1, r1f32, "param1"));
   1086   HloInstruction* empty_literal = builder.AddInstruction(
   1087       HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
   1088   HloInstruction* empty_slice =
   1089       builder.AddInstruction(HloInstruction::CreateSlice(
   1090           ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
   1091   Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength});
   1092   builder.AddInstruction(HloInstruction::CreateConcatenate(
   1093       result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0));
   1094 
   1095   auto computation = module().AddEntryComputation(builder.Build());
   1096 
   1097   EXPECT_THAT(
   1098       computation->root_instruction(),
   1099       op::Concatenate(empty_literal, param0, param0, empty_slice, param1));
   1100 
   1101   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1102                                  non_bitcasting_callback());
   1103   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1104 
   1105   EXPECT_THAT(computation->root_instruction(),
   1106               op::Concatenate(param0, param0, param1));
   1107 }
   1108 
   1109 // Test a concatenate with only empty operands is removed.
   1110 TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
   1111   const int kParamLength = 100;
   1112   Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength});
   1113   HloComputation::Builder builder(TestName());
   1114   HloInstruction* param0 = builder.AddInstruction(
   1115       HloInstruction::CreateParameter(0, r1f32, "param0"));
   1116   HloInstruction* empty_literal = builder.AddInstruction(
   1117       HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
   1118   HloInstruction* empty_slice =
   1119       builder.AddInstruction(HloInstruction::CreateSlice(
   1120           ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
   1121   Shape result_shape = ShapeUtil::MakeShape(F32, {0});
   1122   builder.AddInstruction(HloInstruction::CreateConcatenate(
   1123       result_shape, {empty_literal, empty_slice}, 0));
   1124 
   1125   auto computation = module().AddEntryComputation(builder.Build());
   1126 
   1127   EXPECT_THAT(computation->root_instruction(),
   1128               op::Concatenate(empty_literal, empty_slice));
   1129 
   1130   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1131                                  non_bitcasting_callback());
   1132   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1133 
   1134   EXPECT_EQ(computation->root_instruction(), empty_literal);
   1135 }
   1136 
   1137 // Test that concat with a scalar broadcast becomes a pad.
   1138 TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) {
   1139   Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
   1140   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
   1141   HloComputation::Builder builder(TestName());
   1142   HloInstruction* param0 = builder.AddInstruction(
   1143       HloInstruction::CreateParameter(0, r1f32, "param0"));
   1144   HloInstruction* param1 = builder.AddInstruction(
   1145       HloInstruction::CreateParameter(1, r0f32, "param1"));
   1146   HloInstruction* broadcast = builder.AddInstruction(
   1147       HloInstruction::CreateBroadcast(r1f32, param1, {}));
   1148   builder.AddInstruction(HloInstruction::CreateConcatenate(
   1149       ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0));
   1150 
   1151   auto computation = module().AddEntryComputation(builder.Build());
   1152 
   1153   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1154                                  non_bitcasting_callback());
   1155   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1156   EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1));
   1157 }
   1158 
   1159 // Test that a simplification which changes layouts is not performed if layout
   1160 // sensitive is true.
   1161 TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) {
   1162   HloComputation::Builder builder(TestName());
   1163   HloInstruction* param0 =
   1164       builder.AddInstruction(HloInstruction::CreateParameter(
   1165           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
   1166   HloInstruction* copy = builder.AddInstruction(
   1167       HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
   1168 
   1169   auto computation = module().AddEntryComputation(builder.Build());
   1170 
   1171   // Set to different layouts.
   1172   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
   1173   *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
   1174 
   1175   EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
   1176 
   1177   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
   1178                                  non_bitcasting_callback());
   1179   EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
   1180 
   1181   // Copy has not been removed.
   1182   EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
   1183 }
   1184 
   1185 // Test that a simplification which preserves layouts is performed if layout
   1186 // sensitive is true.
   1187 TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) {
   1188   HloComputation::Builder builder(TestName());
   1189   HloInstruction* param0 =
   1190       builder.AddInstruction(HloInstruction::CreateParameter(
   1191           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
   1192   HloInstruction* copy = builder.AddInstruction(
   1193       HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
   1194 
   1195   auto computation = module().AddEntryComputation(builder.Build());
   1196 
   1197   // Set to same layouts.
   1198   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
   1199   *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
   1200 
   1201   EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
   1202 
   1203   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
   1204                                  non_bitcasting_callback());
   1205   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1206 
   1207   // Copy has been removed.
   1208   EXPECT_THAT(computation->root_instruction(), param0);
   1209 }
   1210 
   1211 // Test that a reshape which could be replaced with a bitcast is not if
   1212 // add_bitcasts is false.
   1213 TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) {
   1214   HloComputation::Builder builder(TestName());
   1215   HloInstruction* param0 =
   1216       builder.AddInstruction(HloInstruction::CreateParameter(
   1217           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
   1218   HloInstruction* reshape =
   1219       builder.AddInstruction(HloInstruction::CreateReshape(
   1220           ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
   1221 
   1222   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
   1223   *reshape->mutable_shape()->mutable_layout() =
   1224       LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
   1225 
   1226   auto computation = module().AddEntryComputation(builder.Build());
   1227 
   1228   EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
   1229 
   1230   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
   1231                                  non_bitcasting_callback());
   1232   EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
   1233 
   1234   // Reshape is not replaced with a bitcast.
   1235   EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
   1236 }
   1237 
   1238 // Test transforming reshapes to bitcasts under various conditions.
   1239 TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
   1240   HloComputation::Builder builder(TestName());
   1241   HloInstruction* param0 =
   1242       builder.AddInstruction(HloInstruction::CreateParameter(
   1243           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
   1244   *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
   1245 
   1246   // Reshape which can be transformed into a bitcast.
   1247   HloInstruction* transformable_reshape =
   1248       builder.AddInstruction(HloInstruction::CreateReshape(
   1249           ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
   1250   *transformable_reshape->mutable_shape()->mutable_layout() =
   1251       LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
   1252 
   1253   // Reshape does not just add degenerate dimensions.
   1254   HloInstruction* dimensions_wrong_reshape =
   1255       builder.AddInstruction(HloInstruction::CreateReshape(
   1256           ShapeUtil::MakeShape(F32, {1, 4, 1, 1, 1, 1}), param0));
   1257   *dimensions_wrong_reshape->mutable_shape()->mutable_layout() =
   1258       LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
   1259 
   1260   // Reshape has wrong layout.
   1261   HloInstruction* layout_wrong_reshape =
   1262       builder.AddInstruction(HloInstruction::CreateReshape(
   1263           ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
   1264   *layout_wrong_reshape->mutable_shape()->mutable_layout() =
   1265       LayoutUtil::MakeLayout({5, 4, 3, 2, 1, 0});
   1266 
   1267   // Collect all the reshapes into a tuple so they are not dead.
   1268   builder.AddInstruction(HloInstruction::CreateTuple(
   1269       {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape}));
   1270 
   1271   auto computation = module().AddEntryComputation(builder.Build());
   1272 
   1273   EXPECT_THAT(computation->root_instruction(),
   1274               op::Tuple(transformable_reshape, dimensions_wrong_reshape,
   1275                         layout_wrong_reshape));
   1276 
   1277   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
   1278                                  bitcasting_callback());
   1279   simplifier.Run(&module()).ValueOrDie();
   1280 
   1281   // Verify that only the first reshape is replaced.
   1282   EXPECT_THAT(
   1283       computation->root_instruction(),
   1284       op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape));
   1285 }
   1286 
   1287 TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) {
   1288   HloComputation::Builder builder(TestName());
   1289   HloInstruction* param =
   1290       builder.AddInstruction(HloInstruction::CreateParameter(
   1291           0, ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), "param"));
   1292   HloInstruction* movable_reshape =
   1293       builder.AddInstruction(HloInstruction::CreateReshape(
   1294           ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), param));
   1295   HloInstruction* zero = builder.AddInstruction(
   1296       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
   1297   builder.AddInstruction(
   1298       HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}),
   1299                                    HloOpcode::kMaximum, movable_reshape, zero));
   1300   auto computation = module().AddEntryComputation(builder.Build());
   1301 
   1302   EXPECT_THAT(computation->root_instruction(),
   1303               op::Maximum(op::Reshape(param), zero));
   1304 
   1305   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1306                                  bitcasting_callback());
   1307 
   1308   simplifier.Run(&module()).ValueOrDie();
   1309   EXPECT_THAT(computation->root_instruction(),
   1310               op::Reshape(op::Maximum(param, zero)));
   1311 }
   1312 
   1313 // Regression test for a bug in the reshape sinking transformation, where
   1314 // moving a reshape to a scalar led to a crash.
   1315 TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) {
   1316   HloComputation::Builder builder(TestName());
   1317   HloInstruction* param =
   1318       builder.AddInstruction(HloInstruction::CreateParameter(
   1319           0, ShapeUtil::MakeShape(F32, {1, 1}), "param"));
   1320   HloInstruction* reshape = builder.AddInstruction(
   1321       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), param));
   1322   HloInstruction* zero = builder.AddInstruction(
   1323       HloInstruction::CreateConstant(Literal::CreateR1<float>({1., 2., 3.})));
   1324   builder.AddInstruction(HloInstruction::CreateBinary(
   1325       ShapeUtil::MakeShape(F32, {3}), HloOpcode::kMaximum, reshape, zero));
   1326   auto computation = module().AddEntryComputation(builder.Build());
   1327 
   1328   EXPECT_THAT(computation->root_instruction(),
   1329               op::Maximum(op::Reshape(param), zero));
   1330 
   1331   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1332                                  bitcasting_callback());
   1333 
   1334   simplifier.Run(&module()).ValueOrDie();
   1335 
   1336   EXPECT_THAT(computation->root_instruction(),
   1337               op::Maximum(op::Reshape(param), zero));
   1338 }
   1339 
   1340 // Regression test for a bug where if we failed to sink a reshape, we'd set the
   1341 // 'changed' bit in AlgebraicSimplifier to false.
   1342 TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) {
   1343   HloComputation::Builder builder(TestName());
   1344 
   1345   // This add (param0 + 0) can be simplified.
   1346   Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
   1347   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
   1348       shape, HloOpcode::kAdd,
   1349       builder.AddInstruction(
   1350           HloInstruction::CreateParameter(0, shape, "param0")),
   1351       builder.AddInstruction(HloInstruction::CreateConstant(
   1352           Literal::CreateR2<float>({{0, 0}, {0, 0}})))));
   1353 
   1354   builder.AddInstruction(
   1355       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add));
   1356 
   1357   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1358                                  bitcasting_callback());
   1359   module().AddEntryComputation(builder.Build());
   1360   EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1361 }
   1362 
   1363 // Regression test for a bug where if we failed to sink a reshape, we'd set the
   1364 // 'changed' bit in AlgebraicSimplifier to false.
   1365 TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) {
   1366   HloComputation::Builder builder(TestName());
   1367 
   1368   // This add (param0 + 0) can be simplified.
   1369   Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
   1370   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
   1371       shape, HloOpcode::kAdd,
   1372       builder.AddInstruction(
   1373           HloInstruction::CreateParameter(0, shape, "param0")),
   1374       builder.AddInstruction(HloInstruction::CreateConstant(
   1375           Literal::CreateR2<float>({{0, 0}, {0, 0}})))));
   1376 
   1377   builder.AddInstruction(
   1378       HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add,
   1379                                       /*broadcast_dimensions=*/{0, 1}));
   1380 
   1381   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1382                                  bitcasting_callback());
   1383   module().AddEntryComputation(builder.Build());
   1384   EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1385 }
   1386 
   1387 TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
   1388   HloComputation::Builder builder(TestName());
   1389   HloInstruction* param =
   1390       builder.AddInstruction(HloInstruction::CreateParameter(
   1391           0, ShapeUtil::MakeShape(F32, {50, 14, 14, 64}), "param"));
   1392   *param->mutable_shape()->mutable_layout() =
   1393       LayoutUtil::MakeLayout({1, 2, 0, 3});
   1394 
   1395   HloInstruction* transpose =
   1396       builder.AddInstruction(HloInstruction::CreateTranspose(
   1397           ShapeUtil::MakeShape(F32, {14, 14, 50, 64}), param, {1, 2, 0, 3}));
   1398   *transpose->mutable_shape()->mutable_layout() =
   1399       LayoutUtil::MakeLayout({0, 1, 2, 3});
   1400 
   1401   auto computation = module().AddEntryComputation(builder.Build());
   1402 
   1403   EXPECT_THAT(computation->root_instruction(), op::Transpose(param));
   1404 
   1405   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
   1406                                  bitcasting_callback());
   1407   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1408 
   1409   // Verify that the reshape is replaced.
   1410   EXPECT_THAT(computation->root_instruction(), op::Bitcast(param));
   1411 }
   1412 
   1413 TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) {
   1414   HloComputation::Builder builder(TestName());
   1415   HloInstruction* param =
   1416       builder.AddInstruction(HloInstruction::CreateParameter(
   1417           0, ShapeUtil::MakeShape(F32, {5, 2, 3, 4}), "param"));
   1418   *param->mutable_shape()->mutable_layout() =
   1419       LayoutUtil::MakeLayout({1, 2, 3, 0});
   1420 
   1421   HloInstruction* transpose =
   1422       builder.AddInstruction(HloInstruction::CreateTranspose(
   1423           ShapeUtil::MakeShape(F32, {5, 3, 4, 2}), param, {0, 2, 3, 1}));
   1424   *transpose->mutable_shape()->mutable_layout() =
   1425       LayoutUtil::MakeLayout({3, 1, 2, 0});
   1426 
   1427   auto computation = module().AddEntryComputation(builder.Build());
   1428 
   1429   EXPECT_THAT(computation->root_instruction(), op::Transpose(param));
   1430 
   1431   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
   1432                                  bitcasting_callback());
   1433   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1434 
   1435   // Verify that the reshape is replaced.
   1436   EXPECT_THAT(computation->root_instruction(), op::Bitcast(param));
   1437 }
   1438 
   1439 TEST_F(AlgebraicSimplifierTest, ReshapesMerged) {
   1440   HloComputation::Builder builder(TestName());
   1441   HloInstruction* param0 =
   1442       builder.AddInstruction(HloInstruction::CreateParameter(
   1443           0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
   1444 
   1445   HloInstruction* reshape1 =
   1446       builder.AddInstruction(HloInstruction::CreateReshape(
   1447           ShapeUtil::MakeShape(F32, {2, 1, 2}), param0));
   1448 
   1449   builder.AddInstruction(HloInstruction::CreateReshape(
   1450       ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1));
   1451 
   1452   auto computation = module().AddEntryComputation(builder.Build());
   1453 
   1454   EXPECT_THAT(computation->root_instruction(),
   1455               op::Reshape(op::Reshape(param0)));
   1456 
   1457   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1458                                  non_bitcasting_callback());
   1459   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1460 
   1461   EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
   1462 }
   1463 
   1464 TEST_F(AlgebraicSimplifierTest, CopiesMerged) {
   1465   HloComputation::Builder builder(TestName());
   1466   HloInstruction* param0 =
   1467       builder.AddInstruction(HloInstruction::CreateParameter(
   1468           0, ShapeUtil::MakeShapeWithDescendingLayout(F32, {2, 2, 2}),
   1469           "param0"));
   1470 
   1471   HloInstruction* copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
   1472       ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
   1473       HloOpcode::kCopy, param0));
   1474 
   1475   builder.AddInstruction(HloInstruction::CreateUnary(
   1476       ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}),
   1477       HloOpcode::kCopy, copy1));
   1478 
   1479   auto computation = module().AddEntryComputation(builder.Build());
   1480 
   1481   EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0)));
   1482 
   1483   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
   1484                                  non_bitcasting_callback());
   1485   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1486 
   1487   EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
   1488 }
   1489 
   1490 TEST_F(AlgebraicSimplifierTest, TransposesMerged) {
   1491   HloComputation::Builder builder(TestName());
   1492   HloInstruction* param0 =
   1493       builder.AddInstruction(HloInstruction::CreateParameter(
   1494           0, ShapeUtil::MakeShape(F32, {2, 3, 4}), "param0"));
   1495 
   1496   HloInstruction* transpose1 =
   1497       builder.AddInstruction(HloInstruction::CreateTranspose(
   1498           ShapeUtil::MakeShape(F32, {3, 4, 2}), param0, {1, 2, 0}));
   1499 
   1500   builder.AddInstruction(HloInstruction::CreateTranspose(
   1501       ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2}));
   1502 
   1503   auto computation = module().AddEntryComputation(builder.Build());
   1504 
   1505   EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1));
   1506 
   1507   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1508                                  non_bitcasting_callback());
   1509   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1510 
   1511   EXPECT_THAT(computation->root_instruction(), op::Transpose(param0));
   1512   EXPECT_EQ(std::vector<int64>({2, 1, 0}),
   1513             computation->root_instruction()->dimensions());
   1514 }
   1515 
   1516 // Test merging reshape and broadcast.
   1517 TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) {
   1518   HloComputation::Builder builder(TestName());
   1519   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
   1520       0, ShapeUtil::MakeShape(F32, {5}), "param0"));
   1521   auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
   1522       ShapeUtil::MakeShape(F32, {1, 5, 1}), param0));
   1523   builder.AddInstruction(HloInstruction::CreateBroadcast(
   1524       ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2}));
   1525 
   1526   auto computation = module().AddEntryComputation(builder.Build());
   1527 
   1528   EXPECT_THAT(computation->root_instruction(),
   1529               op::Broadcast(op::Reshape(param0)));
   1530 
   1531   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1532                                  non_bitcasting_callback());
   1533   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1534 
   1535   EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0));
   1536 }
   1537 
   1538 // Test merging broadcast and reshape.
   1539 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) {
   1540   HloComputation::Builder builder(TestName());
   1541   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
   1542       0, ShapeUtil::MakeShape(F32, {2, 3}), "param0"));
   1543   auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
   1544       ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), param0, {1, 2}));
   1545   builder.AddInstruction(HloInstruction::CreateReshape(
   1546       ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1));
   1547 
   1548   auto computation = module().AddEntryComputation(builder.Build());
   1549 
   1550   EXPECT_THAT(computation->root_instruction(),
   1551               op::Reshape(op::Broadcast(param0)));
   1552 
   1553   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1554                                  non_bitcasting_callback());
   1555   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1556 
   1557   EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0));
   1558 }
   1559 
   1560 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) {
   1561   HloComputation::Builder builder(TestName());
   1562   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
   1563       0, ShapeUtil::MakeShape(F32, {1}), "param"));
   1564   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
   1565       ShapeUtil::MakeShape(F32, {3, 1}), param, {1}));
   1566   builder.AddInstruction(
   1567       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast));
   1568 
   1569   auto computation = module().AddEntryComputation(builder.Build());
   1570 
   1571   EXPECT_THAT(computation->root_instruction(),
   1572               op::Reshape(op::Broadcast(param)));
   1573 
   1574   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1575                                  non_bitcasting_callback());
   1576   EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
   1577 
   1578   EXPECT_THAT(computation->root_instruction(),
   1579               op::Reshape(op::Broadcast(param)));
   1580 }
   1581 
   1582 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) {
   1583   HloComputation::Builder builder(TestName());
   1584   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
   1585       0, ShapeUtil::MakeShape(F32, {4}), "param"));
   1586   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
   1587       ShapeUtil::MakeShape(F32, {3, 2, 4}), param, {2}));
   1588   builder.AddInstruction(HloInstruction::CreateReshape(
   1589       ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast));
   1590 
   1591   HloComputation* computation = module().AddEntryComputation(builder.Build());
   1592 
   1593   EXPECT_THAT(computation->root_instruction(),
   1594               op::Reshape(op::Broadcast(param)));
   1595 
   1596   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1597                                  non_bitcasting_callback());
   1598   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1599 
   1600   EXPECT_THAT(computation->root_instruction(), op::Broadcast(param));
   1601   EXPECT_THAT(computation->root_instruction()->dimensions(),
   1602               ::testing::ElementsAre(3));
   1603 }
   1604 
   1605 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) {
   1606   HloComputation::Builder builder(TestName());
   1607   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
   1608       0, ShapeUtil::MakeShape(F32, {1}), "param"));
   1609   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
   1610       ShapeUtil::MakeShape(F32, {3, 2, 1}), param, {2}));
   1611   builder.AddInstruction(HloInstruction::CreateReshape(
   1612       ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast));
   1613 
   1614   HloComputation* computation = module().AddEntryComputation(builder.Build());
   1615 
   1616   EXPECT_THAT(computation->root_instruction(),
   1617               op::Reshape(op::Broadcast(param)));
   1618 
   1619   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1620                                  non_bitcasting_callback());
   1621   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   1622 
   1623   EXPECT_THAT(computation->root_instruction(), op::Broadcast(param));
   1624   const std::vector<int64> broadcast_dims =
   1625       computation->root_instruction()->dimensions();
   1626   EXPECT_EQ(1, broadcast_dims.size());
   1627   EXPECT_THAT(broadcast_dims[0], ::testing::AnyOf(1, 2, 3));
   1628 }
   1629 
   1630 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
   1631   HloComputation::Builder builder(TestName());
   1632   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
   1633       0, ShapeUtil::MakeShape(F32, {4}), "param"));
   1634   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
   1635       ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), param, {2}));
   1636   builder.AddInstruction(HloInstruction::CreateReshape(
   1637       ShapeUtil::MakeShape(F32, {6, 8}), broadcast));
   1638 
   1639   HloComputation* computation = module().AddEntryComputation(builder.Build());
   1640 
   1641   EXPECT_THAT(computation->root_instruction(),
   1642               op::Reshape(op::Broadcast(param)));
   1643 
   1644   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1645                                  non_bitcasting_callback());
   1646   EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
   1647 
   1648   EXPECT_THAT(computation->root_instruction(),
   1649               op::Reshape(op::Broadcast(param)));
   1650 }
   1651 
   1652 TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
   1653   HloComputation::Builder builder(TestName());
   1654   HloInstruction* param =
   1655       builder.AddInstruction(HloInstruction::CreateParameter(
   1656           0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
   1657   HloInstruction* zero = builder.AddInstruction(
   1658       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
   1659   PaddingConfig no_padding;
   1660   for (int i = 0; i < 2; ++i) {
   1661     auto dimension = no_padding.add_dimensions();
   1662     dimension->set_edge_padding_low(0);
   1663     dimension->set_edge_padding_high(0);
   1664     dimension->set_interior_padding(0);
   1665   }
   1666   builder.AddInstruction(HloInstruction::CreatePad(
   1667       ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding));
   1668 
   1669   HloModule module(TestName());
   1670   HloComputation* computation = module.AddEntryComputation(builder.Build());
   1671 
   1672   EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
   1673 
   1674   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1675                                  non_bitcasting_callback());
   1676   ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
   1677 
   1678   EXPECT_THAT(computation->root_instruction(), param);
   1679 }
   1680 
   1681 TEST_F(AlgebraicSimplifierTest, NegativePadding) {
   1682   // Verify that a pad instruction with negative padding is replaced with a
   1683   // pad with non-negative padding followed by a slice.
   1684   HloComputation::Builder builder(TestName());
   1685   HloInstruction* param =
   1686       builder.AddInstruction(HloInstruction::CreateParameter(
   1687           0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
   1688   HloInstruction* zero = builder.AddInstruction(
   1689       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
   1690   PaddingConfig padding;
   1691   int64 low_padding[2] = {-1, -2};
   1692   int64 high_padding[2] = {2, -3};
   1693   for (int i = 0; i < 2; ++i) {
   1694     auto dimension = padding.add_dimensions();
   1695     dimension->set_edge_padding_low(low_padding[i]);
   1696     dimension->set_edge_padding_high(high_padding[i]);
   1697     dimension->set_interior_padding(0);
   1698   }
   1699   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
   1700       ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding));
   1701 
   1702   HloModule module(TestName());
   1703   HloComputation* computation = module.AddEntryComputation(builder.Build());
   1704 
   1705   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1706                                  non_bitcasting_callback());
   1707 
   1708   auto has_negative_padding = [](const HloInstruction* pad) {
   1709     for (auto& padding_dimension : pad->padding_config().dimensions()) {
   1710       if (padding_dimension.edge_padding_low() < 0 ||
   1711           padding_dimension.edge_padding_high() < 0) {
   1712         return true;
   1713       }
   1714     }
   1715     return false;
   1716   };
   1717 
   1718   EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
   1719   EXPECT_TRUE(has_negative_padding(pad));
   1720 
   1721   ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
   1722 
   1723   EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero)));
   1724   EXPECT_FALSE(
   1725       has_negative_padding(computation->root_instruction()->operand(0)));
   1726 }
   1727 
   1728 TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
   1729   HloComputation::Builder builder(TestName());
   1730   HloInstruction* param =
   1731       builder.AddInstruction(HloInstruction::CreateParameter(
   1732           0, ShapeUtil::MakeShape(F32, {2, 3}), "param"));
   1733   builder.AddInstruction(
   1734       HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param));
   1735 
   1736   HloModule module(TestName());
   1737   HloComputation* computation = module.AddEntryComputation(builder.Build());
   1738 
   1739   EXPECT_THAT(computation->root_instruction(), op::Reshape(param));
   1740 
   1741   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1742                                  non_bitcasting_callback());
   1743   ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
   1744 
   1745   EXPECT_THAT(computation->root_instruction(), param);
   1746 }
   1747 
   1748 TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
   1749   HloComputation::Builder builder(TestName());
   1750   const int64 dim0 = 2;
   1751   const int64 dim1 = 3;
   1752   HloInstruction* param =
   1753       builder.AddInstruction(HloInstruction::CreateParameter(
   1754           0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
   1755   builder.AddInstruction(HloInstruction::CreateSlice(
   1756       ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
   1757       /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1}));
   1758 
   1759   HloModule module(TestName());
   1760   HloComputation* computation = module.AddEntryComputation(builder.Build());
   1761 
   1762   EXPECT_THAT(computation->root_instruction(), op::Slice(param));
   1763 
   1764   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   1765                                  non_bitcasting_callback());
   1766   ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
   1767 
   1768   EXPECT_THAT(computation->root_instruction(), param);
   1769 }
   1770 
   1771 TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
   1772   struct ConvTestOptions {
   1773     int in_batch = 10;
   1774     int in_height = 2;
   1775     int in_width = 2;
   1776     int in_channels = 3;
   1777     int f_width = 1;
   1778     int f_height = 1;
   1779     int f_output_channels = 10;
   1780     int row_stride = 1;
   1781     int row_padding = 0;
   1782     int col_stride = 1;
   1783     int col_padding = 0;
   1784     bool input_minor_to_major_layout = false;
   1785     bool filter_minor_to_major_layout = false;
   1786     bool output_minor_to_major_layout = false;
   1787 
   1788     const char* dim_order = "NHWC";         // can use chars NHWC in any order.
   1789     const char* kernel_dim_order = "HWIO";  // can use chars HWIO in any order.
   1790 
   1791     ConvTestOptions& Reset() {
   1792       *this = ConvTestOptions();
   1793       return *this;
   1794     }
   1795   };
   1796 
   1797   ConvTestOptions options;
   1798 
   1799   // Builds a convolution from <options> and runs algebraic simplification on
   1800   // the computation. Returns a string description of the result of
   1801   // simplification.
   1802   auto build_and_simplify = [&options, this]() -> string {
   1803     HloComputation::Builder b(TestName());
   1804 
   1805     Window window;
   1806     auto* f_dim_1 = window.add_dimensions();
   1807     f_dim_1->set_size(options.f_height);
   1808     f_dim_1->set_stride(options.row_stride);
   1809     f_dim_1->set_padding_low(options.row_padding);
   1810     f_dim_1->set_padding_high(options.row_padding);
   1811     f_dim_1->set_window_dilation(1);
   1812     f_dim_1->set_base_dilation(1);
   1813     auto* f_dim_2 = window.add_dimensions();
   1814     f_dim_2->set_size(options.f_width);
   1815     f_dim_2->set_stride(options.col_stride);
   1816     f_dim_2->set_padding_low(options.col_padding);
   1817     f_dim_2->set_padding_high(options.col_padding);
   1818     f_dim_2->set_window_dilation(1);
   1819     f_dim_2->set_base_dilation(1);
   1820 
   1821     ConvolutionDimensionNumbers dnums;
   1822     std::vector<int64> in_dims;
   1823     int in_channel_idx = -1;
   1824     // filled in later
   1825     dnums.add_input_spatial_dimensions(-1);
   1826     dnums.add_output_spatial_dimensions(-1);
   1827     dnums.add_input_spatial_dimensions(-1);
   1828     dnums.add_output_spatial_dimensions(-1);
   1829     for (int i = 0; i < strlen(options.dim_order); ++i) {
   1830       char ch = options.dim_order[i];
   1831       if (ch == 'N') {
   1832         dnums.set_input_batch_dimension(i);
   1833         dnums.set_output_batch_dimension(i);
   1834         in_dims.push_back(options.in_batch);
   1835       } else if (ch == 'H') {
   1836         dnums.set_input_spatial_dimensions(0, i);
   1837         dnums.set_output_spatial_dimensions(0, i);
   1838         in_dims.push_back(options.in_height);
   1839       } else if (ch == 'W') {
   1840         dnums.set_input_spatial_dimensions(1, i);
   1841         dnums.set_output_spatial_dimensions(1, i);
   1842         in_dims.push_back(options.in_width);
   1843       } else if (ch == 'C') {
   1844         dnums.set_input_feature_dimension(i);
   1845         dnums.set_output_feature_dimension(i);
   1846         in_dims.push_back(options.in_channels);
   1847         in_channel_idx = i;
   1848       }
   1849     }
   1850 
   1851     std::vector<int64> f_dims;
   1852     dnums.add_kernel_spatial_dimensions(-1);  // filled in later
   1853     dnums.add_kernel_spatial_dimensions(-1);  // filled in later
   1854     for (int i = 0; i < strlen(options.kernel_dim_order); ++i) {
   1855       char ch = options.kernel_dim_order[i];
   1856       if (ch == 'H') {
   1857         dnums.set_kernel_spatial_dimensions(0, i);
   1858         f_dims.push_back(options.f_height);
   1859       } else if (ch == 'W') {
   1860         dnums.set_kernel_spatial_dimensions(1, i);
   1861         f_dims.push_back(options.f_width);
   1862       } else if (ch == 'I') {
   1863         dnums.set_kernel_input_feature_dimension(i);
   1864         f_dims.push_back(options.in_channels);
   1865       } else if (ch == 'O') {
   1866         dnums.set_kernel_output_feature_dimension(i);
   1867         f_dims.push_back(options.f_output_channels);
   1868       }
   1869     }
   1870 
   1871     auto out_dims = in_dims;
   1872     out_dims[in_channel_idx] = options.f_output_channels;
   1873 
   1874     auto make_shape = [](tensorflow::gtl::ArraySlice<int64> dims,
   1875                          bool minor_to_major_layout) {
   1876       if (minor_to_major_layout) {
   1877         return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3});
   1878       } else {
   1879         return ShapeUtil::MakeShape(F32, dims);
   1880       }
   1881     };
   1882     auto in_shape = make_shape(in_dims, options.input_minor_to_major_layout);
   1883     auto f_shape = make_shape(f_dims, options.filter_minor_to_major_layout);
   1884     auto out_shape = make_shape(out_dims, options.output_minor_to_major_layout);
   1885 
   1886     HloInstruction* input =
   1887         b.AddInstruction(HloInstruction::CreateParameter(0, in_shape, "input"));
   1888     HloInstruction* filter =
   1889         b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
   1890 
   1891     b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
   1892                                                     window, dnums));
   1893 
   1894     HloModule module(TestName());
   1895     auto* computation = module.AddEntryComputation(b.Build());
   1896 
   1897     AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
   1898                                    bitcasting_callback());
   1899     if (!simplifier.Run(&module).ValueOrDie()) {
   1900       return "NO_CHANGE";
   1901     }
   1902     auto* root = computation->root_instruction();
   1903     if (root->opcode() == HloOpcode::kBitcast &&
   1904         root->operand(0)->opcode() == HloOpcode::kDot) {
   1905       auto lhs_shape = root->operand(0)->operand(0)->shape();
   1906       auto rhs_shape = root->operand(0)->operand(1)->shape();
   1907       return tensorflow::strings::StrCat(
   1908           tensorflow::str_util::Join(lhs_shape.dimensions(), "x"), " DOT ",
   1909           tensorflow::str_util::Join(rhs_shape.dimensions(), "x"));
   1910     }
   1911     return "UNEXPECTED CHANGE";
   1912   };
   1913 
   1914   // Default options are the simplest case and succeed.
   1915   options.Reset();
   1916   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
   1917 
   1918   // Swapping dim spatial and batch order works.
   1919   options.Reset().dim_order = "NWHC";
   1920   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
   1921   options.Reset().dim_order = "WHNC";
   1922   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
   1923   // Channel dimension earlier fails.
   1924   options.Reset().dim_order = "HWCN";
   1925   EXPECT_EQ("NO_CHANGE", build_and_simplify());
   1926   options.Reset().dim_order = "CHWN";
   1927   EXPECT_EQ("NO_CHANGE", build_and_simplify());
   1928 
   1929   // Filtering dims spatial dims can be anywhere, since they are 1x1.
   1930   options.Reset().kernel_dim_order = "WHIO";
   1931   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
   1932   options.Reset().kernel_dim_order = "IWOH";
   1933   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
   1934   options.Reset().kernel_dim_order = "IWHO";
   1935   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
   1936   // But moving output channel before input channel fails.
   1937   options.Reset().kernel_dim_order = "HWOI";
   1938   EXPECT_EQ("NO_CHANGE", build_and_simplify());
   1939   options.Reset().kernel_dim_order = "WHOI";
   1940   EXPECT_EQ("NO_CHANGE", build_and_simplify());
   1941   options.Reset().kernel_dim_order = "OWIH";
   1942   EXPECT_EQ("NO_CHANGE", build_and_simplify());
   1943   options.Reset().kernel_dim_order = "OWHI";
   1944   EXPECT_EQ("NO_CHANGE", build_and_simplify());
   1945 
   1946   // Combine different dim and kernel dim orders.
   1947   options.Reset().kernel_dim_order = "IWHO";
   1948   options.dim_order = "WHNC";
   1949   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
   1950 
   1951   // Test invalid cases from wrong filter size, strides, or padding.
   1952   options.Reset().f_width = 2;
   1953   EXPECT_EQ("NO_CHANGE", build_and_simplify());
   1954   options.Reset().f_height = 2;
   1955   EXPECT_EQ("NO_CHANGE", build_and_simplify());
   1956   options.Reset().row_stride = 2;
   1957   EXPECT_EQ("NO_CHANGE", build_and_simplify());
   1958   options.Reset().col_stride = 2;
   1959   EXPECT_EQ("NO_CHANGE", build_and_simplify());
   1960   options.Reset().col_padding = 1;
   1961   EXPECT_EQ("NO_CHANGE", build_and_simplify());
   1962   options.Reset().row_padding = 1;
   1963   EXPECT_EQ("NO_CHANGE", build_and_simplify());
   1964 
   1965   // The default dim_order is "NHWC". Col-major layout makes C the most major.
   1966   options.Reset().input_minor_to_major_layout = true;
   1967   options.output_minor_to_major_layout = true;
   1968   EXPECT_EQ("NO_CHANGE", build_and_simplify());
   1969 
   1970   // The input and output have different layouts.
   1971   options.Reset().input_minor_to_major_layout = true;
   1972   EXPECT_EQ("NO_CHANGE", build_and_simplify());
   1973 
   1974   // C is most minor, and I is more major than O.
   1975   options.Reset().input_minor_to_major_layout = true;
   1976   options.filter_minor_to_major_layout = true;
   1977   options.output_minor_to_major_layout = true;
   1978   options.dim_order = "CHWN";
   1979   options.kernel_dim_order = "OIHW";
   1980   EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
   1981 
   1982   // C is not the most minor dimension.
   1983   options.Reset().input_minor_to_major_layout = true;
   1984   options.filter_minor_to_major_layout = true;
   1985   options.output_minor_to_major_layout = true;
   1986   options.dim_order = "HWNC";
   1987   options.kernel_dim_order = "OIHW";
   1988   EXPECT_EQ("NO_CHANGE", build_and_simplify());
   1989 
   1990   // I is more minor than O.
   1991   options.Reset().input_minor_to_major_layout = true;
   1992   options.filter_minor_to_major_layout = true;
   1993   options.output_minor_to_major_layout = true;
   1994   options.dim_order = "CHWN";
   1995   options.kernel_dim_order = "IOHW";
   1996   EXPECT_EQ("NO_CHANGE", build_and_simplify());
   1997 }
   1998 
   1999 // Test that max(min(A, x), y) is transformed to clamp(y, A, x)
   2000 TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) {
   2001   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
   2002   HloComputation::Builder builder(TestName());
   2003   HloInstruction* param0 = builder.AddInstruction(
   2004       HloInstruction::CreateParameter(0, r0f32, "param0"));
   2005   HloInstruction* min_value = builder.AddInstruction(
   2006       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
   2007   HloInstruction* max_value = builder.AddInstruction(
   2008       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
   2009   HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary(
   2010       r0f32, HloOpcode::kMinimum, param0, min_value));
   2011   builder.AddInstruction(
   2012       HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value));
   2013 
   2014   HloModule module(TestName());
   2015   auto computation = module.AddEntryComputation(builder.Build());
   2016 
   2017   EXPECT_THAT(computation->root_instruction(),
   2018               op::Maximum(op::Minimum(param0, min_value), max_value));
   2019 
   2020   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   2021                                  non_bitcasting_callback());
   2022   ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
   2023 
   2024   EXPECT_THAT(computation->root_instruction(),
   2025               op::Clamp(max_value, param0, min_value));
   2026 }
   2027 
   2028 // Test that min(max(A, x), y) is transformed to clamp(x, A, y) for scalar
   2029 // values.
   2030 TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) {
   2031   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
   2032   HloComputation::Builder builder(TestName());
   2033   HloInstruction* param0 = builder.AddInstruction(
   2034       HloInstruction::CreateParameter(0, r0f32, "param0"));
   2035   HloInstruction* min_value = builder.AddInstruction(
   2036       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
   2037   HloInstruction* max_value = builder.AddInstruction(
   2038       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
   2039   HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
   2040       r0f32, HloOpcode::kMaximum, param0, max_value));
   2041   builder.AddInstruction(
   2042       HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
   2043 
   2044   HloModule module(TestName());
   2045   auto computation = module.AddEntryComputation(builder.Build());
   2046 
   2047   EXPECT_THAT(computation->root_instruction(),
   2048               op::Minimum(op::Maximum(param0, max_value), min_value));
   2049 
   2050   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   2051                                  non_bitcasting_callback());
   2052   ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
   2053 
   2054   EXPECT_THAT(computation->root_instruction(),
   2055               op::Clamp(max_value, param0, min_value));
   2056 }
   2057 
   2058 // Test that min(max(A, x), y) is transformed to clamp(x, A, y) for
   2059 // broadcasted scalar values.
   2060 TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) {
   2061   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
   2062   Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
   2063   HloComputation::Builder builder(TestName());
   2064   HloInstruction* param0 = builder.AddInstruction(
   2065       HloInstruction::CreateParameter(0, r1f32, "param0"));
   2066   HloInstruction* min_value = builder.AddInstruction(
   2067       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
   2068   HloInstruction* max_value = builder.AddInstruction(
   2069       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
   2070   HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
   2071       r1f32, HloOpcode::kMaximum, param0, max_value));
   2072   builder.AddInstruction(
   2073       HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value));
   2074 
   2075   HloModule module(TestName());
   2076   auto computation = module.AddEntryComputation(builder.Build());
   2077 
   2078   EXPECT_THAT(computation->root_instruction(),
   2079               op::Minimum(op::Maximum(param0, max_value), min_value));
   2080 
   2081   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   2082                                  non_bitcasting_callback());
   2083   ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
   2084 
   2085   EXPECT_THAT(computation->root_instruction(),
   2086               op::Clamp(max_value, param0, min_value));
   2087 }
   2088 
   2089 // Test that min(max(A, non-constant1), non-constant2) is not canonicalized to
   2090 // clamp(non-constant1, A, non-constant2)
   2091 TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) {
   2092   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
   2093   HloComputation::Builder builder(TestName());
   2094   HloInstruction* param0 = builder.AddInstruction(
   2095       HloInstruction::CreateParameter(0, r0f32, "param0"));
   2096   HloInstruction* min_value = builder.AddInstruction(
   2097       HloInstruction::CreateParameter(1, r0f32, "param1"));
   2098   HloInstruction* max_value = builder.AddInstruction(
   2099       HloInstruction::CreateParameter(2, r0f32, "param2"));
   2100   HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
   2101       r0f32, HloOpcode::kMaximum, param0, max_value));
   2102   builder.AddInstruction(
   2103       HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
   2104 
   2105   HloModule module(TestName());
   2106   auto computation = module.AddEntryComputation(builder.Build());
   2107 
   2108   EXPECT_THAT(computation->root_instruction(),
   2109               op::Minimum(op::Maximum(param0, max_value), min_value));
   2110 
   2111   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   2112                                  non_bitcasting_callback());
   2113   EXPECT_FALSE(simplifier.Run(&module).ValueOrDie());
   2114 
   2115   EXPECT_THAT(computation->root_instruction(),
   2116               op::Minimum(op::Maximum(param0, max_value), min_value));
   2117 }
   2118 
   2119 // Test that min(f(max(A, constant1)), constant2) is not transformed to
   2120 // clamp(constant1, A, constant2)
   2121 TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
   2122   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
   2123   HloComputation::Builder builder(TestName());
   2124   HloInstruction* param0 = builder.AddInstruction(
   2125       HloInstruction::CreateParameter(0, r0f32, "param0"));
   2126   HloInstruction* min_value = builder.AddInstruction(
   2127       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
   2128   HloInstruction* max_value = builder.AddInstruction(
   2129       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
   2130   HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
   2131       r0f32, HloOpcode::kMaximum, param0, max_value));
   2132   HloInstruction* fmax = builder.AddInstruction(
   2133       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, max, max_value));
   2134   builder.AddInstruction(HloInstruction::CreateBinary(
   2135       r0f32, HloOpcode::kMinimum, fmax, min_value));
   2136 
   2137   HloModule module(TestName());
   2138   auto computation = module.AddEntryComputation(builder.Build());
   2139 
   2140   EXPECT_THAT(computation->root_instruction(),
   2141               op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
   2142                           min_value));
   2143 
   2144   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   2145                                  non_bitcasting_callback());
   2146   EXPECT_FALSE(simplifier.Run(&module).ValueOrDie());
   2147 
   2148   EXPECT_THAT(computation->root_instruction(),
   2149               op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
   2150                           min_value));
   2151 }
   2152 
   2153 // Test that slice(broadcast(/*scalar value*/)) simplifies to a single
   2154 // broadcast.
   2155 TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
   2156   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
   2157   HloComputation::Builder builder(TestName());
   2158   HloInstruction* scalar_param = builder.AddInstruction(
   2159       HloInstruction::CreateParameter(0, r0f32, "scalar_param"));
   2160 
   2161   Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
   2162   HloInstruction* broadcast =
   2163       builder.AddInstruction(HloInstruction::CreateBroadcast(
   2164           broadcast_shape, scalar_param,
   2165           AsInt64Slice(broadcast_shape.dimensions())));
   2166 
   2167   Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
   2168   HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
   2169       slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1}));
   2170 
   2171   HloModule module(TestName());
   2172   auto computation = module.AddEntryComputation(builder.Build());
   2173 
   2174   HloInstruction* root = computation->root_instruction();
   2175   EXPECT_EQ(root, slice);
   2176   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape));
   2177 
   2178   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   2179                                  non_bitcasting_callback());
   2180 
   2181   ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
   2182 
   2183   // Running simplification again should not result in any further changes.
   2184   ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
   2185 
   2186   root = computation->root_instruction();
   2187   EXPECT_THAT(root, op::Broadcast(scalar_param));
   2188   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape));
   2189 }
   2190 
   2191 // Test that reshape(transpose(broadcast(/*scalar value*/))) simplifies to a
   2192 // single broadcast.
   2193 TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
   2194   HloComputation::Builder builder(TestName());
   2195   HloInstruction* forty_two = builder.AddInstruction(
   2196       HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
   2197 
   2198   Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
   2199   HloInstruction* broadcast =
   2200       builder.AddInstruction(HloInstruction::CreateBroadcast(
   2201           broadcast_shape, forty_two,
   2202           AsInt64Slice(broadcast_shape.dimensions())));
   2203 
   2204   HloInstruction* transpose =
   2205       builder.AddInstruction(HloInstruction::CreateTranspose(
   2206           ShapeUtil::MakeShape(F32, {6, 5, 4}), broadcast, {2, 1, 0}));
   2207 
   2208   Shape reshape_shape = ShapeUtil::MakeShape(F32, {30, 1, 4});
   2209   HloInstruction* reshape = builder.AddInstruction(
   2210       HloInstruction::CreateReshape(reshape_shape, transpose));
   2211 
   2212   HloModule module(TestName());
   2213   auto computation = module.AddEntryComputation(builder.Build());
   2214 
   2215   HloInstruction* root = computation->root_instruction();
   2216   EXPECT_EQ(root, reshape);
   2217   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape));
   2218 
   2219   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   2220                                  non_bitcasting_callback());
   2221   ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
   2222 
   2223   root = computation->root_instruction();
   2224   EXPECT_THAT(root, op::Broadcast(forty_two));
   2225   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape));
   2226 }
   2227 
   2228 // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
   2229 TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
   2230   HloModule module(TestName());
   2231   HloComputation::Builder builder(TestName());
   2232 
   2233   // Create operand to the pad.
   2234   HloInstruction* operand =
   2235       builder.AddInstruction(HloInstruction::CreateParameter(
   2236           0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "p0"));
   2237 
   2238   // Create the pad.
   2239   PaddingConfig padding = MakeNoPaddingConfig(4);
   2240   padding.mutable_dimensions(1)->set_edge_padding_low(1);
   2241   padding.mutable_dimensions(3)->set_edge_padding_high(2);
   2242 
   2243   HloInstruction* pad_value = builder.AddInstruction(
   2244       HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
   2245   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
   2246       ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding));
   2247 
   2248   // Create add computation.
   2249   HloComputation* add_computation = nullptr;
   2250   {
   2251     HloComputation::Builder builder(TestName() + ".add");
   2252     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
   2253     HloInstruction* p0 = builder.AddInstruction(
   2254         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
   2255     HloInstruction* p1 = builder.AddInstruction(
   2256         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
   2257     builder.AddInstruction(
   2258         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
   2259     add_computation = module.AddEmbeddedComputation(builder.Build());
   2260   }
   2261 
   2262   // Create the reduce-window.
   2263   Window window;
   2264   for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) {
   2265     auto* dim = window.add_dimensions();
   2266     dim->set_size(1);
   2267     dim->set_padding_low(10);
   2268     dim->set_padding_high(100);
   2269     dim->set_window_dilation(1);
   2270     dim->set_base_dilation(1);
   2271   }
   2272   const Shape reduce_window_shape =
   2273       ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
   2274   HloInstruction* reduce_init_value = builder.AddInstruction(
   2275       HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
   2276   HloInstruction* reduce_window =
   2277       builder.AddInstruction(HloInstruction::CreateReduceWindow(
   2278           reduce_window_shape, pad, reduce_init_value, window,
   2279           add_computation));
   2280 
   2281   // Build the computation and run the simplifier.
   2282   auto computation = module.AddEntryComputation(builder.Build());
   2283   HloInstruction* root = computation->root_instruction();
   2284   EXPECT_EQ(root, reduce_window);
   2285   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   2286                                  non_bitcasting_callback());
   2287   ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
   2288 
   2289   // Running simplification again should not result in any further changes.
   2290   ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
   2291 
   2292   // Verify the result
   2293   root = computation->root_instruction();
   2294   EXPECT_THAT(root, op::ReduceWindow(operand, op::Constant()));
   2295   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape))
   2296       << ShapeUtil::HumanString(root->shape()) << " vs "
   2297       << ShapeUtil::HumanString(reduce_window_shape);
   2298   EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
   2299   EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
   2300   EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
   2301   EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
   2302   EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
   2303   EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
   2304   EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
   2305   EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
   2306 }
   2307 
   2308 TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
   2309   HloComputation::Builder builder(TestName());
   2310   const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1});
   2311   HloInstruction* a =
   2312       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
   2313   builder.AddInstruction(
   2314       HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3}));
   2315 
   2316   HloModule module(TestName());
   2317   auto computation = module.AddEntryComputation(builder.Build());
   2318 
   2319   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   2320                                  non_bitcasting_callback());
   2321   ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
   2322 
   2323   HloInstruction* root = computation->root_instruction();
   2324   EXPECT_EQ(a, root);
   2325   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
   2326 }
   2327 
   2328 TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
   2329   // Dots add computations to the parent module. Test that, when the HloModule's
   2330   // computations are updated, then iterator invalidation doesn't occur
   2331   // when running on subsequent computations.
   2332   Shape r1f32 = ShapeUtil::MakeShape(F32, {1});
   2333   HloComputation::Builder builder(TestName() + ".Dot");
   2334   HloInstruction* x =
   2335       builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
   2336   HloInstruction* y =
   2337       builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
   2338   DotDimensionNumbers dot_dnums;
   2339   dot_dnums.add_lhs_contracting_dimensions(1);
   2340   dot_dnums.add_rhs_contracting_dimensions(0);
   2341   builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums));
   2342   std::unique_ptr<HloComputation> dot_computation(builder.Build());
   2343 
   2344   HloComputation::Builder call_builder(TestName() + ".Call");
   2345   HloInstruction* zero = call_builder.AddInstruction(
   2346       HloInstruction::CreateConstant(Literal::CreateR1<float>({0.0f})));
   2347   HloInstruction* one = call_builder.AddInstruction(
   2348       HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0f})));
   2349   call_builder.AddInstruction(
   2350       HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
   2351 
   2352   module().AddEmbeddedComputation(std::move(dot_computation));
   2353   module().AddEntryComputation(call_builder.Build());
   2354   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   2355                                  non_bitcasting_callback());
   2356   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   2357 }
   2358 
   2359 // Test that a constant with tuple shape becomes a tuple of constants.
   2360 TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
   2361   HloComputation::Builder builder(TestName());
   2362   const float constant_scalar = 7.3f;
   2363   std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
   2364   std::unique_ptr<Literal> value =
   2365       Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(),
   2366                           Literal::CreateR1<float>(constant_vector).get()});
   2367   builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
   2368 
   2369   auto computation = module().AddEntryComputation(builder.Build());
   2370 
   2371   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   2372                                  non_bitcasting_callback());
   2373   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   2374   EXPECT_THAT(computation->root_instruction(),
   2375               op::Tuple(op::Constant(), op::Constant()));
   2376 }
   2377 
   2378 // A dynamic-slice is trivial if its start indices are all zeroes and the size
   2379 // of its input equals the size of its output.  In this case, the dynamic slice
   2380 // is equal to its input.
   2381 TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) {
   2382   HloComputation::Builder builder(TestName());
   2383 
   2384   Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
   2385   builder.AddInstruction(HloInstruction::CreateDynamicSlice(
   2386       shape,
   2387       builder.AddInstruction(
   2388           HloInstruction::CreateParameter(0, shape, "slice_from")),
   2389       builder.AddInstruction(
   2390           HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0}))),
   2391       /*slice_sizes=*/{10, 100, 1000}));
   2392 
   2393   auto computation = module().AddEntryComputation(builder.Build());
   2394   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   2395                                  non_bitcasting_callback());
   2396   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   2397   EXPECT_THAT(computation->root_instruction(), op::Parameter());
   2398 }
   2399 
   2400 // A dynamic-update-slice is trivial if its start indices are all zeroes and the
   2401 // size of its "update" equals the size of its output.  In this case, the
   2402 // dynamic-update-slice is equal to its update.
   2403 TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
   2404   HloComputation::Builder builder(TestName());
   2405 
   2406   Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
   2407   Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000});
   2408 
   2409   HloInstruction* slice =
   2410       builder.AddInstruction(HloInstruction::CreateDynamicSlice(
   2411           slice_shape,
   2412           builder.AddInstruction(
   2413               HloInstruction::CreateParameter(0, full_shape, "slice_from")),
   2414           builder.AddInstruction(HloInstruction::CreateParameter(
   2415               1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")),
   2416           /*slice_sizes=*/{10, 1, 1000}));
   2417 
   2418   builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
   2419       slice_shape,
   2420       builder.AddInstruction(
   2421           HloInstruction::CreateParameter(2, slice_shape, "to_update")),
   2422       slice,
   2423       builder.AddInstruction(
   2424           HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0})))));
   2425 
   2426   auto computation = module().AddEntryComputation(builder.Build());
   2427   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   2428                                  non_bitcasting_callback());
   2429   ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
   2430   EXPECT_THAT(computation->root_instruction(),
   2431               op::DynamicSlice(op::Parameter(), op::Parameter()));
   2432 }
   2433 
   2434 struct PadReduceWindowEffectiveBroadcastCase {
   2435   std::vector<int64> input_spatials;
   2436   std::vector<int64> symmetric_pad_spatials;
   2437   std::vector<int64> reduce_window_spatials;
   2438   // Whether to use `B F S0 S1` form vs `B S0 S1 F` form.
   2439   //
   2440   // This doesn't test any different functionality but is useful for making sure
   2441   // kBroadcast nodes are well formed.
   2442   bool prepend_a;
   2443   bool should_become_broadcast;
   2444 
   2445   string ToTestCaseName() const {
   2446     return tensorflow::strings::StrCat(
   2447         tensorflow::str_util::Join(input_spatials, ","), ";",
   2448         tensorflow::str_util::Join(symmetric_pad_spatials, ","), ";",
   2449         tensorflow::str_util::Join(reduce_window_spatials, ","), ";", prepend_a,
   2450         ";", should_become_broadcast);
   2451   }
   2452 };
   2453 
   2454 void PrintTo(const PadReduceWindowEffectiveBroadcastCase& c, std::ostream* os) {
   2455   *os << c.ToTestCaseName();
   2456 }
   2457 
   2458 class PadReduceWindowEffectiveBroadcastTest
   2459     : public AlgebraicSimplifierTest,
   2460       public ::testing::WithParamInterface<
   2461           PadReduceWindowEffectiveBroadcastCase> {};
   2462 
   2463 TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
   2464   const auto& param = GetParam();
   2465 
   2466   // a and b are parallel bounds we can either turn into a B F S0 S1 or
   2467   // `B S0 S1 F` kind of pattern.
   2468   auto decorate_spatials = [&param](tensorflow::gtl::ArraySlice<int64> spatials,
   2469                                     int64 a, int64 b) {
   2470     std::vector<int64> result;
   2471     if (param.prepend_a) {
   2472       result.push_back(a);
   2473     }
   2474     for (int64 s : spatials) {
   2475       result.push_back(s);
   2476     }
   2477     if (!param.prepend_a) {
   2478       result.push_back(a);
   2479     }
   2480     result.push_back(b);
   2481     return result;
   2482   };
   2483 
   2484   HloComputation::Builder builder(TestName());
   2485   const Shape input_shape = ShapeUtil::MakeShape(
   2486       F32, decorate_spatials(param.input_spatials, 128, 2048));
   2487   HloInstruction* input = builder.AddInstruction(
   2488       HloInstruction::CreateParameter(0, input_shape, "input"));
   2489 
   2490   PaddingConfig padding = window_util::MakeSymmetricPadding(
   2491       decorate_spatials(param.symmetric_pad_spatials, 0, 0));
   2492   TF_ASSERT_OK_AND_ASSIGN(
   2493       const Shape pad_shape,
   2494       ShapeInference::InferPadShape(input->shape(),
   2495                                     ShapeUtil::MakeShape(F32, {}), padding));
   2496   HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
   2497       pad_shape, input,
   2498       builder.AddInstruction(
   2499           HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
   2500       padding));
   2501 
   2502   HloComputation* add_computation = nullptr;
   2503   {
   2504     HloComputation::Builder builder(TestName() + ".add");
   2505     const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
   2506     HloInstruction* p0 = builder.AddInstruction(
   2507         HloInstruction::CreateParameter(0, scalar_shape, "p0"));
   2508     HloInstruction* p1 = builder.AddInstruction(
   2509         HloInstruction::CreateParameter(1, scalar_shape, "p1"));
   2510     builder.AddInstruction(
   2511         HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
   2512     add_computation = module().AddEmbeddedComputation(builder.Build());
   2513   }
   2514 
   2515   Window window = window_util::MakeWindow(
   2516       decorate_spatials(param.reduce_window_spatials, 1, 1));
   2517   auto zero = builder.AddInstruction(
   2518       HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
   2519   TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape,
   2520                           ShapeInference::InferReduceWindowShape(
   2521                               pad->shape(), zero->shape(), window,
   2522                               add_computation->ComputeProgramShape()));
   2523   builder.AddInstruction(HloInstruction::CreateReduceWindow(
   2524       output_shape, pad, zero, window, add_computation));
   2525 
   2526   auto computation = module().AddEntryComputation(builder.Build());
   2527   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   2528                                  non_bitcasting_callback());
   2529   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
   2530   ASSERT_TRUE(run_successful);
   2531 
   2532   EXPECT_TRUE(
   2533       ShapeUtil::Equal(computation->root_instruction()->shape(), output_shape));
   2534 
   2535   if (param.should_become_broadcast) {
   2536     EXPECT_THAT(computation->root_instruction(), op::Broadcast(::testing::_));
   2537   } else {
   2538     EXPECT_THAT(computation->root_instruction(),
   2539                 op::ReduceWindow(::testing::_, zero));
   2540   }
   2541 }
   2542 
   2543 const std::vector<PadReduceWindowEffectiveBroadcastCase>&
   2544 PadReduceWindowEffectiveBroadcastCases() {
   2545   static auto* cases = new std::vector<PadReduceWindowEffectiveBroadcastCase>{
   2546       {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
   2547        /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
   2548        /*should_become_broadcast=*/true},  //
   2549       {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
   2550        /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/false,
   2551        /*should_become_broadcast=*/true},  //
   2552       {/*input_spatials=*/{2, 2}, /*symmetric_pad_amount=*/{6, 6},
   2553        /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
   2554        /*should_become_broadcast=*/false},  //
   2555       {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2},
   2556        /*reduce_window_spatials=*/{5, 5}, /*prepend_a=*/true,
   2557        /*should_become_broadcast=*/true},  //
   2558       {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2},
   2559        /*reduce_window_spatials=*/{1, 1}, /*prepend_a=*/true,
   2560        /*should_become_broadcast=*/false},  //
   2561       {/*input_spatials=*/{5, 1}, /*symmetric_pad_amount=*/{0, 2},
   2562        /*reduce_window_spatials=*/{2, 5}, /*prepend_a=*/true,
   2563        /*should_become_broadcast=*/false},  //
   2564   };
   2565   return *cases;
   2566 }
   2567 
   2568 INSTANTIATE_TEST_CASE_P(
   2569     PadReduceWindowEffectiveBroadcastInstantiation,
   2570     PadReduceWindowEffectiveBroadcastTest,
   2571     ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases()));
   2572 
   2573 class DotStrengthReductionTest
   2574     : public AlgebraicSimplifierTest,
   2575       public ::testing::WithParamInterface<
   2576           ::testing::tuple<int, int, int, bool, bool>> {};
   2577 TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
   2578   int m, k, n;
   2579   bool transpose_lhs, transpose_rhs;
   2580   std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam();
   2581 
   2582   Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
   2583   Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
   2584   Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m});
   2585   Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
   2586   Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k});
   2587   HloComputation::Builder builder(TestName());
   2588 
   2589   auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
   2590       0, transpose_lhs ? transposed_lhs_shape : lhs_shape, "lhs"));
   2591   if (transpose_lhs) {
   2592     lhs = builder.AddInstruction(
   2593         HloInstruction::CreateTranspose(lhs_shape, lhs, {1, 0}));
   2594   }
   2595   auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
   2596       1, transpose_rhs ? transposed_rhs_shape : rhs_shape, "rhs"));
   2597   if (transpose_rhs) {
   2598     rhs = builder.AddInstruction(
   2599         HloInstruction::CreateTranspose(rhs_shape, rhs, {1, 0}));
   2600   }
   2601   DotDimensionNumbers dot_dnums;
   2602   dot_dnums.add_lhs_contracting_dimensions(1);
   2603   dot_dnums.add_rhs_contracting_dimensions(0);
   2604   builder.AddInstruction(
   2605       HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
   2606   auto computation = module().AddEntryComputation(builder.Build());
   2607   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   2608                                  non_bitcasting_callback());
   2609   TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(&module()));
   2610   const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1;
   2611   const bool computation_should_be_modified =
   2612       dot_should_be_transformed || (transpose_lhs && transpose_rhs);
   2613   EXPECT_EQ(changed, computation_should_be_modified);
   2614   bool has_no_dot = true;
   2615   for (const auto& hlo : computation->instructions()) {
   2616     if (hlo->opcode() == HloOpcode::kDot) {
   2617       has_no_dot = false;
   2618       break;
   2619     }
   2620   }
   2621   EXPECT_EQ(has_no_dot, dot_should_be_transformed);
   2622 }
   2623 
   2624 INSTANTIATE_TEST_CASE_P(
   2625     DotStrengthReductionTestInstantiation, DotStrengthReductionTest,
   2626     ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
   2627                        ::testing::Values(1, 2), ::testing::Bool(),
   2628                        ::testing::Bool()));
   2629 
   2630 struct DotOfConcatTestSpec {
   2631   int64 m;
   2632   int64 k;
   2633   int64 n;
   2634 };
   2635 
   2636 class DotOfConcatSimplificationTest
   2637     : public HloVerifiedTestBase,
   2638       public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
   2639 
   2640 // Test that we transform
   2641 //  dot(const, concat(A, B, C))
   2642 // to
   2643 //  add(dot(const_0, A), dot(const_1, B),  dot(const_2, C))
   2644 TEST_P(DotOfConcatSimplificationTest, ConstantLHS) {
   2645   HloComputation::Builder builder(TestName());
   2646 
   2647   DotOfConcatTestSpec spec = GetParam();
   2648 
   2649   ASSERT_GE(spec.k, 3);
   2650 
   2651   int64 k0 = spec.k / 3;
   2652   int64 k1 = spec.k / 3;
   2653   int64 k2 = spec.k - k0 - k1;
   2654 
   2655   Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
   2656   auto* lhs = builder.AddInstruction(
   2657       HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
   2658           /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k)));
   2659 
   2660   Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n});
   2661   Shape rhs1_shape = ShapeUtil::MakeShape(F32, {k1, spec.n});
   2662   Shape rhs2_shape = ShapeUtil::MakeShape(F32, {k2, spec.n});
   2663 
   2664   HloInstruction* rhs0 = builder.AddInstruction(
   2665       HloInstruction::CreateParameter(0, rhs0_shape, "rhs0"));
   2666   HloInstruction* rhs1 = builder.AddInstruction(
   2667       HloInstruction::CreateParameter(1, rhs1_shape, "rhs1"));
   2668   HloInstruction* rhs2 = builder.AddInstruction(
   2669       HloInstruction::CreateParameter(2, rhs2_shape, "rhs2"));
   2670 
   2671   Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
   2672   HloInstruction* rhs = builder.AddInstruction(
   2673       HloInstruction::CreateConcatenate(rhs_shape, {rhs0, rhs1, rhs2}, 0));
   2674 
   2675   DotDimensionNumbers dot_dnums;
   2676   dot_dnums.add_lhs_contracting_dimensions(1);
   2677   dot_dnums.add_rhs_contracting_dimensions(0);
   2678 
   2679   Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
   2680   builder.AddInstruction(
   2681       HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
   2682 
   2683   auto computation = module().AddEntryComputation(builder.Build());
   2684   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   2685                                  non_bitcasting_callback());
   2686   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
   2687   ASSERT_TRUE(run_successful);
   2688 
   2689   EXPECT_TRUE(
   2690       ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
   2691 
   2692   auto match_dot_0 = op::Dot(op::Slice(op::Constant()), op::Parameter(0));
   2693   auto match_dot_1 = op::Dot(op::Slice(op::Constant()), op::Parameter(1));
   2694   auto match_dot_2 = op::Dot(op::Slice(op::Constant()), op::Parameter(2));
   2695   EXPECT_THAT(computation->root_instruction(),
   2696               op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2));
   2697 }
   2698 
   2699 // Test that we transform
   2700 //  dot(concat(A, B, C), const)
   2701 // to
   2702 //  add(dot(A, const_0), dot(B, const_1),  dot(C, const_2))
   2703 TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
   2704   HloComputation::Builder builder(TestName());
   2705 
   2706   DotOfConcatTestSpec spec = GetParam();
   2707 
   2708   ASSERT_GE(spec.k, 4);
   2709 
   2710   int64 k0 = spec.k / 4;
   2711   int64 k1 = spec.k / 4;
   2712   int64 k2 = spec.k / 4;
   2713   int64 k3 = spec.k - k0 - k1 - k2;
   2714 
   2715   Shape lhs0_shape = ShapeUtil::MakeShape(F32, {spec.m, k0});
   2716   Shape lhs1_shape = ShapeUtil::MakeShape(F32, {spec.m, k1});
   2717   Shape lhs2_shape = ShapeUtil::MakeShape(F32, {spec.m, k2});
   2718   Shape lhs3_shape = ShapeUtil::MakeShape(F32, {spec.m, k3});
   2719 
   2720   HloInstruction* lhs0 = builder.AddInstruction(
   2721       HloInstruction::CreateParameter(0, lhs0_shape, "lhs0"));
   2722   HloInstruction* lhs1 = builder.AddInstruction(
   2723       HloInstruction::CreateParameter(1, lhs1_shape, "lhs1"));
   2724   HloInstruction* lhs2 = builder.AddInstruction(
   2725       HloInstruction::CreateParameter(2, lhs2_shape, "lhs2"));
   2726   HloInstruction* lhs3 = builder.AddInstruction(
   2727       HloInstruction::CreateParameter(3, lhs3_shape, "lhs3"));
   2728 
   2729   Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
   2730   HloInstruction* lhs =
   2731       builder.AddInstruction(HloInstruction::CreateConcatenate(
   2732           lhs_shape, {lhs0, lhs1, lhs2, lhs3}, 1));
   2733 
   2734   Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
   2735   auto* rhs = builder.AddInstruction(
   2736       HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
   2737           /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n)));
   2738 
   2739   DotDimensionNumbers dot_dnums;
   2740   dot_dnums.add_lhs_contracting_dimensions(1);
   2741   dot_dnums.add_rhs_contracting_dimensions(0);
   2742 
   2743   Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
   2744   builder.AddInstruction(
   2745       HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
   2746 
   2747   auto computation = module().AddEntryComputation(builder.Build());
   2748   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
   2749                                  non_bitcasting_callback());
   2750   TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
   2751   ASSERT_TRUE(run_successful);
   2752   EXPECT_TRUE(
   2753       ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
   2754 
   2755   auto match_dot_0 = op::Dot(op::Parameter(0), op::Slice(op::Constant()));
   2756   auto match_dot_1 = op::Dot(op::Parameter(1), op::Slice(op::Constant()));
   2757   auto match_dot_2 = op::Dot(op::Parameter(2), op::Slice(op::Constant()));
   2758   auto match_dot_3 = op::Dot(op::Parameter(3), op::Slice(op::Constant()));
   2759   EXPECT_THAT(computation->root_instruction(),
   2760               op::Add(op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2),
   2761                       match_dot_3));
   2762 }
   2763 
   2764 DotOfConcatTestSpec kDotOfConcatTestSpecs[] = {
   2765     {/*m=*/3, /*k=*/9, /*n=*/3},    //
   2766     {/*m=*/3, /*k=*/20, /*n=*/3},   //
   2767     {/*m=*/1, /*k=*/18, /*n=*/5},   //
   2768     {/*m=*/20, /*k=*/20, /*n=*/1},  //
   2769     {/*m=*/1, /*k=*/16, /*n=*/1},   //
   2770 };
   2771 
   2772 INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation,
   2773                         DotOfConcatSimplificationTest,
   2774                         ::testing::ValuesIn(kDotOfConcatTestSpecs));
   2775 }  // namespace
   2776 }  // namespace xla
   2777