Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
     17 
     18 #include <initializer_list>
     19 #include <memory>
     20 #include <utility>
     21 #include <vector>
     22 
     23 #include "tensorflow/compiler/xla/layout_util.h"
     24 #include "tensorflow/compiler/xla/literal_util.h"
     25 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
     26 #include "tensorflow/compiler/xla/service/computation_layout.h"
     27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     29 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     30 #include "tensorflow/compiler/xla/service/hlo_module.h"
     31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     32 #include "tensorflow/compiler/xla/shape_layout.h"
     33 #include "tensorflow/compiler/xla/shape_util.h"
     34 #include "tensorflow/compiler/xla/test.h"
     35 #include "tensorflow/compiler/xla/test_helpers.h"
     36 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     37 #include "tensorflow/compiler/xla/tests/test_utils.h"
     38 #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
     39 #include "tensorflow/compiler/xla/util.h"
     40 #include "tensorflow/compiler/xla/xla_data.pb.h"
     41 #include "tensorflow/core/lib/core/status.h"
     42 #include "tensorflow/core/lib/core/status_test_util.h"
     43 #include "tensorflow/core/lib/gtl/array_slice.h"
     44 
     45 namespace op = xla::testing::opcode_matchers;
     46 
     47 namespace xla {
     48 namespace {
     49 
     50 using ::testing::ElementsAre;
     51 
     52 class LayoutAssignmentTest : public HloTestBase {
     53  protected:
     54   void AssignLayouts(HloModule* module,
     55                      ComputationLayout* entry_computation_layout) {
     56     LayoutAssignment layout_assignment(entry_computation_layout);
     57     EXPECT_IS_OK(layout_assignment.Run(module).status());
     58   }
     59 };
     60 
     61 TEST_F(LayoutAssignmentTest, ComputationLayout) {
     62   // Verify the layouts of the root and parameter instructions of a computation
     63   // match the ComputationLayout for two different layouts.
     64   std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}};
     65   for (auto& minor_to_major : minor_to_majors) {
     66     auto builder = HloComputation::Builder(TestName());
     67     Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
     68     auto param0 = builder.AddInstruction(
     69         HloInstruction::CreateParameter(0, ashape, "param0"));
     70     auto param1 = builder.AddInstruction(
     71         HloInstruction::CreateParameter(1, ashape, "param1"));
     72     auto add = builder.AddInstruction(
     73         HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
     74     auto module = CreateNewModule();
     75     HloComputation* computation = module->AddEntryComputation(builder.Build());
     76 
     77     Layout layout = LayoutUtil::MakeLayout(minor_to_major);
     78     Shape shape(ashape);
     79     *shape.mutable_layout() = layout;
     80     const ShapeLayout shape_layout(shape);
     81 
     82     ComputationLayout computation_layout(computation->ComputeProgramShape());
     83     *computation_layout.mutable_parameter_layout(0) = shape_layout;
     84     *computation_layout.mutable_parameter_layout(1) = shape_layout;
     85     *computation_layout.mutable_result_layout() = shape_layout;
     86     AssignLayouts(module.get(), &computation_layout);
     87     EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
     88     EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
     89     EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
     90   }
     91 }
     92 
     93 TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
     94   // Verify the layouts of the root and parameter instructions of a computation
     95   // match the ComputationLayout which has mixed layout.
     96   auto builder = HloComputation::Builder(TestName());
     97   Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
     98   auto param0 = builder.AddInstruction(
     99       HloInstruction::CreateParameter(0, ashape, "param0"));
    100   auto param1 = builder.AddInstruction(
    101       HloInstruction::CreateParameter(1, ashape, "param1"));
    102   builder.AddInstruction(
    103       HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
    104   auto module = CreateNewModule();
    105   HloComputation* computation = module->AddEntryComputation(builder.Build());
    106 
    107   Layout col_major_layout = LayoutUtil::MakeLayout({1, 0});
    108   Shape col_major_shape(ashape);
    109   *col_major_shape.mutable_layout() = col_major_layout;
    110   const ShapeLayout col_major(col_major_shape);
    111 
    112   Layout row_major_layout = LayoutUtil::MakeLayout({0, 1});
    113   Shape row_major_shape(ashape);
    114   *row_major_shape.mutable_layout() = row_major_layout;
    115   const ShapeLayout row_major(row_major_shape);
    116 
    117   ComputationLayout computation_layout(computation->ComputeProgramShape());
    118   *computation_layout.mutable_parameter_layout(0) = col_major;
    119   *computation_layout.mutable_parameter_layout(1) = row_major;
    120   *computation_layout.mutable_result_layout() = col_major;
    121 
    122   AssignLayouts(module.get(), &computation_layout);
    123   EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
    124   EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
    125   EXPECT_TRUE(LayoutUtil::Equal(
    126       col_major_layout, computation->root_instruction()->shape().layout()));
    127 }
    128 
    129 TEST_F(LayoutAssignmentTest, FusionInstruction) {
    130   // Verify that the layout of the fused parameters in a fusion instruction
    131   // match that of the fusion operands. Other fused instructions should have no
    132   // layout.
    133   std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}};
    134   for (auto& minor_to_major : minor_to_majors) {
    135     auto builder = HloComputation::Builder(TestName());
    136     auto constant_literal1 = Literal::CreateR2WithLayout<float>(
    137         {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
    138     auto constant_literal2 = Literal::CreateR2WithLayout<float>(
    139         {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
    140     Shape ashape = constant_literal1->shape();
    141 
    142     auto constant1 = builder.AddInstruction(
    143         HloInstruction::CreateConstant(std::move(constant_literal1)));
    144     auto constant2 = builder.AddInstruction(
    145         HloInstruction::CreateConstant(std::move(constant_literal2)));
    146     auto add = builder.AddInstruction(HloInstruction::CreateBinary(
    147         ashape, HloOpcode::kAdd, constant1, constant2));
    148     auto negate1 = builder.AddInstruction(
    149         HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, add));
    150     auto negate2 = builder.AddInstruction(
    151         HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1));
    152 
    153     auto module = CreateNewModule();
    154     HloComputation* computation = module->AddEntryComputation(builder.Build());
    155 
    156     auto fusion = computation->CreateFusionInstruction(
    157         {negate2, negate1, add}, HloInstruction::FusionKind::kLoop);
    158 
    159     Layout layout = LayoutUtil::MakeLayout(minor_to_major);
    160     Shape shape(ashape);
    161     *shape.mutable_layout() = layout;
    162     const ShapeLayout shape_layout(shape);
    163 
    164     ComputationLayout computation_layout(computation->ComputeProgramShape());
    165     *computation_layout.mutable_result_layout() = shape_layout;
    166 
    167     AssignLayouts(module.get(), &computation_layout);
    168 
    169     EXPECT_TRUE(LayoutUtil::Equal(
    170         layout, fusion->fused_parameter(0)->shape().layout()));
    171     EXPECT_TRUE(LayoutUtil::Equal(
    172         layout, fusion->fused_parameter(1)->shape().layout()));
    173     EXPECT_TRUE(LayoutUtil::Equal(
    174         layout, fusion->fused_expression_root()->shape().layout()));
    175 
    176     // Inner fused node should not have layout.
    177     EXPECT_FALSE(LayoutUtil::HasLayout(
    178         fusion->fused_expression_root()->operand(0)->shape()));
    179   }
    180 }
    181 
    182 TEST_F(LayoutAssignmentTest, TupleLayout) {
    183   // Verify the layouts of a tuple are assigned properly (the element layouts
    184   // match their source).
    185   auto builder = HloComputation::Builder(TestName());
    186   auto constant0 = builder.AddInstruction(
    187       HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
    188           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
    189   auto constant1 = builder.AddInstruction(
    190       HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
    191           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
    192   auto tuple = builder.AddInstruction(
    193       HloInstruction::CreateTuple({constant0, constant1}));
    194 
    195   // To avoid having to construct a tuple layout in the ComputationLayout below,
    196   // make the result of the instruction be an array.
    197   auto get_element0 = builder.AddInstruction(
    198       HloInstruction::CreateGetTupleElement(constant0->shape(), tuple, 0));
    199   auto negate = builder.AddInstruction(HloInstruction::CreateUnary(
    200       constant0->shape(), HloOpcode::kNegate, get_element0));
    201 
    202   auto module = CreateNewModule();
    203   module->AddEntryComputation(builder.Build());
    204 
    205   ComputationLayout computation_layout(
    206       module->entry_computation()->ComputeProgramShape());
    207 
    208   AssignLayouts(module.get(), &computation_layout);
    209 
    210   EXPECT_TRUE(
    211       LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
    212 
    213   EXPECT_TRUE(LayoutUtil::HasLayout(tuple->shape()));
    214   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
    215       negate->shape(), computation_layout.result_layout().shape()));
    216   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
    217       ShapeUtil::GetTupleElementShape(tuple->shape(), 1), constant1->shape()));
    218 }
    219 
    220 TEST_F(LayoutAssignmentTest, TupleSelect) {
    221   // Verify layouts of a select with tuple operands is assigned properly.
    222   auto builder = HloComputation::Builder(TestName());
    223   auto constant0 = builder.AddInstruction(
    224       HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
    225           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
    226   auto constant1 = builder.AddInstruction(
    227       HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
    228           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
    229   auto tuple0 = builder.AddInstruction(
    230       HloInstruction::CreateTuple({constant0, constant1}));
    231   auto tuple1 = builder.AddInstruction(
    232       HloInstruction::CreateTuple({constant0, constant1}));
    233 
    234   auto pred = builder.AddInstruction(
    235       HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
    236 
    237   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
    238       tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1));
    239 
    240   auto module = CreateNewModule();
    241   module->AddEntryComputation(builder.Build());
    242 
    243   ComputationLayout computation_layout(
    244       module->entry_computation()->ComputeProgramShape());
    245   Shape result_shape =
    246       ShapeUtil::MakeTupleShape({constant0->shape(), constant1->shape()});
    247   TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
    248       result_shape));
    249 
    250   AssignLayouts(module.get(), &computation_layout);
    251 
    252   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape()));
    253 }
    254 
    255 TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
    256   // Construct following computation which has conflicting layouts for two
    257   // elements of a tuple which share the same source logicalb buffer:
    258   //
    259   // %constant = Constant(...)
    260   // %inner_tuple = Tuple(%constant)
    261   // %nested_tuple = Tuple(%inner_tuple, %inner_tuple)
    262   //
    263   // Result layout col-major for the first element and row-major for the
    264   // second. This results in the conflict where the element of the inner_tuple
    265   // needs to be both col and row major. This is resolved by deep-copying the
    266   // tuple and assigning the layouts of the copied arrays as needed.
    267   auto builder = HloComputation::Builder(TestName());
    268   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
    269       Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
    270   auto inner_tuple =
    271       builder.AddInstruction(HloInstruction::CreateTuple({constant}));
    272   auto nested_tuple = builder.AddInstruction(
    273       HloInstruction::CreateTuple({inner_tuple, inner_tuple}));
    274 
    275   auto module = CreateNewModule();
    276   module->AddEntryComputation(builder.Build());
    277 
    278   ComputationLayout computation_layout(
    279       module->entry_computation()->ComputeProgramShape());
    280   Shape result_shape = nested_tuple->shape();
    281   *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) =
    282       ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
    283   *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{1, 0}) =
    284       ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
    285   TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
    286       result_shape));
    287 
    288   LayoutAssignment layout_assignment(&computation_layout);
    289   AssignLayouts(module.get(), &computation_layout);
    290 
    291   // Layout assignment should have deep copied the result of the computation to
    292   // address the layout conflict. This results in several Tuple() and
    293   // GetTupleElement() instructions. Running algebraic simplification should
    294   // clean up the code to something like:
    295   //
    296   //  %constant = Constant(...) layout={1,0}
    297   //  %tuple.0 = Tuple(%constant) layout=({1,0})
    298   //  %copy = Copy(%constant) layout={0,1}  # layout transposed
    299   //  %tuple.1 = Tuple(%copy) layout=({0,1})
    300   //  %tuple.2 = Tuple(%tuple.0, %tuple.1) layout=(({1,0}), ({0,1}))
    301   //
    302   EXPECT_TRUE(
    303       AlgebraicSimplifier(/*is_layout_sensitive=*/true,
    304                           [](const Shape&, const Shape&) { return false; })
    305           .Run(module.get())
    306           .ValueOrDie());
    307   HloInstruction* root = module->entry_computation()->root_instruction();
    308   // Verify layout of the root and the root's operands.
    309   EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape()));
    310   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}),
    311                                root->operand(0)->shape()));
    312   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {1}),
    313                                root->operand(1)->shape()));
    314 
    315   // Verify the structure of the HLO graph.
    316   EXPECT_THAT(root,
    317               op::Tuple(op::Tuple(constant), op::Tuple(op::Copy(constant))));
    318 }
    319 
    320 TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
    321   // param -> log -> reshape -> tanh
    322   auto builder = HloComputation::Builder(TestName());
    323   Shape ashape = ShapeUtil::MakeShape(F32, {1, 2, 3, 1});
    324   Shape bshape = ShapeUtil::MakeShape(F32, {3, 1, 2});
    325   auto param = builder.AddInstruction(
    326       HloInstruction::CreateParameter(0, ashape, "param"));
    327   auto log = builder.AddInstruction(
    328       HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
    329   auto reshape =
    330       builder.AddInstruction(HloInstruction::CreateReshape(bshape, log));
    331   auto tanh = builder.AddInstruction(
    332       HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape));
    333 
    334   auto module = CreateNewModule();
    335   HloComputation* computation =
    336       module->AddEntryComputation(builder.Build(tanh));
    337 
    338   Shape ashape_with_layout(ashape);
    339   Shape bshape_with_layout(bshape);
    340   *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 2, 1, 3});
    341   *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
    342 
    343   ComputationLayout computation_layout(computation->ComputeProgramShape());
    344   *computation_layout.mutable_parameter_layout(0) =
    345       ShapeLayout(ashape_with_layout);
    346   *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
    347   AssignLayouts(module.get(), &computation_layout);
    348 
    349   auto log_minor_to_major =
    350       AsInt64Slice(log->shape().layout().minor_to_major());
    351   EXPECT_GT(PositionInContainer(log_minor_to_major, 1),
    352             PositionInContainer(log_minor_to_major, 2));
    353 
    354   auto reshape_minor_to_major =
    355       AsInt64Slice(reshape->shape().layout().minor_to_major());
    356   EXPECT_GT(PositionInContainer(reshape_minor_to_major, 0),
    357             PositionInContainer(reshape_minor_to_major, 2));
    358 }
    359 
    360 // Test whether LayoutAssignment assigns layouts to elementwise operations to
    361 // keep linear indices valid across them, and to transpositions to make them
    362 // bitcasts.
    363 TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
    364   // param -> log -> transpose -> tanh
    365   auto builder = HloComputation::Builder(TestName());
    366   Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
    367   Shape bshape = ShapeUtil::MakeShape(F32, {12, 42});
    368   auto param = builder.AddInstruction(
    369       HloInstruction::CreateParameter(0, ashape, "param"));
    370   auto log = builder.AddInstruction(
    371       HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
    372   auto transpose = builder.AddInstruction(
    373       HloInstruction::CreateTranspose(bshape, log, {1, 0}));
    374   auto tanh = builder.AddInstruction(
    375       HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose));
    376   auto module = CreateNewModule();
    377   auto computation = module->AddEntryComputation(builder.Build(tanh));
    378 
    379   Shape ashape_with_layout(ashape);
    380   Shape bshape_with_layout(bshape);
    381   *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
    382   *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
    383 
    384   ComputationLayout computation_layout(computation->ComputeProgramShape());
    385   *computation_layout.mutable_parameter_layout(0) =
    386       ShapeLayout(ashape_with_layout);
    387   *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
    388   AssignLayouts(module.get(), &computation_layout);
    389 
    390   EXPECT_TRUE(
    391       LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
    392   EXPECT_TRUE(LayoutUtil::Equal(bshape_with_layout.layout(),
    393                                 transpose->shape().layout()));
    394   EXPECT_TRUE(
    395       LayoutUtil::Equal(bshape_with_layout.layout(), tanh->shape().layout()));
    396 }
    397 
    398 // Test whether LayoutAssignment assigns layouts to transpositions to make them
    399 // bitcasts.
    400 TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
    401   // param -> broadcast -> transpose
    402   auto builder = HloComputation::Builder(TestName());
    403   Shape ashape = ShapeUtil::MakeShape(F32, {3, 4});
    404   Shape bshape = ShapeUtil::MakeShape(F32, {2, 3, 4});
    405   Shape cshape = ShapeUtil::MakeShape(F32, {4, 3, 2});
    406   auto param = builder.AddInstruction(
    407       HloInstruction::CreateParameter(0, ashape, "param"));
    408   auto broadcast = builder.AddInstruction(
    409       HloInstruction::CreateBroadcast(bshape, param, {1, 2}));
    410   auto transpose = builder.AddInstruction(
    411       HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0}));
    412   auto module = CreateNewModule();
    413   HloComputation* computation =
    414       module->AddEntryComputation(builder.Build(transpose));
    415 
    416   Shape input_shape_with_layout(ashape);
    417   Shape output_shape_with_layout(cshape);
    418   *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
    419   *output_shape_with_layout.mutable_layout() =
    420       LayoutUtil::MakeLayout({2, 1, 0});
    421 
    422   ComputationLayout computation_layout(computation->ComputeProgramShape());
    423   *computation_layout.mutable_parameter_layout(0) =
    424       ShapeLayout(input_shape_with_layout);
    425   *computation_layout.mutable_result_layout() =
    426       ShapeLayout(output_shape_with_layout);
    427   AssignLayouts(module.get(), &computation_layout);
    428 
    429   EXPECT_THAT(broadcast->shape().layout().minor_to_major(),
    430               ElementsAre(0, 1, 2));
    431 }
    432 
    433 TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
    434   // param[4] -> broadcast[3x4] ------> transpose[4x3]-------- -------> tuple
    435   //                            \                                     /
    436   //                             \-> tanh[3x4] -> broadcast2[2x3x4] -/
    437   //
    438   // The layout of `transpose` is set to {1,0} because it provides a buffer to
    439   // the computation result which has a fixed layout.. Therefore, `broadcast`
    440   // (the operand of transpose) is expected to have layout {0,1} so that the
    441   // transpose is a bitcast. Furthermore, `tanh` is expected to have the same
    442   // layout as `broadcast` (i.e. {0,1}) because `tanh` is elementwise.
    443   Shape f32_4 = ShapeUtil::MakeShape(F32, {4});
    444   Shape f32_34 = ShapeUtil::MakeShape(F32, {3, 4});
    445   Shape f32_43 = ShapeUtil::MakeShape(F32, {4, 3});
    446   Shape f32_234 = ShapeUtil::MakeShape(F32, {2, 3, 4});
    447 
    448   auto builder = HloComputation::Builder(TestName());
    449   auto param = builder.AddInstruction(
    450       HloInstruction::CreateParameter(0, f32_4, "param"));
    451   auto broadcast = builder.AddInstruction(
    452       HloInstruction::CreateBroadcast(f32_34, param, {3}));
    453   auto transpose = builder.AddInstruction(
    454       HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0}));
    455   auto tanh = builder.AddInstruction(
    456       HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast));
    457   auto broadcast2 = builder.AddInstruction(
    458       HloInstruction::CreateBroadcast(f32_234, tanh, {2}));
    459   auto tuple = builder.AddInstruction(
    460       HloInstruction::CreateTuple({transpose, broadcast2}));
    461   auto module = CreateNewModule();
    462   HloComputation* computation =
    463       module->AddEntryComputation(builder.Build(tuple));
    464 
    465   ComputationLayout computation_layout(computation->ComputeProgramShape());
    466   Shape param_shape_with_layout(f32_4);
    467   Shape transpose_shape_with_layout(f32_43);
    468   Shape broadcast2_shape_with_layout(f32_234);
    469   *param_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0});
    470   *transpose_shape_with_layout.mutable_layout() =
    471       LayoutUtil::MakeLayout({1, 0});
    472   *broadcast2_shape_with_layout.mutable_layout() =
    473       LayoutUtil::MakeLayout({2, 1, 0});
    474 
    475   *computation_layout.mutable_parameter_layout(0) =
    476       ShapeLayout(param_shape_with_layout);
    477   *computation_layout.mutable_result_layout() =
    478       ShapeLayout(ShapeUtil::MakeTupleShape(
    479           {transpose_shape_with_layout, broadcast2_shape_with_layout}));
    480   AssignLayouts(module.get(), &computation_layout);
    481 
    482   EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1));
    483   EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0));
    484   EXPECT_THAT(tanh->shape().layout().minor_to_major(), ElementsAre(0, 1));
    485 }
    486 
    487 class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment {
    488  public:
    489   explicit OperandsMustBeTheSameLayoutAssignment(
    490       ComputationLayout* entry_computation_layout)
    491       : LayoutAssignment(entry_computation_layout) {}
    492 
    493  protected:
    494   Status PropagateBufferConstraint(
    495       const BufferLayoutConstraint& buffer_constraint,
    496       LayoutConstraints* constraints) override {
    497     const LogicalBuffer& buffer = buffer_constraint.buffer();
    498     const HloInstruction* instruction = buffer.instruction();
    499 
    500     // Force the operands' layout to the output layout.
    501     for (int64 operand_no = 0; operand_no < instruction->operand_count();
    502          ++operand_no) {
    503       const HloInstruction* operand = instruction->operand(operand_no);
    504       if (ShapeUtil::Rank(instruction->shape()) !=
    505           ShapeUtil::Rank(operand->shape())) {
    506         continue;
    507       }
    508       TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
    509           buffer_constraint.layout(), instruction, operand_no,
    510           /*mandatory=*/true));
    511     }
    512     return PropagateBufferConstraintToUses(buffer_constraint, constraints);
    513   }
    514 };
    515 
    516 TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
    517   // param0 -> concatenate -> reshape
    518   // param1   -^
    519   auto builder = HloComputation::Builder(TestName());
    520   Shape ashape = ShapeUtil::MakeShape(F32, {50, 1});
    521   Shape bshape = ShapeUtil::MakeShape(F32, {50, 2});
    522   Shape cshape = ShapeUtil::MakeShape(F32, {100});
    523   auto param0 = builder.AddInstruction(
    524       HloInstruction::CreateParameter(0, ashape, "param"));
    525   auto param1 = builder.AddInstruction(
    526       HloInstruction::CreateParameter(1, ashape, "param"));
    527   auto concatenate = builder.AddInstruction(
    528       HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1));
    529   auto reshape = builder.AddInstruction(
    530       HloInstruction::CreateReshape(cshape, concatenate));
    531   auto module = CreateNewModule();
    532   HloComputation* computation =
    533       module->AddEntryComputation(builder.Build(reshape));
    534 
    535   Shape param0_shape_with_layout(ashape);
    536   Shape param1_shape_with_layout(ashape);
    537   *param0_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
    538   *param1_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
    539 
    540   ComputationLayout computation_layout(computation->ComputeProgramShape());
    541   *computation_layout.mutable_parameter_layout(0) =
    542       ShapeLayout(param0_shape_with_layout);
    543   *computation_layout.mutable_parameter_layout(1) =
    544       ShapeLayout(param1_shape_with_layout);
    545   OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout);
    546   EXPECT_IS_OK(layout_assignment.Run(module.get()).status());
    547 
    548   EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode());
    549   EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(),
    550               ElementsAre(1, 0));
    551   EXPECT_THAT(concatenate->operand(1)->shape().layout().minor_to_major(),
    552               ElementsAre(1, 0));
    553   EXPECT_THAT(concatenate->shape().layout().minor_to_major(),
    554               ElementsAre(1, 0));
    555 }
    556 
    557 // Test layout assignment of a transpose into a bitcast based on its operand.
    558 TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) {
    559   auto builder = HloComputation::Builder(TestName());
    560   Shape input_shape_with_layout =
    561       ShapeUtil::MakeShapeWithLayout(F32, {3, 5, 6, 7}, {2, 0, 3, 1});
    562   auto param = builder.AddInstruction(
    563       HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
    564   auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
    565       ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1}));
    566   auto module = CreateNewModule();
    567   HloComputation* computation =
    568       module->AddEntryComputation(builder.Build(transpose));
    569   ComputationLayout computation_layout(computation->ComputeProgramShape());
    570   AssignLayouts(module.get(), &computation_layout);
    571   EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
    572                                             transpose->shape(), {2, 3, 0, 1}));
    573 }
    574 // Test layout assignment of a transpose into a bitcast based on its user.
    575 TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
    576   auto builder = HloComputation::Builder(TestName());
    577   Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7});
    578   auto constant = builder.AddInstruction(
    579       HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
    580   auto broadcast = builder.AddInstruction(
    581       HloInstruction::CreateBroadcast(input_shape, constant, {}));
    582   auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
    583       ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1}));
    584   auto module = CreateNewModule();
    585   HloComputation* computation =
    586       module->AddEntryComputation(builder.Build(transpose));
    587   ComputationLayout computation_layout(computation->ComputeProgramShape());
    588   AssignLayouts(module.get(), &computation_layout);
    589   EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
    590                                             transpose->shape(), {2, 3, 0, 1}));
    591 }
    592 
    593 // A GTE inside of a fusion node inherits the layout of its operand (which
    594 // should, if we keep following operands, eventually be a parameter).
    595 TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
    596   const char* module_str = R"(
    597     HloModule test_module
    598 
    599     fused_computation {
    600       fparam = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
    601       gte0 = f32[2,2,2] get-tuple-element(fparam), index=0
    602       gte1 = (f32[2,2,2], f32[2,2,2]) get-tuple-element(fparam), index=1
    603       gte1a = f32[2,2,2] get-tuple-element(gte1), index=0
    604       gte1b = f32[2,2,2] get-tuple-element(gte1), index=1
    605       add = f32[2,2,2] add(gte1a, gte1b)
    606       ROOT fresult = f32[2,2,2] add(gte0, add)
    607     }
    608 
    609     ENTRY entry_computation {
    610       param = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
    611       ROOT fusion =
    612         f32[2,2,2] fusion(param), kind=kLoop, calls=fused_computation
    613     }
    614   )";
    615 
    616   auto module = tools::Parse(module_str).ValueOrDie();
    617   ComputationLayout computation_layout(
    618       module->entry_computation()->ComputeProgramShape());
    619   Shape param_shape = ShapeUtil::MakeTupleShape(
    620       {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
    621        ShapeUtil::MakeTupleShape({
    622            ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {1, 2, 0}),
    623            ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {2, 0, 1}),
    624        })});
    625   TF_ASSERT_OK(
    626       computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
    627           param_shape));
    628   computation_layout.mutable_result_layout()->ResetLayout(
    629       LayoutUtil::MakeLayout({2, 1, 0}));
    630   AssignLayouts(module.get(), &computation_layout);
    631 
    632   HloComputation* fused_computation = *std::find_if(
    633       module->computations().begin(), module->computations().end(),
    634       [](const HloComputation* c) { return c->name() == "fused_computation"; });
    635 
    636   auto fused_instr = [&](const string& name) {
    637     auto it = std::find_if(
    638         fused_computation->instructions().begin(),
    639         fused_computation->instructions().end(),
    640         [&](const HloInstruction* i) { return i->name() == name; });
    641     CHECK(it != fused_computation->instructions().end());
    642     return *it;
    643   };
    644 
    645   EXPECT_THAT(fused_instr("gte0")->shape().layout().minor_to_major(),
    646               ElementsAre(0, 1, 2));
    647   EXPECT_THAT(
    648       fused_instr("gte1")->shape().tuple_shapes(0).layout().minor_to_major(),
    649       ElementsAre(1, 2, 0));
    650   EXPECT_THAT(
    651       fused_instr("gte1")->shape().tuple_shapes(1).layout().minor_to_major(),
    652       ElementsAre(2, 0, 1));
    653   EXPECT_THAT(fused_instr("gte1a")->shape().layout().minor_to_major(),
    654               ElementsAre(1, 2, 0));
    655   EXPECT_THAT(fused_instr("gte1b")->shape().layout().minor_to_major(),
    656               ElementsAre(2, 0, 1));
    657   EXPECT_THAT(fused_instr("fresult")->shape().layout().minor_to_major(),
    658               ElementsAre(2, 1, 0));
    659 }
    660 
    661 TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
    662   auto builder = HloComputation::Builder(TestName());
    663   auto module = CreateNewModule();
    664   Shape shape = ShapeUtil::MakeShape(F32, {128, 8});
    665   Shape tshape = ShapeUtil::MakeTupleShape({shape, shape});
    666   Shape result_tshape = ShapeUtil::MakeTupleShape({shape});
    667 
    668   auto param0 = builder.AddInstruction(
    669       HloInstruction::CreateParameter(0, shape, "param0"));
    670   auto param1 = builder.AddInstruction(
    671       HloInstruction::CreateParameter(1, shape, "param1"));
    672   auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
    673       2, ShapeUtil::MakeShape(PRED, {}), "param2"));
    674   auto tuple =
    675       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
    676 
    677   auto true_builder = HloComputation::Builder(TestName() + "_TrueBranch");
    678   {
    679     auto param = true_builder.AddInstruction(
    680         HloInstruction::CreateParameter(0, tshape, "param"));
    681     auto gte0 = true_builder.AddInstruction(
    682         HloInstruction::CreateGetTupleElement(shape, param, 0));
    683     auto gte1 = true_builder.AddInstruction(
    684         HloInstruction::CreateGetTupleElement(shape, param, 1));
    685     auto add = true_builder.AddInstruction(
    686         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, gte0, gte1));
    687     true_builder.AddInstruction(HloInstruction::CreateTuple({add}));
    688   }
    689   HloComputation* true_computation =
    690       module->AddEmbeddedComputation(true_builder.Build());
    691 
    692   auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch");
    693   {
    694     Shape xshape = ShapeUtil::MakeShapeWithLayout(F32, {128, 8}, {0, 1});
    695     false_builder.AddInstruction(
    696         HloInstruction::CreateParameter(0, tshape, "param"));
    697     // Using infeed as layout assignment does not mess up with it.
    698     auto infeed =
    699         false_builder.AddInstruction(HloInstruction::CreateInfeed(xshape, ""));
    700     false_builder.AddInstruction(HloInstruction::CreateTuple({infeed}));
    701   }
    702   HloComputation* false_computation =
    703       module->AddEmbeddedComputation(false_builder.Build());
    704   builder.AddInstruction(HloInstruction::CreateConditional(
    705       result_tshape, pred, tuple, true_computation, tuple, false_computation));
    706 
    707   HloComputation* computation = module->AddEntryComputation(builder.Build());
    708   ComputationLayout computation_layout(computation->ComputeProgramShape());
    709 
    710   AssignLayouts(module.get(), &computation_layout);
    711 
    712   const HloInstruction* true_root = true_computation->root_instruction();
    713   const HloInstruction* false_root = false_computation->root_instruction();
    714   EXPECT_THAT(true_root->opcode(), HloOpcode::kTuple);
    715   EXPECT_THAT(false_root->opcode(), HloOpcode::kTuple);
    716 
    717   const HloInstruction* true_result = true_root->operand(0);
    718   const HloInstruction* false_result = false_root->operand(0);
    719   EXPECT_TRUE(LayoutUtil::Equal(true_result->shape().layout(),
    720                                 false_result->shape().layout()));
    721   EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy);
    722 }
    723 
    724 }  // namespace
    725 }  // namespace xla
    726