Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
     17 
     18 #include <initializer_list>
     19 #include <memory>
     20 #include <utility>
     21 #include <vector>
     22 
     23 #include "absl/types/span.h"
     24 #include "tensorflow/compiler/xla/layout_util.h"
     25 #include "tensorflow/compiler/xla/literal.h"
     26 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
     27 #include "tensorflow/compiler/xla/service/computation_layout.h"
     28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     30 #include "tensorflow/compiler/xla/service/hlo_module.h"
     31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     32 #include "tensorflow/compiler/xla/service/hlo_parser.h"
     33 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
     34 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
     35 #include "tensorflow/compiler/xla/shape_layout.h"
     36 #include "tensorflow/compiler/xla/shape_util.h"
     37 #include "tensorflow/compiler/xla/test.h"
     38 #include "tensorflow/compiler/xla/test_helpers.h"
     39 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     40 #include "tensorflow/compiler/xla/tests/test_utils.h"
     41 #include "tensorflow/compiler/xla/util.h"
     42 #include "tensorflow/compiler/xla/xla_data.pb.h"
     43 #include "tensorflow/core/lib/core/status.h"
     44 #include "tensorflow/core/lib/core/status_test_util.h"
     45 
     46 namespace xla {
     47 namespace {
     48 
     49 namespace m = xla::match;
     50 using ::testing::ElementsAre;
     51 
     52 class LayoutAssignmentTest : public HloTestBase {
     53  protected:
     54   void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout,
     55                      ChannelLayoutConstraints* channel_constraints = nullptr) {
     56     LayoutAssignment layout_assignment(
     57         entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout,
     58         /*channel_constraints=*/channel_constraints);
     59     EXPECT_IS_OK(layout_assignment.Run(m).status());
     60   }
     61 
     62   std::vector<int64> LayoutOf(HloModule* m, absl::string_view name) {
     63     auto minor_to_major =
     64         FindInstruction(m, name)->shape().layout().minor_to_major();
     65     return std::vector<int64>(minor_to_major.begin(), minor_to_major.end());
     66   }
     67 
     68   void ExpectLayoutIs(const Shape& shape,
     69                       absl::Span<const int64> minor_to_major) {
     70     const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
     71     EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected))
     72         << "Expected layout " << expected << ", actual " << shape.layout();
     73   }
     74 
     75   void ExpectTupleLayoutIs(
     76       const Shape& shape,
     77       std::initializer_list<absl::Span<const int64>> minor_to_majors) {
     78     int i = 0;
     79     for (const absl::Span<const int64> minor_to_major : minor_to_majors) {
     80       const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
     81       const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout();
     82       EXPECT_TRUE(LayoutUtil::Equal(actual, expected))
     83           << "Expected tuple element " << i << " layout " << expected
     84           << ", actual " << actual;
     85       ++i;
     86     }
     87   }
     88 };
     89 
     90 TEST_F(LayoutAssignmentTest, ComputationLayout) {
     91   // Verify the layouts of the root and parameter instructions of a computation
     92   // match the ComputationLayout for two different layouts.
     93   std::vector<std::vector<int64>> minor_to_majors = {{0, 1}, {1, 0}};
     94   for (auto& minor_to_major : minor_to_majors) {
     95     auto builder = HloComputation::Builder(TestName());
     96     Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
     97     auto param0 = builder.AddInstruction(
     98         HloInstruction::CreateParameter(0, ashape, "param0"));
     99     auto param1 = builder.AddInstruction(
    100         HloInstruction::CreateParameter(1, ashape, "param1"));
    101     auto add = builder.AddInstruction(
    102         HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
    103     auto m = CreateNewVerifiedModule();
    104     HloComputation* computation = m->AddEntryComputation(builder.Build());
    105 
    106     Layout layout = LayoutUtil::MakeLayout(minor_to_major);
    107     Shape shape(ashape);
    108     *shape.mutable_layout() = layout;
    109     const ShapeLayout shape_layout(shape);
    110 
    111     ComputationLayout computation_layout(computation->ComputeProgramShape());
    112     *computation_layout.mutable_parameter_layout(0) = shape_layout;
    113     *computation_layout.mutable_parameter_layout(1) = shape_layout;
    114     *computation_layout.mutable_result_layout() = shape_layout;
    115     AssignLayouts(m.get(), &computation_layout);
    116     EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
    117     EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
    118     EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
    119   }
    120 }
    121 
    122 TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
    123   // Verify the layouts of the root and parameter instructions of a computation
    124   // match the ComputationLayout which has mixed layout.
    125   auto builder = HloComputation::Builder(TestName());
    126   Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
    127   auto param0 = builder.AddInstruction(
    128       HloInstruction::CreateParameter(0, ashape, "param0"));
    129   auto param1 = builder.AddInstruction(
    130       HloInstruction::CreateParameter(1, ashape, "param1"));
    131   builder.AddInstruction(
    132       HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
    133   auto m = CreateNewVerifiedModule();
    134   HloComputation* computation = m->AddEntryComputation(builder.Build());
    135 
    136   Layout col_major_layout = LayoutUtil::MakeLayout({1, 0});
    137   Shape col_major_shape(ashape);
    138   *col_major_shape.mutable_layout() = col_major_layout;
    139   const ShapeLayout col_major(col_major_shape);
    140 
    141   Layout row_major_layout = LayoutUtil::MakeLayout({0, 1});
    142   Shape row_major_shape(ashape);
    143   *row_major_shape.mutable_layout() = row_major_layout;
    144   const ShapeLayout row_major(row_major_shape);
    145 
    146   ComputationLayout computation_layout(computation->ComputeProgramShape());
    147   *computation_layout.mutable_parameter_layout(0) = col_major;
    148   *computation_layout.mutable_parameter_layout(1) = row_major;
    149   *computation_layout.mutable_result_layout() = col_major;
    150 
    151   AssignLayouts(m.get(), &computation_layout);
    152   EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
    153   EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
    154   EXPECT_TRUE(LayoutUtil::Equal(
    155       col_major_layout, computation->root_instruction()->shape().layout()));
    156 }
    157 
    158 TEST_F(LayoutAssignmentTest, FusionInstruction) {
    159   // Verify that the layout of the fused parameters in a fusion instruction
    160   // match that of the fusion operands. Other fused instructions should have no
    161   // layout.
    162   std::vector<std::vector<int64>> minor_to_majors = {{0, 1}, {1, 0}};
    163   for (auto& minor_to_major : minor_to_majors) {
    164     auto builder = HloComputation::Builder(TestName());
    165     auto constant_literal1 = LiteralUtil::CreateR2WithLayout<float>(
    166         {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
    167     auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
    168         {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
    169     Shape ashape = constant_literal1.shape();
    170 
    171     auto constant1 = builder.AddInstruction(
    172         HloInstruction::CreateConstant(std::move(constant_literal1)));
    173     auto constant2 = builder.AddInstruction(
    174         HloInstruction::CreateConstant(std::move(constant_literal2)));
    175     auto add = builder.AddInstruction(HloInstruction::CreateBinary(
    176         ashape, HloOpcode::kAdd, constant1, constant2));
    177     auto negate1 = builder.AddInstruction(
    178         HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, add));
    179     auto negate2 = builder.AddInstruction(
    180         HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1));
    181 
    182     auto m = CreateNewVerifiedModule();
    183     HloComputation* computation = m->AddEntryComputation(builder.Build());
    184 
    185     auto fusion = computation->CreateFusionInstruction(
    186         {negate2, negate1, add}, HloInstruction::FusionKind::kLoop);
    187 
    188     Layout layout = LayoutUtil::MakeLayout(minor_to_major);
    189     Shape shape(ashape);
    190     *shape.mutable_layout() = layout;
    191     const ShapeLayout shape_layout(shape);
    192 
    193     ComputationLayout computation_layout(computation->ComputeProgramShape());
    194     *computation_layout.mutable_result_layout() = shape_layout;
    195 
    196     AssignLayouts(m.get(), &computation_layout);
    197 
    198     EXPECT_TRUE(LayoutUtil::Equal(
    199         layout, fusion->fused_parameter(0)->shape().layout()));
    200     EXPECT_TRUE(LayoutUtil::Equal(
    201         layout, fusion->fused_parameter(1)->shape().layout()));
    202     EXPECT_TRUE(LayoutUtil::Equal(
    203         layout, fusion->fused_expression_root()->shape().layout()));
    204 
    205     // Inner fused node should not have layout.
    206     EXPECT_FALSE(LayoutUtil::HasLayout(
    207         fusion->fused_expression_root()->operand(0)->shape()));
    208   }
    209 }
    210 
    211 TEST_F(LayoutAssignmentTest, TupleLayout) {
    212   // Verify the layouts of a tuple are assigned properly (the element layouts
    213   // match their source).
    214   auto builder = HloComputation::Builder(TestName());
    215   auto constant0 = builder.AddInstruction(
    216       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
    217           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
    218   auto constant1 = builder.AddInstruction(
    219       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
    220           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
    221   auto tuple = builder.AddInstruction(
    222       HloInstruction::CreateTuple({constant0, constant1}));
    223 
    224   // To avoid having to construct a tuple layout in the ComputationLayout below,
    225   // make the result of the instruction be an array.
    226   auto get_element0 = builder.AddInstruction(
    227       HloInstruction::CreateGetTupleElement(constant0->shape(), tuple, 0));
    228   auto negate = builder.AddInstruction(HloInstruction::CreateUnary(
    229       constant0->shape(), HloOpcode::kNegate, get_element0));
    230 
    231   auto m = CreateNewVerifiedModule();
    232   m->AddEntryComputation(builder.Build());
    233 
    234   ComputationLayout computation_layout(
    235       m->entry_computation()->ComputeProgramShape());
    236 
    237   AssignLayouts(m.get(), &computation_layout);
    238 
    239   EXPECT_TRUE(
    240       LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
    241 
    242   EXPECT_TRUE(LayoutUtil::HasLayout(tuple->shape()));
    243   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
    244       negate->shape(), computation_layout.result_layout().shape()));
    245   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
    246       ShapeUtil::GetTupleElementShape(tuple->shape(), 1), constant1->shape()));
    247 }
    248 
    249 TEST_F(LayoutAssignmentTest, TupleSelect) {
    250   // Verify layouts of a select with tuple operands is assigned properly.
    251   auto builder = HloComputation::Builder(TestName());
    252   auto constant0 = builder.AddInstruction(
    253       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
    254           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
    255   auto constant1 = builder.AddInstruction(
    256       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
    257           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
    258   auto tuple0 = builder.AddInstruction(
    259       HloInstruction::CreateTuple({constant0, constant1}));
    260   auto tuple1 = builder.AddInstruction(
    261       HloInstruction::CreateTuple({constant0, constant1}));
    262 
    263   auto pred = builder.AddInstruction(
    264       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
    265 
    266   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
    267       tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1));
    268 
    269   auto m = CreateNewVerifiedModule();
    270   m->AddEntryComputation(builder.Build());
    271 
    272   ComputationLayout computation_layout(
    273       m->entry_computation()->ComputeProgramShape());
    274   Shape result_shape =
    275       ShapeUtil::MakeTupleShape({constant0->shape(), constant1->shape()});
    276   TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
    277       result_shape));
    278 
    279   AssignLayouts(m.get(), &computation_layout);
    280 
    281   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape()));
    282 }
    283 
    284 TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
    285   // Construct following computation which has conflicting layouts for two
    286   // elements of a tuple which share the same source logicalb buffer:
    287   //
    288   // %constant = Constant(...)
    289   // %inner_tuple = Tuple(%constant)
    290   // %nested_tuple = Tuple(%inner_tuple, %inner_tuple)
    291   //
    292   // Result layout col-major for the first element and row-major for the
    293   // second. This results in the conflict where the element of the inner_tuple
    294   // needs to be both col and row major. This is resolved by deep-copying the
    295   // tuple and assigning the layouts of the copied arrays as needed.
    296   auto builder = HloComputation::Builder(TestName());
    297   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
    298       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
    299   auto inner_tuple =
    300       builder.AddInstruction(HloInstruction::CreateTuple({constant}));
    301   auto nested_tuple = builder.AddInstruction(
    302       HloInstruction::CreateTuple({inner_tuple, inner_tuple}));
    303 
    304   auto m = CreateNewVerifiedModule();
    305   m->AddEntryComputation(builder.Build());
    306 
    307   ComputationLayout computation_layout(
    308       m->entry_computation()->ComputeProgramShape());
    309   Shape result_shape = nested_tuple->shape();
    310   *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) =
    311       ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
    312   *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{1, 0}) =
    313       ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
    314   TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
    315       result_shape));
    316 
    317   LayoutAssignment layout_assignment(&computation_layout);
    318   AssignLayouts(m.get(), &computation_layout);
    319 
    320   // Layout assignment should have deep copied the result of the computation to
    321   // address the layout conflict. This results in several Tuple() and
    322   // GetTupleElement() instructions. Running algebraic simplification should
    323   // clean up the code to something like:
    324   //
    325   //  %constant = Constant(...) layout={1,0}
    326   //  %tuple.0 = Tuple(%constant) layout=({1,0})
    327   //  %copy = Copy(%constant) layout={0,1}  # layout transposed
    328   //  %tuple.1 = Tuple(%copy) layout=({0,1})
    329   //  %tuple.2 = Tuple(%tuple.0, %tuple.1) layout=(({1,0}), ({0,1}))
    330   //
    331   AlgebraicSimplifierOptions options(
    332       [](const Shape&, const Shape&) { return false; });
    333   options.set_is_layout_sensitive(true);
    334   EXPECT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
    335   HloInstruction* root = m->entry_computation()->root_instruction();
    336   // Verify layout of the root and the root's operands.
    337   EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape()));
    338   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}),
    339                                root->operand(0)->shape()));
    340   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {1}),
    341                                root->operand(1)->shape()));
    342 
    343   // Verify the structure of the HLO graph.
    344   EXPECT_THAT(root,
    345               GmockMatch(m::Tuple(m::Tuple(m::Op().Is(constant)),
    346                                   m::Tuple(m::Copy(m::Op().Is(constant))))));
    347 }
    348 
    349 TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
    350   // param -> log -> reshape -> tanh
    351   auto builder = HloComputation::Builder(TestName());
    352   Shape ashape = ShapeUtil::MakeShape(F32, {1, 2, 3, 1});
    353   Shape bshape = ShapeUtil::MakeShape(F32, {3, 1, 2});
    354   auto param = builder.AddInstruction(
    355       HloInstruction::CreateParameter(0, ashape, "param"));
    356   auto log = builder.AddInstruction(
    357       HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
    358   auto reshape =
    359       builder.AddInstruction(HloInstruction::CreateReshape(bshape, log));
    360   auto tanh = builder.AddInstruction(
    361       HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape));
    362 
    363   auto m = CreateNewVerifiedModule();
    364   HloComputation* computation = m->AddEntryComputation(builder.Build(tanh));
    365 
    366   Shape ashape_with_layout(ashape);
    367   Shape bshape_with_layout(bshape);
    368   *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 2, 1, 3});
    369   *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
    370 
    371   ComputationLayout computation_layout(computation->ComputeProgramShape());
    372   *computation_layout.mutable_parameter_layout(0) =
    373       ShapeLayout(ashape_with_layout);
    374   *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
    375   AssignLayouts(m.get(), &computation_layout);
    376 
    377   auto log_minor_to_major =
    378       AsInt64Slice(log->shape().layout().minor_to_major());
    379   EXPECT_GT(PositionInContainer(log_minor_to_major, 1),
    380             PositionInContainer(log_minor_to_major, 2));
    381 
    382   auto reshape_minor_to_major =
    383       AsInt64Slice(reshape->shape().layout().minor_to_major());
    384   EXPECT_GT(PositionInContainer(reshape_minor_to_major, 0),
    385             PositionInContainer(reshape_minor_to_major, 2));
    386 }
    387 
    388 // Test whether LayoutAssignment assigns layouts to elementwise operations to
    389 // keep linear indices valid across them, and to transpositions to make them
    390 // bitcasts.
    391 TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
    392   // param -> log -> transpose -> tanh
    393   auto builder = HloComputation::Builder(TestName());
    394   Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
    395   Shape bshape = ShapeUtil::MakeShape(F32, {12, 42});
    396   auto param = builder.AddInstruction(
    397       HloInstruction::CreateParameter(0, ashape, "param"));
    398   auto log = builder.AddInstruction(
    399       HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
    400   auto transpose = builder.AddInstruction(
    401       HloInstruction::CreateTranspose(bshape, log, {1, 0}));
    402   auto tanh = builder.AddInstruction(
    403       HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose));
    404   auto m = CreateNewVerifiedModule();
    405   auto computation = m->AddEntryComputation(builder.Build(tanh));
    406 
    407   Shape ashape_with_layout(ashape);
    408   Shape bshape_with_layout(bshape);
    409   *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
    410   *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
    411 
    412   ComputationLayout computation_layout(computation->ComputeProgramShape());
    413   *computation_layout.mutable_parameter_layout(0) =
    414       ShapeLayout(ashape_with_layout);
    415   *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
    416   AssignLayouts(m.get(), &computation_layout);
    417 
    418   EXPECT_TRUE(
    419       LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
    420   EXPECT_TRUE(LayoutUtil::Equal(bshape_with_layout.layout(),
    421                                 transpose->shape().layout()));
    422   EXPECT_TRUE(
    423       LayoutUtil::Equal(bshape_with_layout.layout(), tanh->shape().layout()));
    424 }
    425 
    426 // Test whether LayoutAssignment assigns layouts to transpositions to make them
    427 // bitcasts.
    428 TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
    429   // param -> broadcast -> transpose
    430   auto builder = HloComputation::Builder(TestName());
    431   Shape ashape = ShapeUtil::MakeShape(F32, {3, 4});
    432   Shape bshape = ShapeUtil::MakeShape(F32, {2, 3, 4});
    433   Shape cshape = ShapeUtil::MakeShape(F32, {4, 3, 2});
    434   auto param = builder.AddInstruction(
    435       HloInstruction::CreateParameter(0, ashape, "param"));
    436   auto broadcast = builder.AddInstruction(
    437       HloInstruction::CreateBroadcast(bshape, param, {1, 2}));
    438   auto transpose = builder.AddInstruction(
    439       HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0}));
    440   auto m = CreateNewVerifiedModule();
    441   HloComputation* computation =
    442       m->AddEntryComputation(builder.Build(transpose));
    443 
    444   Shape input_shape_with_layout(ashape);
    445   Shape output_shape_with_layout(cshape);
    446   *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
    447   *output_shape_with_layout.mutable_layout() =
    448       LayoutUtil::MakeLayout({2, 1, 0});
    449 
    450   ComputationLayout computation_layout(computation->ComputeProgramShape());
    451   *computation_layout.mutable_parameter_layout(0) =
    452       ShapeLayout(input_shape_with_layout);
    453   *computation_layout.mutable_result_layout() =
    454       ShapeLayout(output_shape_with_layout);
    455   AssignLayouts(m.get(), &computation_layout);
    456 
    457   EXPECT_THAT(broadcast->shape().layout().minor_to_major(),
    458               ElementsAre(0, 1, 2));
    459 }
    460 
    461 TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
    462   // param[4] -> broadcast[3x4] ------> transpose[4x3]-------- -------> tuple
    463   //                            \                                     /
    464   //                             \-> tanh[3x4] -> broadcast2[2x3x4] -/
    465   //
    466   // The layout of `transpose` is set to {1,0} because it provides a buffer to
    467   // the computation result which has a fixed layout.. Therefore, `broadcast`
    468   // (the operand of transpose) is expected to have layout {0,1} so that the
    469   // transpose is a bitcast. Furthermore, `tanh` is expected to have the same
    470   // layout as `broadcast` (i.e. {0,1}) because `tanh` is elementwise.
    471   Shape f32_4 = ShapeUtil::MakeShape(F32, {4});
    472   Shape f32_34 = ShapeUtil::MakeShape(F32, {3, 4});
    473   Shape f32_43 = ShapeUtil::MakeShape(F32, {4, 3});
    474   Shape f32_234 = ShapeUtil::MakeShape(F32, {2, 3, 4});
    475 
    476   auto builder = HloComputation::Builder(TestName());
    477   auto param = builder.AddInstruction(
    478       HloInstruction::CreateParameter(0, f32_4, "param"));
    479   auto broadcast = builder.AddInstruction(
    480       HloInstruction::CreateBroadcast(f32_34, param, {1}));
    481   auto transpose = builder.AddInstruction(
    482       HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0}));
    483   auto tanh = builder.AddInstruction(
    484       HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast));
    485   auto broadcast2 = builder.AddInstruction(
    486       HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2}));
    487   auto tuple = builder.AddInstruction(
    488       HloInstruction::CreateTuple({transpose, broadcast2}));
    489   auto m = CreateNewVerifiedModule();
    490   HloComputation* computation = m->AddEntryComputation(builder.Build(tuple));
    491 
    492   ComputationLayout computation_layout(computation->ComputeProgramShape());
    493   Shape param_shape_with_layout(f32_4);
    494   Shape transpose_shape_with_layout(f32_43);
    495   Shape broadcast2_shape_with_layout(f32_234);
    496   *param_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0});
    497   *transpose_shape_with_layout.mutable_layout() =
    498       LayoutUtil::MakeLayout({1, 0});
    499   *broadcast2_shape_with_layout.mutable_layout() =
    500       LayoutUtil::MakeLayout({2, 1, 0});
    501 
    502   *computation_layout.mutable_parameter_layout(0) =
    503       ShapeLayout(param_shape_with_layout);
    504   *computation_layout.mutable_result_layout() =
    505       ShapeLayout(ShapeUtil::MakeTupleShape(
    506           {transpose_shape_with_layout, broadcast2_shape_with_layout}));
    507   AssignLayouts(m.get(), &computation_layout);
    508 
    509   EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1));
    510   EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0));
    511   EXPECT_THAT(tanh->shape().layout().minor_to_major(), ElementsAre(0, 1));
    512 }
    513 
    514 class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment {
    515  public:
    516   explicit OperandsMustBeTheSameLayoutAssignment(
    517       ComputationLayout* entry_computation_layout)
    518       : LayoutAssignment(entry_computation_layout) {}
    519 
    520  protected:
    521   Status PropagateBufferConstraint(
    522       const BufferLayoutConstraint& buffer_constraint,
    523       LayoutConstraints* constraints) override {
    524     const LogicalBuffer& buffer = buffer_constraint.buffer();
    525     const HloInstruction* instruction = buffer.instruction();
    526 
    527     // Force the operands' layout to the output layout.
    528     for (int64 operand_no = 0; operand_no < instruction->operand_count();
    529          ++operand_no) {
    530       const HloInstruction* operand = instruction->operand(operand_no);
    531       if (instruction->shape().rank() != operand->shape().rank()) {
    532         continue;
    533       }
    534       TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
    535           buffer_constraint.layout(), instruction, operand_no,
    536           /*mandatory=*/true));
    537     }
    538     return PropagateBufferConstraintToUses(buffer_constraint, constraints);
    539   }
    540 };
    541 
    542 TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
    543   // param0 -> concatenate -> reshape
    544   // param1   -^
    545   auto builder = HloComputation::Builder(TestName());
    546   Shape ashape = ShapeUtil::MakeShape(F32, {50, 1});
    547   Shape bshape = ShapeUtil::MakeShape(F32, {50, 2});
    548   Shape cshape = ShapeUtil::MakeShape(F32, {100});
    549   auto param0 = builder.AddInstruction(
    550       HloInstruction::CreateParameter(0, ashape, "param"));
    551   auto param1 = builder.AddInstruction(
    552       HloInstruction::CreateParameter(1, ashape, "param"));
    553   auto concatenate = builder.AddInstruction(
    554       HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1));
    555   auto reshape = builder.AddInstruction(
    556       HloInstruction::CreateReshape(cshape, concatenate));
    557   auto m = CreateNewVerifiedModule();
    558   HloComputation* computation = m->AddEntryComputation(builder.Build(reshape));
    559 
    560   Shape param0_shape_with_layout(ashape);
    561   Shape param1_shape_with_layout(ashape);
    562   *param0_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
    563   *param1_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
    564 
    565   ComputationLayout computation_layout(computation->ComputeProgramShape());
    566   *computation_layout.mutable_parameter_layout(0) =
    567       ShapeLayout(param0_shape_with_layout);
    568   *computation_layout.mutable_parameter_layout(1) =
    569       ShapeLayout(param1_shape_with_layout);
    570   OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout);
    571   EXPECT_IS_OK(layout_assignment.Run(m.get()).status());
    572 
    573   EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode());
    574   EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(),
    575               ElementsAre(1, 0));
    576   EXPECT_THAT(concatenate->operand(1)->shape().layout().minor_to_major(),
    577               ElementsAre(1, 0));
    578   EXPECT_THAT(concatenate->shape().layout().minor_to_major(),
    579               ElementsAre(1, 0));
    580 }
    581 
    582 // Test layout assignment of a transpose into a bitcast based on its operand.
    583 TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) {
    584   auto builder = HloComputation::Builder(TestName());
    585   Shape input_shape_with_layout =
    586       ShapeUtil::MakeShapeWithLayout(F32, {3, 5, 6, 7}, {2, 0, 3, 1});
    587   auto param = builder.AddInstruction(
    588       HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
    589   auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
    590       ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1}));
    591   auto m = CreateNewVerifiedModule();
    592   HloComputation* computation =
    593       m->AddEntryComputation(builder.Build(transpose));
    594   ComputationLayout computation_layout(computation->ComputeProgramShape());
    595   AssignLayouts(m.get(), &computation_layout);
    596   EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
    597                                             transpose->shape(), {2, 3, 0, 1}));
    598 }
    599 // Test layout assignment of a transpose into a bitcast based on its user.
    600 TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
    601   auto builder = HloComputation::Builder(TestName());
    602   Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7});
    603   auto constant = builder.AddInstruction(
    604       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
    605   auto broadcast = builder.AddInstruction(
    606       HloInstruction::CreateBroadcast(input_shape, constant, {}));
    607   auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
    608       ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1}));
    609   auto m = CreateNewVerifiedModule();
    610   HloComputation* computation =
    611       m->AddEntryComputation(builder.Build(transpose));
    612   ComputationLayout computation_layout(computation->ComputeProgramShape());
    613   AssignLayouts(m.get(), &computation_layout);
    614   EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
    615                                             transpose->shape(), {2, 3, 0, 1}));
    616 }
    617 
    618 // TransposeIsBitcast shouldn't be called without layout information.
    619 TEST_F(LayoutAssignmentTest, TransposeIsBitcastFail) {
    620   auto builder = HloComputation::Builder(TestName());
    621   Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
    622   Shape input_shape_with_layout(input_shape);
    623   *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
    624   auto param = builder.AddInstruction(
    625       HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
    626   auto hlo = builder.AddInstruction(
    627       HloInstruction::CreateTranspose(input_shape, param, {0, 2, 1}));
    628   // Clear the default layout assigned to the instruction.
    629   LayoutUtil::ClearLayout(hlo->mutable_shape());
    630   EXPECT_DEATH(ShapeUtil::TransposeIsBitcast(hlo->operand(0)->shape(),
    631                                              hlo->shape(), hlo->dimensions()),
    632                "LayoutUtil::HasLayout");
    633 }
    634 
    635 // ReshapeIsBitcast shouldn't be called without layout information.
    636 TEST_F(LayoutAssignmentTest, ReshapeIsBitcastFail) {
    637   auto builder = HloComputation::Builder(TestName());
    638   Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
    639   Shape input_shape_with_layout(input_shape);
    640   *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
    641   auto param = builder.AddInstruction(
    642       HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
    643   auto hlo =
    644       builder.AddInstruction(HloInstruction::CreateReshape(input_shape, param));
    645   // Clear the default layout assigned to the instruction.
    646   LayoutUtil::ClearLayout(hlo->mutable_shape());
    647   EXPECT_DEATH(
    648       ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape()),
    649       "LayoutUtil::HasLayout");
    650 }
    651 
    652 // Check that the computation below doesn't crash the compiler.
    653 //
    654 // Within a fusion computation, only the parameters and result get assigned a
    655 // layout.  When we run the algebraic simplifier on this computation post layout
    656 // assignment, it should not call TransposeIsBitcast on the `transpose` node
    657 // inside the fusion computation as TransposeIsBitcast checks both input_shape
    658 // and output_shape have layouts.
    659 TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) {
    660   const char* module_str = R"(
    661     HloModule test_module
    662 
    663     fused_computation {
    664       param_1 = f32[2,2,2]{2,1,0} parameter(1)
    665       transpose = f32[2,2,2]{2,1,0} transpose(param_1), dimensions={0,2,1}
    666       reduce_1 = f32[] parameter(0)
    667       broadcast_1 = f32[2,2,2]{2,1,0} broadcast(reduce_1), dimensions={}
    668       ROOT divide_1 = f32[2,2,2]{2,1,0} divide(transpose, broadcast_1)
    669     }
    670 
    671     ENTRY entry_computation {
    672       fusion.1 = f32[2,2,2]{2,1,0} parameter(1)
    673       reduce.1 = f32[] parameter(0)
    674       fusion.2 = f32[2,2,2]{2,1,0} fusion(reduce.1, fusion.1), kind=kLoop, calls=fused_computation
    675      ROOT tuple.1 = (f32[2,2,2]{2,1,0}) tuple(fusion.2)
    676     }
    677   )";
    678 
    679   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
    680                           ParseAndReturnVerifiedModule(module_str));
    681   std::unique_ptr<HloModule> compiled_module =
    682       backend()
    683           .compiler()
    684           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
    685                          /*device_allocator=*/nullptr)
    686           .ConsumeValueOrDie();
    687 
    688   EXPECT_EQ(Status::OK(), backend()
    689                               .compiler()
    690                               ->RunBackend(std::move(compiled_module),
    691                                            backend().default_stream_executor(),
    692                                            /*device_allocator=*/nullptr)
    693                               .status());
    694 }
    695 
    696 // A GTE inside of a fusion node inherits the layout of its operand (which
    697 // should, if we keep following operands, eventually be a parameter).
    698 TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
    699   const char* module_str = R"(
    700     HloModule test_module
    701 
    702     fused_computation {
    703       fparam = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
    704       gte0 = f32[2,2,2] get-tuple-element(fparam), index=0
    705       gte1 = (f32[2,2,2], f32[2,2,2]) get-tuple-element(fparam), index=1
    706       gte1a = f32[2,2,2] get-tuple-element(gte1), index=0
    707       gte1b = f32[2,2,2] get-tuple-element(gte1), index=1
    708       add = f32[2,2,2] add(gte1a, gte1b)
    709       ROOT fresult = f32[2,2,2] add(gte0, add)
    710     }
    711 
    712     ENTRY entry_computation {
    713       param = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
    714       ROOT fusion =
    715         f32[2,2,2] fusion(param), kind=kLoop, calls=fused_computation
    716     }
    717   )";
    718 
    719   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
    720                           ParseAndReturnVerifiedModule(module_str));
    721   ComputationLayout computation_layout(
    722       m->entry_computation()->ComputeProgramShape());
    723   Shape param_shape = ShapeUtil::MakeTupleShape(
    724       {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
    725        ShapeUtil::MakeTupleShape({
    726            ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {1, 2, 0}),
    727            ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {2, 0, 1}),
    728        })});
    729   TF_ASSERT_OK(
    730       computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
    731           param_shape));
    732   computation_layout.mutable_result_layout()->ResetLayout(
    733       LayoutUtil::MakeLayout({2, 1, 0}));
    734   AssignLayouts(m.get(), &computation_layout);
    735 
    736   EXPECT_THAT(LayoutOf(m.get(), "gte0"), ElementsAre(0, 1, 2));
    737   EXPECT_THAT(LayoutOf(m.get(), "gte1a"), ElementsAre(1, 2, 0));
    738   EXPECT_THAT(LayoutOf(m.get(), "gte1b"), ElementsAre(2, 0, 1));
    739   EXPECT_THAT(LayoutOf(m.get(), "fresult"), ElementsAre(2, 1, 0));
    740   EXPECT_THAT(FindInstruction(m.get(), "gte1")
    741                   ->shape()
    742                   .tuple_shapes(0)
    743                   .layout()
    744                   .minor_to_major(),
    745               ElementsAre(1, 2, 0));
    746   EXPECT_THAT(FindInstruction(m.get(), "gte1")
    747                   ->shape()
    748                   .tuple_shapes(1)
    749                   .layout()
    750                   .minor_to_major(),
    751               ElementsAre(2, 0, 1));
    752 }
    753 
    754 TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
    755   auto builder = HloComputation::Builder(TestName());
    756   auto m = CreateNewVerifiedModule();
    757   Shape shape = ShapeUtil::MakeShape(F32, {128, 8});
    758   Shape tshape = ShapeUtil::MakeTupleShape({shape, shape});
    759   Shape result_tshape = ShapeUtil::MakeTupleShape({shape});
    760 
    761   auto param0 = builder.AddInstruction(
    762       HloInstruction::CreateParameter(0, shape, "param0"));
    763   auto param1 = builder.AddInstruction(
    764       HloInstruction::CreateParameter(1, shape, "param1"));
    765   auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
    766       2, ShapeUtil::MakeShape(PRED, {}), "param2"));
    767   auto tuple =
    768       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
    769 
    770   auto true_builder = HloComputation::Builder(TestName() + "_TrueBranch");
    771   {
    772     auto param = true_builder.AddInstruction(
    773         HloInstruction::CreateParameter(0, tshape, "param"));
    774     auto gte0 = true_builder.AddInstruction(
    775         HloInstruction::CreateGetTupleElement(shape, param, 0));
    776     auto gte1 = true_builder.AddInstruction(
    777         HloInstruction::CreateGetTupleElement(shape, param, 1));
    778     auto add = true_builder.AddInstruction(
    779         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, gte0, gte1));
    780     true_builder.AddInstruction(HloInstruction::CreateTuple({add}));
    781   }
    782   HloComputation* true_computation =
    783       m->AddEmbeddedComputation(true_builder.Build());
    784 
    785   auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch");
    786   {
    787     Shape xshape = ShapeUtil::MakeShapeWithLayout(F32, {128, 8}, {0, 1});
    788     false_builder.AddInstruction(
    789         HloInstruction::CreateParameter(0, tshape, "param"));
    790     // Using infeed as layout assignment does not mess up with it.
    791     auto token = false_builder.AddInstruction(HloInstruction::CreateToken());
    792     auto infeed = false_builder.AddInstruction(
    793         HloInstruction::CreateInfeed(xshape, token, ""));
    794     auto infeed_data = false_builder.AddInstruction(
    795         HloInstruction::CreateGetTupleElement(xshape, infeed, 0));
    796     false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data}));
    797   }
    798   HloComputation* false_computation =
    799       m->AddEmbeddedComputation(false_builder.Build());
    800   builder.AddInstruction(HloInstruction::CreateConditional(
    801       result_tshape, pred, tuple, true_computation, tuple, false_computation));
    802 
    803   HloComputation* computation = m->AddEntryComputation(builder.Build());
    804   ComputationLayout computation_layout(computation->ComputeProgramShape());
    805 
    806   AssignLayouts(m.get(), &computation_layout);
    807 
    808   const HloInstruction* true_root = true_computation->root_instruction();
    809   const HloInstruction* false_root = false_computation->root_instruction();
    810   EXPECT_THAT(true_root->opcode(), HloOpcode::kTuple);
    811   EXPECT_THAT(false_root->opcode(), HloOpcode::kTuple);
    812 
    813   const HloInstruction* true_result = true_root->operand(0);
    814   const HloInstruction* false_result = false_root->operand(0);
    815   EXPECT_TRUE(LayoutUtil::Equal(true_result->shape().layout(),
    816                                 false_result->shape().layout()));
    817   EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy);
    818 }
    819 
    820 TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
    821   auto builder = HloComputation::Builder(TestName());
    822   auto constant0 = builder.AddInstruction(
    823       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
    824           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
    825   builder.AddInstruction(HloInstruction::CreateUnary(
    826       constant0->shape(), HloOpcode::kBitcast, constant0));
    827   auto m = CreateNewVerifiedModule();
    828   m->AddEntryComputation(builder.Build());
    829 
    830   ComputationLayout computation_layout(
    831       m->entry_computation()->ComputeProgramShape());
    832   LayoutAssignment layout_assignment(&computation_layout);
    833   Status error_status = layout_assignment.Run(m.get()).status();
    834   EXPECT_FALSE(error_status.ok());
    835   EXPECT_THAT(
    836       error_status.error_message(),
    837       ::testing::HasSubstr(
    838           "Unexpected bitcast operation seen during layout assignment"));
    839 }
    840 
    841 TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
    842   // Pin non matching layouts to parameter and root.
    843   const char* module_str = R"(
    844     HloModule test_module
    845 
    846     ENTRY entry_computation {
    847       param = (f32[2,2]) parameter(0)
    848       gte = f32[2,2] get-tuple-element(param), index=0
    849       token0 = token[] after-all()
    850       recv = (f32[2,2], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=1}
    851       recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1,
    852         sharding={maximal device=1}
    853       ROOT root = f32[2,2] get-tuple-element(recv-done), index=0
    854       send = (f32[2,2], u32[], token[]) send(gte, token0), channel_id=1,
    855         sharding={maximal device=0}
    856       send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0}
    857     }
    858   )";
    859 
    860   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
    861                           ParseAndReturnVerifiedModule(module_str));
    862   ComputationLayout computation_layout(
    863       m->entry_computation()->ComputeProgramShape());
    864   Shape param_shape = ShapeUtil::MakeTupleShape(
    865       {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
    866   TF_ASSERT_OK(
    867       computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
    868           param_shape));
    869   computation_layout.mutable_result_layout()->ResetLayout(
    870       LayoutUtil::MakeLayout({1, 0}));
    871 
    872   ChannelLayoutConstraints channel_constraints;
    873   AssignLayouts(m.get(), &computation_layout, &channel_constraints);
    874 
    875   EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1));
    876   EXPECT_THAT(LayoutOf(m.get(), "root"), ElementsAre(1, 0));
    877   EXPECT_TRUE(ShapeUtil::Equal(
    878       ShapeUtil::GetSubshape(FindInstruction(m.get(), "send")->shape(), {0}),
    879       ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
    880 }
    881 
    882 TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
    883   // Pin non matching layouts to parameter and root.
    884   const char* module_str = R"(
    885     HloModule test_module
    886 
    887     add {
    888       lhs = f32[] parameter(0)
    889       rhs = f32[] parameter(1)
    890       ROOT add = f32[] add(lhs, rhs)
    891     }
    892 
    893     ENTRY entry_computation {
    894       param = (f32[2,2]) parameter(0)
    895       gte = f32[2,2] get-tuple-element(param), index=0
    896       ar.0 = f32[2,2] all-reduce(gte),
    897         all_reduce_id=1, replica_groups={{0}}, to_apply=add,
    898         sharding={maximal device=0}
    899       const = f32[2,2] constant({{0,1},{2,3}})
    900       ROOT ar.1 = f32[2,2] all-reduce(const),
    901         all_reduce_id=1, replica_groups={{0}}, to_apply=add,
    902         sharding={maximal device=1}
    903     })";
    904   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
    905                           ParseAndReturnVerifiedModule(module_str));
    906   ComputationLayout computation_layout(
    907       m->entry_computation()->ComputeProgramShape());
    908   Shape param_shape = ShapeUtil::MakeTupleShape(
    909       {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
    910   TF_ASSERT_OK(
    911       computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
    912           param_shape));
    913   computation_layout.mutable_result_layout()->ResetLayout(
    914       LayoutUtil::MakeLayout({1, 0}));
    915 
    916   ChannelLayoutConstraints channel_constraints;
    917   AssignLayouts(m.get(), &computation_layout, &channel_constraints);
    918 
    919   EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1));
    920   EXPECT_THAT(LayoutOf(m.get(), "ar.0"), ElementsAre(0, 1));
    921   EXPECT_THAT(LayoutOf(m.get(), "ar.1"), ElementsAre(0, 1));
    922   const HloInstruction* root = m->entry_computation()->root_instruction();
    923   EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0));
    924 }
    925 
    926 TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
    927   const char* module_str = R"(
    928     HloModule CopySliceOperandToAvoidImplicitLayoutChange
    929 
    930     ENTRY CopySliceOperandToAvoidImplicitLayoutChange {
    931       par0 = f32[3,4]{1,0} parameter(0)
    932       par1 = f32[4,5]{0,1} parameter(1)
    933       slice0 = f32[3,4] slice(par1), slice={[1:4],[1:5]}
    934       ROOT add0 = f32[3,4]{1,0} add(par0,slice0)
    935     }
    936   )";
    937 
    938   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
    939                           ParseAndReturnVerifiedModule(module_str));
    940   auto compiled_module =
    941       backend()
    942           .compiler()
    943           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
    944                          /*device_allocator=*/nullptr)
    945           .ConsumeValueOrDie();
    946   HloInstruction* root =
    947       compiled_module->entry_computation()->root_instruction();
    948   Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
    949   EXPECT_THAT(
    950       root,
    951       GmockMatch(m::Add(
    952           m::Parameter(),
    953           m::Slice(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy)))));
    954 }
    955 
    956 TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
    957   const char* module_str = R"(
    958     HloModule CopyDSliceOperandToAvoidImplicitLayoutChange
    959 
    960     ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange {
    961       par0 = f32[3,4]{1,0} parameter(0)
    962       par1 = f32[4,5]{0,1} parameter(1)
    963       par2 = s32[] parameter(2)
    964       par3 = s32[] parameter(3)
    965       dslice0 = f32[3,4] dynamic-slice(par1, par2, par3), dynamic_slice_sizes={3,4}
    966       ROOT add0 = f32[3,4]{1,0} add(par0,dslice0)
    967     }
    968   )";
    969 
    970   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
    971                           ParseAndReturnVerifiedModule(module_str));
    972   auto compiled_module =
    973       backend()
    974           .compiler()
    975           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
    976                          /*device_allocator=*/nullptr)
    977           .ConsumeValueOrDie();
    978   HloInstruction* root =
    979       compiled_module->entry_computation()->root_instruction();
    980   Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
    981   EXPECT_THAT(root,
    982               GmockMatch(m::Add(
    983                   m::Parameter(),
    984                   m::DynamicSlice(
    985                       m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
    986                       m::Parameter(2), m::Parameter(3)))));
    987 }
    988 
    989 TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
    990   const char* module_str = R"(
    991     HloModule CopyConcatOperandToAvoidImplicitLayoutChange
    992 
    993     ENTRY CopyConcatOperandToAvoidImplicitLayoutChange {
    994       par0 = f32[3,8]{1,0} parameter(0)
    995       par1 = f32[3,5]{0,1} parameter(1)
    996       par2 = f32[3,3]{1,0} parameter(2)
    997       concat0 = f32[3,8] concatenate(f32[3,5] par1, f32[3,3] par2),
    998         dimensions={1}
    999       ROOT add0 = f32[3,8]{1,0} add(par0,concat0)
   1000     }
   1001   )";
   1002 
   1003   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
   1004                           ParseAndReturnVerifiedModule(module_str));
   1005   auto compiled_module =
   1006       backend()
   1007           .compiler()
   1008           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
   1009                          /*device_allocator=*/nullptr)
   1010           .ConsumeValueOrDie();
   1011   HloInstruction* root =
   1012       compiled_module->entry_computation()->root_instruction();
   1013   Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0});
   1014   EXPECT_THAT(
   1015       root,
   1016       GmockMatch(m::Add(
   1017           m::Parameter(),
   1018           m::Concatenate(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
   1019                          m::Parameter(2)))));
   1020 }
   1021 
   1022 TEST_F(LayoutAssignmentTest,
   1023        ConvolutionOperandWithImplicitLayoutChangeNotCopied) {
   1024   const char* module_str = R"(
   1025     HloModule ConvolutionOperandWithImplicitLayoutChangeNotCopied
   1026 
   1027     ENTRY ConvolutionOperandWithImplicitLayoutChangeNotCopied {
   1028       par0 = f32[128,3,230,230]{2,3,1,0} parameter(0)
   1029       par1 = f32[7,7,3,64]{3,2,0,1} parameter(1)
   1030       ROOT convolution0 = f32[128,64,112,112]{3,2,1,0} convolution(par0, par1),
   1031         window={size=7x7 stride=2x2}, dim_labels=bf01_01io->bf01,
   1032         feature_group_count=1
   1033     }
   1034   )";
   1035 
   1036   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
   1037                           ParseAndReturnVerifiedModule(module_str));
   1038   auto compiled_module =
   1039       backend()
   1040           .compiler()
   1041           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
   1042                          /*device_allocator=*/nullptr)
   1043           .ConsumeValueOrDie();
   1044   HloInstruction* root =
   1045       compiled_module->entry_computation()->root_instruction();
   1046   EXPECT_THAT(root,
   1047               GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
   1048 }
   1049 
   1050 TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
   1051   const char* module_str = R"(
   1052     HloModule PropagatingLayoutFromResultToOperand
   1053 
   1054     ENTRY PropagatingLayoutFromResultToOperand {
   1055       par0 = f32[4,5]{1,0} parameter(0)
   1056       ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]}
   1057     }
   1058   )";
   1059 
   1060   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
   1061                           ParseAndReturnVerifiedModule(module_str));
   1062   auto compiled_module =
   1063       backend()
   1064           .compiler()
   1065           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
   1066                          /*device_allocator=*/nullptr)
   1067           .ConsumeValueOrDie();
   1068   HloInstruction* root =
   1069       compiled_module->entry_computation()->root_instruction();
   1070   Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1});
   1071   EXPECT_THAT(root,
   1072               GmockMatch(m::Slice(
   1073                   m::Copy(m::Parameter(0)).WithShapeEqualTo(&shape_copy))));
   1074 }
   1075 
   1076 TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
   1077   // The first infeed uses layout {0,1}, while the second uses layout {1,0}.
   1078   // The mismatch forces a copy of the tuple.  The tuple contains a token, so
   1079   // layout assignment will fail if it tries to copy the whole tuple.
   1080   const char* module_str = R"(
   1081     HloModule TupleCopyOnLayoutMismatch
   1082 
   1083     condition.1 (tup: (s32[], token[], f32[512,1024]{0,1})) -> pred[] {
   1084       tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
   1085       counter.1 = s32[] get-tuple-element(tup.1), index=0
   1086       five = s32[] constant(5)
   1087       ROOT lt = pred[] compare(counter.1, five), direction=LT
   1088     }
   1089 
   1090     body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) {
   1091       tup.2 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
   1092       counter.2 = s32[] get-tuple-element(tup.2), index=0
   1093       tok.2 = token[] get-tuple-element(tup.2), index=1
   1094 
   1095       ifeed.2 = (f32[512,1024]{1,0}, token[]) infeed(tok.2)
   1096       next_tok = token[] get-tuple-element(ifeed.2), index=1
   1097       next_buf = f32[512,1024]{1,0} get-tuple-element(ifeed.2), index=0
   1098 
   1099       one = s32[] constant(1)
   1100       next_counter = s32[] add(counter.2, one)
   1101       ROOT tup = (s32[], token[], f32[512,1024]{0,1}) tuple(next_counter, next_tok, next_buf)
   1102     }
   1103 
   1104     ENTRY main () -> f32[512,1024]{0,1} {
   1105       start_tok = token[] after-all()
   1106 
   1107       ifeed.3 = (f32[512,1024]{0,1}, token[]) infeed(start_tok)
   1108       itok = token[] get-tuple-element(ifeed.3), index=1
   1109       ibuf = f32[512,1024]{0,1} get-tuple-element(ifeed.3), index=0
   1110 
   1111       zero = s32[] constant(0)
   1112       itup = (s32[], token[], f32[512,1024]{0,1}) tuple(zero, itok, ibuf)
   1113 
   1114       loop = (s32[], token[], f32[512,1024]{0,1}) while(itup), condition=condition.1, body=body.2
   1115       ROOT result = f32[512,1024]{0,1} get-tuple-element(loop), index=2
   1116     }
   1117   )";
   1118 
   1119   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
   1120                           ParseAndReturnVerifiedModule(module_str));
   1121   ComputationLayout computation_layout(
   1122       m->entry_computation()->ComputeProgramShape());
   1123 
   1124   // Sanity check to verify that there's a layout mismatch.
   1125   EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1));
   1126   EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0));
   1127 
   1128   AssignLayouts(m.get(), &computation_layout);
   1129 
   1130   // Make sure that layout assignment did not magically eliminate the mismatch,
   1131   // in which case the test didn't prove anything.
   1132   EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1));
   1133   EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0));
   1134 }
   1135 
   1136 TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) {
   1137   const char* module_str = R"(
   1138 HloModule CustomCallNotLayoutConstrained
   1139 
   1140 ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] {
   1141   %p = f32[42,2,3] parameter(0)
   1142   ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz"
   1143 }
   1144 )";
   1145   // Try with a couple different layouts. In each case the custom calls operand
   1146   // and result layout should match that of the computation.
   1147   {
   1148     TF_ASSERT_OK_AND_ASSIGN(
   1149         std::unique_ptr<VerifiedHloModule> m,
   1150         ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
   1151     ComputationLayout computation_layout = m->entry_computation_layout();
   1152     *computation_layout.mutable_parameter_layout(0) =
   1153         ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1}));
   1154     *computation_layout.mutable_result_layout() = ShapeLayout(
   1155         ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1}));
   1156     AssignLayouts(m.get(), &computation_layout);
   1157 
   1158     HloInstruction* root = m->entry_computation()->root_instruction();
   1159     ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
   1160     ExpectLayoutIs(root->shape(), {3, 2, 0, 1});
   1161     ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1});
   1162   }
   1163   {
   1164     TF_ASSERT_OK_AND_ASSIGN(
   1165         std::unique_ptr<VerifiedHloModule> m,
   1166         ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
   1167     ComputationLayout computation_layout = m->entry_computation_layout();
   1168     *computation_layout.mutable_parameter_layout(0) =
   1169         ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2}));
   1170     *computation_layout.mutable_result_layout() = ShapeLayout(
   1171         ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1}));
   1172     AssignLayouts(m.get(), &computation_layout);
   1173 
   1174     HloInstruction* root = m->entry_computation()->root_instruction();
   1175     ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
   1176     ExpectLayoutIs(root->shape(), {0, 2, 3, 1});
   1177     ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2});
   1178   }
   1179 }
   1180 
   1181 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) {
   1182   const char* module_str = R"(
   1183 HloModule CustomCallLayoutConstrained
   1184 
   1185 ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
   1186   %p0 = f32[4,4] parameter(0)
   1187   %p1 = f32[2,3] parameter(1)
   1188   ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}}
   1189 }
   1190 )";
   1191   TF_ASSERT_OK_AND_ASSIGN(
   1192       std::unique_ptr<VerifiedHloModule> m,
   1193       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
   1194   ComputationLayout computation_layout = m->entry_computation_layout();
   1195   *computation_layout.mutable_parameter_layout(0) =
   1196       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
   1197   *computation_layout.mutable_parameter_layout(1) =
   1198       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
   1199   *computation_layout.mutable_result_layout() = ShapeLayout(
   1200       ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
   1201   AssignLayouts(m.get(), &computation_layout);
   1202 
   1203   // The custom call should be partially encapsulated in kCopy instructions
   1204   // because of the layout mismatches.
   1205   ASSERT_THAT(m->entry_computation()->root_instruction(),
   1206               GmockMatch(m::Copy(m::CustomCall(m::Copy(), m::Parameter()))));
   1207 
   1208   const HloInstruction* custom_call =
   1209       m->entry_computation()->root_instruction()->operand(0);
   1210   ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
   1211   ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1});
   1212   ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
   1213 }
   1214 
   1215 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) {
   1216   const char* module_str = R"(
   1217 HloModule CustomCallLayoutConstrainedZeroOperands
   1218 
   1219 ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] {
   1220   ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
   1221 }
   1222 )";
   1223   TF_ASSERT_OK_AND_ASSIGN(
   1224       std::unique_ptr<VerifiedHloModule> m,
   1225       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
   1226   ComputationLayout computation_layout = m->entry_computation_layout();
   1227   *computation_layout.mutable_result_layout() = ShapeLayout(
   1228       ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
   1229   AssignLayouts(m.get(), &computation_layout);
   1230 
   1231   ASSERT_THAT(m->entry_computation()->root_instruction(),
   1232               GmockMatch(m::Copy(m::CustomCall())));
   1233 
   1234   const HloInstruction* custom_call =
   1235       m->entry_computation()->root_instruction()->operand(0);
   1236   ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
   1237 }
   1238 
   1239 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) {
   1240   const char* module_str = R"(
   1241 HloModule CustomCallLayoutConstrainedTupleOperand
   1242 
   1243 ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
   1244   %p0 = f32[4,4] parameter(0)
   1245   %p1 = f32[2,3] parameter(1)
   1246   %tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1)
   1247   ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})}
   1248 }
   1249 )";
   1250   TF_ASSERT_OK_AND_ASSIGN(
   1251       std::unique_ptr<VerifiedHloModule> m,
   1252       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
   1253   ComputationLayout computation_layout = m->entry_computation_layout();
   1254   *computation_layout.mutable_parameter_layout(0) =
   1255       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
   1256   *computation_layout.mutable_parameter_layout(1) =
   1257       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
   1258   *computation_layout.mutable_result_layout() = ShapeLayout(
   1259       ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
   1260   AssignLayouts(m.get(), &computation_layout);
   1261 
   1262   HloInstruction* root = m->entry_computation()->root_instruction();
   1263   ExpectLayoutIs(root->shape(), {2, 1, 0, 3});
   1264 
   1265   ASSERT_THAT(m->entry_computation()->root_instruction(),
   1266               GmockMatch(m::Copy(m::CustomCall(m::Tuple()))));
   1267 
   1268   const HloInstruction* custom_call =
   1269       m->entry_computation()->root_instruction()->operand(0);
   1270   ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
   1271   ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}});
   1272 }
   1273 
   1274 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) {
   1275   const char* module_str = R"(
   1276 HloModule CustomCallLayoutConstrainedTupleResult
   1277 
   1278 ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) {
   1279   %p0 = f32[4,4] parameter(0)
   1280   ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}}
   1281 }
   1282 )";
   1283   // Try with a couple different layouts. In each case the custom calls operand
   1284   // and result layout should match that of the computation.
   1285   TF_ASSERT_OK_AND_ASSIGN(
   1286       std::unique_ptr<VerifiedHloModule> m,
   1287       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
   1288   ComputationLayout computation_layout = m->entry_computation_layout();
   1289   *computation_layout.mutable_parameter_layout(0) =
   1290       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
   1291   *computation_layout.mutable_result_layout() =
   1292       ShapeLayout(ShapeUtil::MakeTupleShape(
   1293           {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}),
   1294            ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})}));
   1295   AssignLayouts(m.get(), &computation_layout);
   1296 
   1297   ExpectTupleLayoutIs(m->result_shape(), {{1, 0}, {1, 0}});
   1298 
   1299   const HloInstruction* custom_call = FindInstruction(m.get(), "custom-call");
   1300   ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}});
   1301 }
   1302 
   1303 Status AssignLayoutsToComputation(
   1304     HloModule* m, ChannelLayoutConstraints* channel_constraints = nullptr) {
   1305   if (!m->entry_computation_layout().result_layout().LayoutIsSet()) {
   1306     m->mutable_entry_computation_layout()
   1307         ->mutable_result_layout()
   1308         ->SetToDefaultLayout();
   1309   }
   1310   LayoutAssignment layout_assignment(
   1311       m->mutable_entry_computation_layout(),
   1312       LayoutAssignment::InstructionCanChangeLayout, channel_constraints);
   1313   return layout_assignment.Run(m).status();
   1314 }
   1315 
   1316 TEST_F(LayoutAssignmentTest, OverwriteDiamondShapedConstraintsX) {
   1317   // Check that we handle a diamond-shaped graph correctly.
   1318   //      transpose
   1319   //       /    \
   1320   //     add    |
   1321   //       \    /
   1322   //        tuple
   1323 
   1324   auto b = HloComputation::Builder(TestName());
   1325   Shape ashape = ShapeUtil::MakeShape(F32, {12, 8});
   1326   Shape bshape = ShapeUtil::MakeShape(F32, {8, 12});
   1327   auto param0 =
   1328       b.AddInstruction(HloInstruction::CreateParameter(0, bshape, "input"));
   1329   auto param1 =
   1330       b.AddInstruction(HloInstruction::CreateParameter(1, ashape, "input"));
   1331   auto transpose =
   1332       b.AddInstruction(HloInstruction::CreateTranspose(ashape, param0, {1, 0}));
   1333   auto add = b.AddInstruction(
   1334       HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, transpose, param1));
   1335   b.AddInstruction(HloInstruction::CreateTuple({add, transpose}));
   1336   auto m = CreateNewVerifiedModule();
   1337   m->AddEntryComputation(b.Build());
   1338   Shape ashape_major = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {1, 0});
   1339   Shape ashape_minor = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {0, 1});
   1340   *m->mutable_entry_computation_layout()->mutable_result_layout() =
   1341       ShapeLayout(ShapeUtil::MakeTupleShape({ashape_major, ashape_minor}));
   1342   const Layout r2_dim0major = LayoutUtil::MakeLayout({1, 0});
   1343   ForceParameterLayout(m.get(), 0, r2_dim0major);
   1344   ForceParameterLayout(m.get(), 1, r2_dim0major);
   1345   TF_ASSERT_OK(AssignLayoutsToComputation(m.get()));
   1346 
   1347   EXPECT_THAT(add->shape().layout().minor_to_major(), ElementsAre(1, 0));
   1348   EXPECT_THAT(add->operand(0)->shape().layout().minor_to_major(),
   1349               ElementsAre(1, 0));
   1350   EXPECT_THAT(add->operand(1)->shape().layout().minor_to_major(),
   1351               ElementsAre(1, 0));
   1352 
   1353   EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(0, 1));
   1354 }
   1355 
   1356 }  // namespace
   1357 }  // namespace xla
   1358