Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/reshape_mover.h"
     17 
     18 #include "tensorflow/compiler/xla/layout_util.h"
     19 #include "tensorflow/compiler/xla/literal_util.h"
     20 #include "tensorflow/compiler/xla/ptr_util.h"
     21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     25 #include "tensorflow/compiler/xla/shape_util.h"
     26 #include "tensorflow/compiler/xla/test.h"
     27 #include "tensorflow/compiler/xla/test_helpers.h"
     28 #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
     29 #include "tensorflow/compiler/xla/types.h"
     30 #include "tensorflow/compiler/xla/xla_data.pb.h"
     31 #include "tensorflow/core/lib/strings/str_util.h"
     32 
     33 namespace op = xla::testing::opcode_matchers;
     34 
     35 namespace xla {
     36 namespace {
     37 using ReshapeMoverTest = HloVerifiedTestBase;
     38 
     39 TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) {
     40   HloComputation::Builder builder(TestName());
     41   auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
     42   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
     43       0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0"));
     44   auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
     45       1, ShapeUtil::MakeShape(F32, {1, 8, 7, 1}), "param1"));
     46   auto reshape0 =
     47       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
     48   auto reshape1 =
     49       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
     50   builder.AddInstruction(HloInstruction::CreateBinary(
     51       root_shape, HloOpcode::kAdd, reshape0, reshape1));
     52 
     53   auto computation = module().AddEntryComputation(builder.Build());
     54 
     55   EXPECT_THAT(computation->root_instruction(),
     56               op::Add(op::Reshape(param0), op::Reshape(param1)));
     57 
     58   EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie());
     59 
     60   EXPECT_THAT(computation->root_instruction(),
     61               op::Add(op::Reshape(param0), op::Reshape(param1)));
     62 }
     63 
     64 // For a graph that looks like:
     65 //
     66 // +- reshape0 - rng0
     67 // |
     68 // +- const1
     69 // |
     70 // add
     71 //
     72 // where rng0 has a different shape than reshape0.
     73 //
     74 // Verifies that the reshape is not moved, since rng0 is trivially reshapable
     75 // and therefore there is no nontrivial reshapes to move.
     76 TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) {
     77   HloComputation::Builder builder(TestName());
     78   auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
     79   auto rng0 = builder.AddInstruction(
     80       HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {1, 8, 1, 7, 1}),
     81                                 RandomDistribution::RNG_UNIFORM, {}));
     82   auto reshape0 =
     83       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, rng0));
     84 
     85   auto const1 = builder.AddInstruction(
     86       HloInstruction::CreateConstant(Literal::CreateFromShape(root_shape)));
     87 
     88   builder.AddInstruction(HloInstruction::CreateBinary(
     89       root_shape, HloOpcode::kAdd, reshape0, const1));
     90 
     91   auto computation = module().AddEntryComputation(builder.Build());
     92 
     93   EXPECT_THAT(computation->root_instruction(),
     94               op::Add(op::Reshape(rng0), const1));
     95 
     96   EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie());
     97 
     98   EXPECT_THAT(computation->root_instruction(),
     99               op::Add(op::Reshape(rng0), const1));
    100 }
    101 
    102 TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) {
    103   HloComputation::Builder builder(TestName());
    104   auto root_shape = ShapeUtil::MakeShape(F32, {});
    105   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
    106       0, ShapeUtil::MakeShape(F32, {1, 1, 1}), "param0"));
    107   auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
    108       1, ShapeUtil::MakeShape(F32, {1, 1, 1}), "param1"));
    109   auto reshape0 =
    110       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
    111   auto reshape1 =
    112       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
    113   builder.AddInstruction(HloInstruction::CreateBinary(
    114       root_shape, HloOpcode::kAdd, reshape0, reshape1));
    115 
    116   auto computation = module().AddEntryComputation(builder.Build());
    117 
    118   EXPECT_THAT(computation->root_instruction(),
    119               op::Add(op::Reshape(param0), op::Reshape(param1)));
    120 
    121   EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie());
    122 
    123   EXPECT_THAT(
    124       computation->root_instruction(),
    125       op::Add(op::Reshape(op::Parameter()), op::Reshape(op::Parameter())));
    126 }
    127 
    128 TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) {
    129   HloComputation::Builder builder(TestName());
    130   auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
    131   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
    132       0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0"));
    133   auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
    134       1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1"));
    135   auto reshape0 =
    136       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
    137   auto reshape1 =
    138       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
    139   builder.AddInstruction(HloInstruction::CreateBinary(
    140       root_shape, HloOpcode::kAdd, reshape0, reshape1));
    141 
    142   auto computation = module().AddEntryComputation(builder.Build());
    143 
    144   EXPECT_THAT(computation->root_instruction(),
    145               op::Add(op::Reshape(param0), op::Reshape(param1)));
    146   EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie());
    147 
    148   EXPECT_THAT(computation->root_instruction(),
    149               op::Reshape(op::Add(param0, param1)));
    150   EXPECT_EQ(root_shape.DebugString(),
    151             computation->root_instruction()->shape().DebugString());
    152 }
    153 
    154 // For a graph that looks like:
    155 //
    156 // +- reshape2 - param2
    157 // |
    158 // +- reshape1 - param1
    159 // |
    160 // +- constant0
    161 // |
    162 // select
    163 //
    164 // Verifies that the reshape1 and reshape2 sink past select:
    165 //
    166 // +- param2
    167 // |
    168 // +- param1
    169 // |
    170 // +- reshape3(constant0)
    171 // |
    172 // select
    173 // |
    174 // reshape4
    175 TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) {
    176   HloComputation::Builder builder(TestName());
    177   auto root_shape = ShapeUtil::MakeShape(F32, {2, 3});
    178   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    179       Literal::CreateR2<bool>({{true, true, false}, {false, false, true}})));
    180 
    181   auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
    182       0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param1"));
    183   auto reshape1 =
    184       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
    185 
    186   auto param2 = builder.AddInstruction(HloInstruction::CreateParameter(
    187       1, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param2"));
    188   auto reshape2 =
    189       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param2));
    190 
    191   builder.AddInstruction(HloInstruction::CreateTernary(
    192       root_shape, HloOpcode::kSelect, const0, reshape1, reshape2));
    193 
    194   auto computation = module().AddEntryComputation(builder.Build());
    195 
    196   EXPECT_THAT(computation->root_instruction(),
    197               op::Select(const0, reshape1, reshape2));
    198 
    199   EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie());
    200 
    201   EXPECT_THAT(computation->root_instruction(),
    202               op::Reshape(op::Select(op::Reshape(const0), param1, param2)));
    203 
    204   EXPECT_EQ(root_shape.DebugString(),
    205             computation->root_instruction()->shape().DebugString());
    206 }
    207 
    208 // For a graph that looks like:
    209 //
    210 // +- reshape0 - param0
    211 // |
    212 // +- param1
    213 // |
    214 // add
    215 //
    216 // Verifies that the reshape0 does not sink below add, because param1 is not
    217 // trivially reshapable nor is a Reshape/Transpose.
    218 TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) {
    219   HloComputation::Builder builder(TestName());
    220   auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
    221   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
    222       0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0"));
    223   auto reshape0 =
    224       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
    225   auto param1 = builder.AddInstruction(
    226       HloInstruction::CreateParameter(1, root_shape, "param1"));
    227   builder.AddInstruction(HloInstruction::CreateBinary(
    228       root_shape, HloOpcode::kAdd, reshape0, param1));
    229 
    230   auto computation = module().AddEntryComputation(builder.Build());
    231 
    232   EXPECT_THAT(computation->root_instruction(),
    233               op::Add(op::Reshape(param0), param1));
    234   EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie());
    235 
    236   EXPECT_THAT(computation->root_instruction(),
    237               op::Add(op::Reshape(param0), param1));
    238   EXPECT_EQ(root_shape.DebugString(),
    239             computation->root_instruction()->shape().DebugString());
    240 }
    241 
    242 // For a graph that looks like:
    243 //
    244 // +- pred
    245 // |
    246 // +- reshape0 - const0
    247 // |
    248 // +- reshape1 - const1
    249 // |
    250 // select
    251 //
    252 // Verifies that we don't unnecessarily sink reshapes, which are in fact
    253 // trivial reshapes.
    254 TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) {
    255   HloComputation::Builder builder(TestName());
    256   auto root_shape = ShapeUtil::MakeShape(F32, {3, 2});
    257   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    258       Literal::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
    259   auto reshape0 =
    260       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const0));
    261 
    262   auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
    263       Literal::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
    264   auto reshape1 =
    265       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1));
    266 
    267   auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
    268       0, ShapeUtil::MakeShape(PRED, {3, 2}), "pred"));
    269 
    270   builder.AddInstruction(HloInstruction::CreateTernary(
    271       root_shape, HloOpcode::kSelect, pred, reshape0, reshape1));
    272 
    273   auto computation = module().AddEntryComputation(builder.Build());
    274 
    275   EXPECT_THAT(computation->root_instruction(),
    276               op::Select(pred, op::Reshape(const0), op::Reshape(const1)));
    277 
    278   EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie());
    279 
    280   EXPECT_THAT(computation->root_instruction(),
    281               op::Select(pred, op::Reshape(const0), op::Reshape(const1)));
    282   EXPECT_EQ(root_shape.DebugString(),
    283             computation->root_instruction()->shape().DebugString());
    284 }
    285 
    286 // For a graph that looks like:
    287 //
    288 // +- reshape0 - param0
    289 // |
    290 // +- const1
    291 // |
    292 // add
    293 //
    294 // where there is only 1 non-trivial reshape (reshape0), we sink the reshape
    295 // here for canonicalization benefit:
    296 //
    297 // +- param0
    298 // |
    299 // +- reshape1 - const1
    300 // |
    301 // add
    302 // |
    303 // reshape2
    304 //
    305 // (note that reshape1 here is trivial).
    306 TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) {
    307   HloComputation::Builder builder(TestName());
    308   auto root_shape = ShapeUtil::MakeShape(F32, {2, 3});
    309   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
    310       0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param0"));
    311   auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
    312       Literal::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
    313   auto reshape0 =
    314       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
    315   builder.AddInstruction(HloInstruction::CreateBinary(
    316       root_shape, HloOpcode::kAdd, reshape0, const1));
    317 
    318   auto computation = module().AddEntryComputation(builder.Build());
    319 
    320   EXPECT_THAT(computation->root_instruction(),
    321               op::Add(op::Reshape(param0), const1));
    322 
    323   EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie());
    324 
    325   EXPECT_THAT(computation->root_instruction(),
    326               op::Reshape(op::Add(param0, op::Reshape(const1))));
    327   EXPECT_EQ(root_shape.DebugString(),
    328             computation->root_instruction()->shape().DebugString());
    329 }
    330 
    331 // For a graph that looks like:
    332 //
    333 // +- reshape0 - param0 (shape A)
    334 // |
    335 // +- reshape1 - const1 (shape B)
    336 // |
    337 // add
    338 //
    339 // There is 1 non-trivial reshape (reshape0). It's not clear whether reshape1
    340 // should be trivial or not; conceptually it's trivial, but handling it would
    341 // complicate the rest of our logic.
    342 //
    343 // For now we treat it as non-trivial, so we verify that we don't sink the
    344 // reshapes in this case.
    345 TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) {
    346   HloComputation::Builder builder(TestName());
    347   auto root_shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
    348   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
    349       0, ShapeUtil::MakeShape(F32, {1, 3}), "param0"));
    350   auto const1 = builder.AddInstruction(
    351       HloInstruction::CreateConstant(Literal::CreateR1<float>({9, 8, 7})));
    352   auto reshape0 =
    353       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
    354   auto reshape1 =
    355       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1));
    356 
    357   builder.AddInstruction(HloInstruction::CreateBinary(
    358       root_shape, HloOpcode::kAdd, reshape0, reshape1));
    359 
    360   auto computation = module().AddEntryComputation(builder.Build());
    361 
    362   EXPECT_THAT(computation->root_instruction(),
    363               op::Add(op::Reshape(param0), op::Reshape(const1)));
    364 
    365   EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie());
    366 
    367   EXPECT_THAT(computation->root_instruction(),
    368               op::Add(op::Reshape(param0), op::Reshape(const1)));
    369   EXPECT_EQ(root_shape.DebugString(),
    370             computation->root_instruction()->shape().DebugString());
    371 }
    372 
    373 TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) {
    374   HloComputation::Builder builder(TestName());
    375   auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
    376   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
    377       0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0"));
    378   auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
    379       1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1"));
    380   auto reshape0 =
    381       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
    382   auto reshape1 =
    383       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
    384   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
    385       root_shape, HloOpcode::kAdd, reshape0, reshape1));
    386 
    387   auto computation = module().AddEntryComputation(builder.Build());
    388   computation->CreateFusionInstruction({add},
    389                                        HloInstruction::FusionKind::kLoop);
    390 
    391   EXPECT_THAT(computation->root_instruction(),
    392               op::Fusion(op::Reshape(param0), op::Reshape(param1)));
    393 
    394   EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie());
    395 
    396   EXPECT_THAT(computation->root_instruction(),
    397               op::Reshape(op::Fusion(param0, param1)));
    398   EXPECT_EQ(root_shape.DebugString(),
    399             computation->root_instruction()->shape().DebugString());
    400 }
    401 
    402 TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) {
    403   HloComputation::Builder builder(TestName());
    404   auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
    405   auto pred_shape = ShapeUtil::MakeShape(PRED, {8, 7});
    406   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
    407       0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0"));
    408   auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
    409       1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1"));
    410   auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
    411       2, ShapeUtil::MakeShape(PRED, {1, 8, 1, 7}), "pred"));
    412   auto reshape0 =
    413       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
    414   auto reshape1 =
    415       builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
    416   auto reshape_pred =
    417       builder.AddInstruction(HloInstruction::CreateReshape(pred_shape, pred));
    418   builder.AddInstruction(HloInstruction::CreateTernary(
    419       root_shape, HloOpcode::kSelect, reshape_pred, reshape0, reshape1));
    420 
    421   auto computation = module().AddEntryComputation(builder.Build());
    422 
    423   EXPECT_THAT(
    424       computation->root_instruction(),
    425       op::Select(op::Reshape(pred), op::Reshape(param0), op::Reshape(param1)));
    426 
    427   EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie());
    428 
    429   EXPECT_THAT(computation->root_instruction(),
    430               op::Reshape(op::Select(pred, param0, param1)));
    431   EXPECT_EQ(root_shape.DebugString(),
    432             computation->root_instruction()->shape().DebugString());
    433 }
    434 
    435 TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) {
    436   HloComputation::Builder builder(TestName());
    437   auto root_shape = ShapeUtil::MakeShape(F32, {});
    438   auto pred_shape = ShapeUtil::MakeShape(PRED, {});
    439   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
    440       0, ShapeUtil::MakeShape(F32, {}), "param0"));
    441   auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
    442       1, ShapeUtil::MakeShape(F32, {}), "param1"));
    443   auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
    444       2, ShapeUtil::MakeShape(PRED, {1, 1, 1}), "pred"));
    445   auto reshape_pred =
    446       builder.AddInstruction(HloInstruction::CreateReshape(pred_shape, pred));
    447   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
    448       root_shape, HloOpcode::kSelect, reshape_pred, param0, param1));
    449 
    450   auto computation = module().AddEntryComputation(builder.Build());
    451   EXPECT_THAT(computation->root_instruction(),
    452               op::Select(op::Reshape(pred), param0, param1));
    453 
    454   EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie());
    455 
    456   EXPECT_THAT(computation->root_instruction(),
    457               op::Select(op::Reshape(pred), param0, param1));
    458   EXPECT_EQ(select, computation->root_instruction());
    459 }
    460 
    461 // Tree looks like:
    462 //
    463 // param0 [1,128,1]
    464 //  |
    465 // reshape [128,1]          constant [128,1024]
    466 //   \                         /
    467 //     multiply w/implicit broadcast [128,1024]
    468 //
    469 // The reshape mover would like to sink the reshape below the multiply.
    470 //
    471 // Previously we would attempt to insert a reshape of the constant to [1,128,1]
    472 // (which is unsound, because it has a different number of elements) as
    473 // preparation for sinking the reshape.
    474 //
    475 // To eliminate the unsoundness, we outlaw reshape sinking when one of the
    476 // operands is implicitly broadcast in the elementwise consumer.
    477 //
    478 // TODO(b/37799338) However, it would be possible in this case to do a more
    479 // in-depth analysis to get reshape movement to occur:
    480 //
    481 // 1. Note that the broadcast dimension (logical dimension 1) in the operands
    482 //    would map back to logical dimension 2 in the param0 node.
    483 // 2. Match rank of the constant to the param0 node (by prepending a trivial 1
    484 //    dimension).
    485 // 3. Reshape to [128,1024] at the root.
    486 //
    487 // But this is not currently done.
    488 TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) {
    489   HloComputation::Builder builder(TestName());
    490   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
    491       0, ShapeUtil::MakeShape(F32, {1, 128, 1}), "param0"));
    492   auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
    493       ShapeUtil::MakeShape(F32, {128, 1}), param0));
    494   Array2D<float> a(128, 1024);
    495   auto literal = Literal::CreateR2FromArray2D<float>(a);
    496   auto constant = builder.AddInstruction(
    497       HloInstruction::CreateConstant(std::move(literal)));
    498   auto multiply = builder.AddInstruction(HloInstruction::CreateBinary(
    499       constant->shape(), HloOpcode::kMultiply, constant, reshape));
    500 
    501   auto computation = module().AddEntryComputation(builder.Build());
    502   EXPECT_THAT(computation->root_instruction(),
    503               op::Multiply(op::Constant(), op::Reshape(param0)));
    504 
    505   EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie());
    506 
    507   EXPECT_THAT(computation->root_instruction(),
    508               op::Multiply(op::Constant(), op::Reshape(param0)));
    509   EXPECT_EQ(multiply, computation->root_instruction());
    510 }
    511 
    512 // Tree looks like this:
    513 //
    514 // add1
    515 // |
    516 // +- reshape2 - param2
    517 // |
    518 // +- reshape3 - add0
    519 //               |
    520 //               + reshape0 - param0
    521 //               |
    522 //               + reshape1 - param1
    523 //
    524 // We expect reshape{0,1} AND reshape{2,3} to be lifted.
    525 TEST_F(ReshapeMoverTest, MultiplePasses) {
    526   auto shape1 = ShapeUtil::MakeShape(F32, {1, 8, 1, 7});
    527   auto shape2 = ShapeUtil::MakeShape(F32, {8, 7, 1});
    528   auto shape3 = ShapeUtil::MakeShape(F32, {8, 7});
    529   HloComputation::Builder builder(TestName());
    530   auto param0 = builder.AddInstruction(
    531       HloInstruction::CreateParameter(0, shape1, "param0"));
    532   auto param1 = builder.AddInstruction(
    533       HloInstruction::CreateParameter(1, shape1, "param1"));
    534   auto param2 = builder.AddInstruction(
    535       HloInstruction::CreateParameter(2, shape2, "param2"));
    536   auto reshape0 =
    537       builder.AddInstruction(HloInstruction::CreateReshape(shape2, param0));
    538   auto reshape1 =
    539       builder.AddInstruction(HloInstruction::CreateReshape(shape2, param1));
    540   auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
    541       shape2, HloOpcode::kAdd, reshape0, reshape1));
    542   auto reshape2 =
    543       builder.AddInstruction(HloInstruction::CreateReshape(shape3, param2));
    544   auto reshape3 =
    545       builder.AddInstruction(HloInstruction::CreateReshape(shape3, add0));
    546   builder.AddInstruction(HloInstruction::CreateBinary(shape3, HloOpcode::kAdd,
    547                                                       reshape2, reshape3));
    548 
    549   auto computation = module().AddEntryComputation(builder.Build());
    550 
    551   EXPECT_THAT(
    552       computation->root_instruction(),
    553       op::Add(op::Reshape(param2),
    554               op::Reshape(op::Add(op::Reshape(param0), op::Reshape(param1)))));
    555 
    556   EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie());
    557 
    558   EXPECT_THAT(
    559       computation->root_instruction(),
    560       op::Reshape(op::Add(param2, op::Reshape(op::Add(param0, param1)))));
    561 }
    562 
    563 }  // namespace
    564 }  // namespace xla
    565