Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h"
     17 
     18 #include <initializer_list>
     19 #include <memory>
     20 #include <utility>
     21 #include <vector>
     22 
     23 #include "absl/types/span.h"
     24 #include "tensorflow/compiler/xla/layout_util.h"
     25 #include "tensorflow/compiler/xla/literal.h"
     26 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
     27 #include "tensorflow/compiler/xla/service/computation_layout.h"
     28 #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h"
     29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     31 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     32 #include "tensorflow/compiler/xla/service/hlo_module.h"
     33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     34 #include "tensorflow/compiler/xla/shape_layout.h"
     35 #include "tensorflow/compiler/xla/shape_util.h"
     36 #include "tensorflow/compiler/xla/test.h"
     37 #include "tensorflow/compiler/xla/test_helpers.h"
     38 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     39 #include "tensorflow/compiler/xla/tests/test_utils.h"
     40 #include "tensorflow/compiler/xla/util.h"
     41 #include "tensorflow/compiler/xla/xla_data.pb.h"
     42 #include "tensorflow/core/lib/core/status.h"
     43 
     44 namespace op = xla::testing::opcode_matchers;
     45 
     46 namespace xla {
     47 namespace {
     48 
     49 class CpuLayoutAssignmentTest : public HloTestBase {
     50  protected:
     51   void AssignLayouts(HloModule* module,
     52                      ComputationLayout* entry_computation_layout) {
     53     cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features(
     54         [](int64 shape_size) {
     55           return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
     56         });
     57     cpu::CpuLayoutAssignment layout_assignment(
     58         entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout,
     59         &target_machine_features);
     60     EXPECT_IS_OK(layout_assignment.Run(module).status());
     61   }
     62 };
     63 
     64 TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) {
     65   auto builder = HloComputation::Builder(TestName());
     66   Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1});
     67   Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24});
     68   Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1});
     69   auto dot_lhs = builder.AddInstruction(
     70       HloInstruction::CreateParameter(0, lhs_shape, "param0"));
     71   auto dot_rhs = builder.AddInstruction(
     72       HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
     73   auto result = builder.AddInstruction(
     74       CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
     75 
     76   auto module = CreateNewVerifiedModule();
     77   HloComputation* computation = module->AddEntryComputation(builder.Build());
     78 
     79   ComputationLayout computation_layout(computation->ComputeProgramShape());
     80   *computation_layout.mutable_parameter_layout(0) =
     81       ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_shape));
     82   *computation_layout.mutable_result_layout() =
     83       ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape));
     84   AssignLayouts(module.get(), &computation_layout);
     85 
     86   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
     87                                 dot_lhs->shape().layout()));
     88   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}),
     89                                 dot_rhs->shape().layout()));
     90   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
     91                                 result->shape().layout()));
     92   for (const auto& instruction : computation->instructions()) {
     93     EXPECT_NE(instruction->opcode(), HloOpcode::kCopy);
     94   }
     95 }
     96 
     97 TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) {
     98   // Two dot products have the same constant as the RHS, and both those dot
     99   // products can be optimized if the constant has a column-major layout.
    100   auto builder = HloComputation::Builder(TestName());
    101   Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1});
    102   Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24});
    103   Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1});
    104   auto dot_a_lhs = builder.AddInstruction(
    105       HloInstruction::CreateParameter(0, lhs_shape, "param0"));
    106   auto dot_b_lhs = builder.AddInstruction(
    107       HloInstruction::CreateParameter(1, lhs_shape, "param1"));
    108   auto dot_rhs = builder.AddInstruction(
    109       HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
    110   auto dot_a_result = builder.AddInstruction(
    111       CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs));
    112   auto dot_b_result = builder.AddInstruction(
    113       CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs));
    114   builder.AddInstruction(HloInstruction::CreateBinary(
    115       result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result));
    116 
    117   auto module = CreateNewVerifiedModule();
    118   HloComputation* computation = module->AddEntryComputation(builder.Build());
    119 
    120   ComputationLayout computation_layout(computation->ComputeProgramShape());
    121   *computation_layout.mutable_parameter_layout(0) =
    122       ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_shape));
    123   *computation_layout.mutable_result_layout() =
    124       ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape));
    125   AssignLayouts(module.get(), &computation_layout);
    126 
    127   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}),
    128                                 dot_rhs->shape().layout()));
    129   for (HloInstruction* instruction :
    130        {dot_a_lhs, dot_b_lhs, dot_a_result, dot_b_result}) {
    131     EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
    132                                   instruction->shape().layout()));
    133   }
    134   for (const auto& instruction : computation->instructions()) {
    135     EXPECT_NE(instruction->opcode(), HloOpcode::kCopy);
    136   }
    137 }
    138 
    139 TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) {
    140   // Two dot products have the same constant as the RHS, but only one of the two
    141   // dot products can be optimized if the constant has a column-major layout.
    142   auto builder = HloComputation::Builder(TestName());
    143   Shape lhs_a_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1});
    144   Shape lhs_b_shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 12}, {0, 1});
    145   Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1});
    146   Shape result_a_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1});
    147   Shape result_b_shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 24}, {0, 1});
    148   auto dot_a_lhs = builder.AddInstruction(
    149       HloInstruction::CreateParameter(0, lhs_a_shape, "param0"));
    150   auto dot_b_lhs = builder.AddInstruction(
    151       HloInstruction::CreateParameter(1, lhs_b_shape, "param1"));
    152   auto dot_rhs = builder.AddInstruction(
    153       HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
    154   auto dot_a_result = builder.AddInstruction(
    155       CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs));
    156   auto dot_b_result = builder.AddInstruction(
    157       CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs));
    158   auto tuple_result = builder.AddInstruction(
    159       HloInstruction::CreateTuple({dot_a_result, dot_b_result}));
    160 
    161   auto module = CreateNewVerifiedModule();
    162   HloComputation* computation = module->AddEntryComputation(builder.Build());
    163 
    164   ComputationLayout computation_layout(computation->ComputeProgramShape());
    165   *computation_layout.mutable_parameter_layout(0) =
    166       ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_a_shape));
    167   *computation_layout.mutable_parameter_layout(1) =
    168       ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_b_shape));
    169   *computation_layout.mutable_result_layout() =
    170       ShapeLayout(LayoutUtil::GetWithDefaultLayout(tuple_result->shape()));
    171   AssignLayouts(module.get(), &computation_layout);
    172 
    173   for (HloInstruction* instruction :
    174        {dot_rhs, dot_a_lhs, dot_b_lhs, dot_a_result, dot_b_result}) {
    175     EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
    176                                   instruction->shape().layout()));
    177   }
    178   for (const auto& instruction : computation->instructions()) {
    179     EXPECT_NE(instruction->opcode(), HloOpcode::kCopy);
    180   }
    181 }
    182 
    183 TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) {
    184   auto builder = HloComputation::Builder(TestName());
    185   Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1});
    186   Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1});
    187   Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1});
    188   auto dot_lhs = builder.AddInstruction(
    189       HloInstruction::CreateConstant(Literal::CreateFromShape(lhs_shape)));
    190   auto dot_rhs = builder.AddInstruction(
    191       HloInstruction::CreateParameter(0, rhs_shape, "param0"));
    192   auto dot_result = builder.AddInstruction(
    193       CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
    194 
    195   auto module = CreateNewVerifiedModule();
    196   HloComputation* computation = module->AddEntryComputation(builder.Build());
    197 
    198   ComputationLayout computation_layout(computation->ComputeProgramShape());
    199   *computation_layout.mutable_parameter_layout(0) =
    200       ShapeLayout(LayoutUtil::GetWithDefaultLayout(rhs_shape));
    201   *computation_layout.mutable_result_layout() =
    202       ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape));
    203   AssignLayouts(module.get(), &computation_layout);
    204 
    205   for (HloInstruction* instruction : {dot_lhs, dot_rhs, dot_result}) {
    206     EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
    207                                   instruction->shape().layout()));
    208   }
    209   for (const auto& instruction : computation->instructions()) {
    210     EXPECT_NE(instruction->opcode(), HloOpcode::kCopy);
    211   }
    212 }
    213 
    214 TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) {
    215   // This is a case we could theoretically optimize at some point, but today we
    216   // don't.
    217   auto builder = HloComputation::Builder(TestName());
    218   Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1});
    219   Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1});
    220   Shape other_shape = ShapeUtil::MakeShapeWithLayout(F32, {100, 24}, {0, 1});
    221 
    222   auto constant_shape = ShapeUtil::MakeTupleShape({other_shape, rhs_shape});
    223   auto constant = builder.AddInstruction(
    224       HloInstruction::CreateConstant(Literal::CreateFromShape(constant_shape)));
    225 
    226   Shape result_shape = ShapeUtil::MakeShape(F32, {1, 24});
    227 
    228   auto dot_lhs = builder.AddInstruction(
    229       HloInstruction::CreateParameter(0, lhs_shape, "param0"));
    230   auto dot_rhs = builder.AddInstruction(
    231       HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1));
    232   auto dot_result = builder.AddInstruction(
    233       CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
    234 
    235   auto module = CreateNewVerifiedModule();
    236   HloComputation* computation = module->AddEntryComputation(builder.Build());
    237 
    238   ComputationLayout computation_layout(computation->ComputeProgramShape());
    239   *computation_layout.mutable_parameter_layout(0) =
    240       ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_shape));
    241   *computation_layout.mutable_result_layout() =
    242       ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape));
    243   AssignLayouts(module.get(), &computation_layout);
    244 
    245   for (HloInstruction* instruction : {dot_lhs, dot_rhs, dot_result}) {
    246     EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
    247                                   instruction->shape().layout()));
    248   }
    249   for (const auto& instruction : computation->instructions()) {
    250     EXPECT_NE(instruction->opcode(), HloOpcode::kCopy);
    251   }
    252 }
    253 
    254 struct DotOutputFusionLayoutAssignmentResult {
    255   bool layout_assignment_changed_something;
    256   const HloInstruction* dot_lhs_fusion_param;
    257   const HloInstruction* dot_rhs_fusion_param;
    258   const HloInstruction* addend_fusion_param;
    259 };
    260 
    261 static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion(
    262     HloModule* module, const string& test_name, int m, int k, int n,
    263     const int64 dot_operand_idx_in_add) {
    264   DotOutputFusionLayoutAssignmentResult result;
    265 
    266   CHECK(dot_operand_idx_in_add == 0 || dot_operand_idx_in_add == 1);
    267 
    268   auto builder = HloComputation::Builder(test_name);
    269 
    270   Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
    271   Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
    272   Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
    273 
    274   HloInstruction* dot_lhs = builder.AddInstruction(
    275       HloInstruction::CreateParameter(0, dot_lhs_shape, "param0"));
    276   HloInstruction* addend = builder.AddInstruction(
    277       HloInstruction::CreateParameter(1, dot_shape, "param1"));
    278   HloInstruction* dot_rhs = builder.AddInstruction(
    279       HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape)));
    280   HloInstruction* dot_result =
    281       builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
    282   HloInstruction* add_result;
    283   if (dot_operand_idx_in_add == 0) {
    284     add_result = builder.AddInstruction(HloInstruction::CreateBinary(
    285         dot_shape, HloOpcode::kAdd, dot_result, addend));
    286   } else {
    287     add_result = builder.AddInstruction(HloInstruction::CreateBinary(
    288         dot_shape, HloOpcode::kAdd, addend, dot_result));
    289   }
    290 
    291   HloComputation* computation = module->AddEntryComputation(builder.Build());
    292 
    293   HloInstruction* fusion_instruction =
    294       module->entry_computation()->AddInstruction(HloInstruction::CreateFusion(
    295           dot_shape, HloInstruction::FusionKind::kOutput, add_result));
    296   TF_RETURN_IF_ERROR(
    297       computation->ReplaceInstruction(add_result, fusion_instruction));
    298 
    299   HloInstruction* fused_add =
    300       fusion_instruction->fused_instructions_computation()->root_instruction();
    301   HloInstruction* fused_dot = fusion_instruction->FuseInstruction(dot_result);
    302 
    303   TF_RETURN_IF_ERROR(
    304       computation->RemoveInstructionAndUnusedOperands(dot_result));
    305 
    306   ComputationLayout computation_layout(computation->ComputeProgramShape());
    307   *computation_layout.mutable_parameter_layout(0) =
    308       ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_lhs_shape));
    309   *computation_layout.mutable_parameter_layout(1) =
    310       ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_shape));
    311   *computation_layout.mutable_result_layout() =
    312       ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_shape));
    313 
    314   result.dot_lhs_fusion_param =
    315       fusion_instruction->operand(fused_dot->operand(0)->parameter_number());
    316   result.dot_rhs_fusion_param =
    317       fusion_instruction->operand(fused_dot->operand(1)->parameter_number());
    318   result.addend_fusion_param = fusion_instruction->operand(
    319       fused_add->operand(1 - dot_operand_idx_in_add)->parameter_number());
    320 
    321   cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features(
    322       [](int64 shape_size) {
    323         return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
    324       });
    325   cpu::CpuLayoutAssignment layout_assignment(
    326       &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
    327       &target_machine_features);
    328   TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something,
    329                       layout_assignment.Run(module));
    330 
    331   return result;
    332 }
    333 
    334 static void AssertCorrectLayoutForDotOutputFusion(
    335     const HloComputation* computation,
    336     const DotOutputFusionLayoutAssignmentResult& layout_assignment_result,
    337     bool expect_col_major_dot_rhs) {
    338   Layout expected_dot_rhs_layout = expect_col_major_dot_rhs
    339                                        ? LayoutUtil::MakeLayout({0, 1})
    340                                        : LayoutUtil::MakeLayout({1, 0});
    341   EXPECT_TRUE(LayoutUtil::Equal(
    342       expected_dot_rhs_layout,
    343       layout_assignment_result.dot_rhs_fusion_param->shape().layout()));
    344 
    345   EXPECT_TRUE(LayoutUtil::Equal(
    346       LayoutUtil::MakeLayout({1, 0}),
    347       layout_assignment_result.dot_lhs_fusion_param->shape().layout()));
    348 
    349   EXPECT_TRUE(LayoutUtil::Equal(
    350       LayoutUtil::MakeLayout({1, 0}),
    351       layout_assignment_result.addend_fusion_param->shape().layout()));
    352   EXPECT_THAT(computation->instructions(), Each(Not(op::Copy())));
    353 }
    354 
    355 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) {
    356   std::unique_ptr<HloModule> module = CreateNewVerifiedModule();
    357   TF_ASSERT_OK_AND_ASSIGN(
    358       DotOutputFusionLayoutAssignmentResult layout_assignment_result,
    359       RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19,
    360                          /*dot_operand_idx_in_add=*/0));
    361   ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something);
    362   AssertCorrectLayoutForDotOutputFusion(module->entry_computation(),
    363                                         layout_assignment_result,
    364                                         /*expect_col_major_dot_rhs=*/true);
    365 }
    366 
    367 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) {
    368   std::unique_ptr<HloModule> module = CreateNewVerifiedModule();
    369   TF_ASSERT_OK_AND_ASSIGN(
    370       DotOutputFusionLayoutAssignmentResult layout_assignment_result,
    371       RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19,
    372                          /*dot_operand_idx_in_add=*/1));
    373   ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something);
    374   AssertCorrectLayoutForDotOutputFusion(module->entry_computation(),
    375                                         layout_assignment_result,
    376                                         /*expect_col_major_dot_rhs=*/true);
    377 }
    378 
    379 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) {
    380   std::unique_ptr<HloModule> module = CreateNewVerifiedModule();
    381   TF_ASSERT_OK_AND_ASSIGN(
    382       DotOutputFusionLayoutAssignmentResult layout_assignment_result,
    383       RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1,
    384                          /*dot_operand_idx_in_add=*/0));
    385   ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something);
    386   AssertCorrectLayoutForDotOutputFusion(module->entry_computation(),
    387                                         layout_assignment_result,
    388                                         /*expect_col_major_dot_rhs=*/false);
    389 }
    390 
    391 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) {
    392   std::unique_ptr<HloModule> module = CreateNewVerifiedModule();
    393   TF_ASSERT_OK_AND_ASSIGN(
    394       DotOutputFusionLayoutAssignmentResult layout_assignment_result,
    395       RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1,
    396                          /*dot_operand_idx_in_add=*/1));
    397   ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something);
    398   AssertCorrectLayoutForDotOutputFusion(module->entry_computation(),
    399                                         layout_assignment_result,
    400                                         /*expect_col_major_dot_rhs=*/false);
    401 }
    402 
    403 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) {
    404   std::unique_ptr<HloModule> module = CreateNewVerifiedModule();
    405   TF_ASSERT_OK_AND_ASSIGN(
    406       DotOutputFusionLayoutAssignmentResult layout_assignment_result,
    407       RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19,
    408                          /*dot_operand_idx_in_add=*/0));
    409   ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something);
    410   AssertCorrectLayoutForDotOutputFusion(module->entry_computation(),
    411                                         layout_assignment_result,
    412                                         /*expect_col_major_dot_rhs=*/false);
    413 }
    414 
    415 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_1) {
    416   std::unique_ptr<HloModule> module = CreateNewVerifiedModule();
    417   TF_ASSERT_OK_AND_ASSIGN(
    418       DotOutputFusionLayoutAssignmentResult layout_assignment_result,
    419       RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19,
    420                          /*dot_operand_idx_in_add=*/1));
    421   ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something);
    422   AssertCorrectLayoutForDotOutputFusion(module->entry_computation(),
    423                                         layout_assignment_result,
    424                                         /*expect_col_major_dot_rhs=*/false);
    425 }
    426 }  // namespace
    427 }  // namespace xla
    428