Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h"
     17 
     18 #include <initializer_list>
     19 #include <memory>
     20 #include <utility>
     21 #include <vector>
     22 
     23 #include "tensorflow/compiler/xla/layout_util.h"
     24 #include "tensorflow/compiler/xla/literal_util.h"
     25 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
     26 #include "tensorflow/compiler/xla/service/computation_layout.h"
     27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     29 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
     30 #include "tensorflow/compiler/xla/service/hlo_module.h"
     31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     32 #include "tensorflow/compiler/xla/shape_layout.h"
     33 #include "tensorflow/compiler/xla/shape_util.h"
     34 #include "tensorflow/compiler/xla/test.h"
     35 #include "tensorflow/compiler/xla/test_helpers.h"
     36 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     37 #include "tensorflow/compiler/xla/tests/test_utils.h"
     38 #include "tensorflow/compiler/xla/util.h"
     39 #include "tensorflow/compiler/xla/xla_data.pb.h"
     40 #include "tensorflow/core/lib/core/status.h"
     41 #include "tensorflow/core/lib/gtl/array_slice.h"
     42 
     43 namespace op = xla::testing::opcode_matchers;
     44 
     45 namespace xla {
     46 namespace {
     47 
     48 class CpuLayoutAssignmentTest : public HloTestBase {
     49  protected:
     50   void AssignLayouts(HloModule* module,
     51                      ComputationLayout* entry_computation_layout) {
     52     cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout);
     53     EXPECT_IS_OK(layout_assignment.Run(module).status());
     54   }
     55 };
     56 
     57 TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) {
     58   auto builder = HloComputation::Builder(TestName());
     59   Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1});
     60   Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24});
     61   Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1});
     62   auto dot_lhs = builder.AddInstruction(
     63       HloInstruction::CreateParameter(0, lhs_shape, "param0"));
     64   auto dot_rhs = builder.AddInstruction(
     65       HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
     66   auto result = builder.AddInstruction(
     67       HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
     68 
     69   auto module = CreateNewModule();
     70   HloComputation* computation = module->AddEntryComputation(builder.Build());
     71 
     72   ComputationLayout computation_layout(computation->ComputeProgramShape());
     73   *computation_layout.mutable_parameter_layout(0) =
     74       ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_shape));
     75   *computation_layout.mutable_result_layout() =
     76       ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape));
     77   AssignLayouts(module.get(), &computation_layout);
     78 
     79   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
     80                                 dot_lhs->shape().layout()));
     81   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}),
     82                                 dot_rhs->shape().layout()));
     83   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
     84                                 result->shape().layout()));
     85   for (const auto& instruction : computation->instructions()) {
     86     EXPECT_NE(instruction->opcode(), HloOpcode::kCopy);
     87   }
     88 }
     89 
     90 TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) {
     91   // Two dot products have the same constant as the RHS, and both those dot
     92   // products can be optimized if the constant has a column-major layout.
     93   auto builder = HloComputation::Builder(TestName());
     94   Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1});
     95   Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24});
     96   Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1});
     97   auto dot_a_lhs = builder.AddInstruction(
     98       HloInstruction::CreateParameter(0, lhs_shape, "param0"));
     99   auto dot_b_lhs = builder.AddInstruction(
    100       HloInstruction::CreateParameter(1, lhs_shape, "param1"));
    101   auto dot_rhs = builder.AddInstruction(
    102       HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
    103   auto dot_a_result = builder.AddInstruction(
    104       HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs));
    105   auto dot_b_result = builder.AddInstruction(
    106       HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs));
    107   builder.AddInstruction(HloInstruction::CreateBinary(
    108       result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result));
    109 
    110   auto module = CreateNewModule();
    111   HloComputation* computation = module->AddEntryComputation(builder.Build());
    112 
    113   ComputationLayout computation_layout(computation->ComputeProgramShape());
    114   *computation_layout.mutable_parameter_layout(0) =
    115       ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_shape));
    116   *computation_layout.mutable_result_layout() =
    117       ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape));
    118   AssignLayouts(module.get(), &computation_layout);
    119 
    120   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}),
    121                                 dot_rhs->shape().layout()));
    122   for (HloInstruction* instruction :
    123        {dot_a_lhs, dot_b_lhs, dot_a_result, dot_b_result}) {
    124     EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
    125                                   instruction->shape().layout()));
    126   }
    127   for (const auto& instruction : computation->instructions()) {
    128     EXPECT_NE(instruction->opcode(), HloOpcode::kCopy);
    129   }
    130 }
    131 
    132 TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) {
    133   // Two dot products have the same constant as the RHS, but only one of the two
    134   // dot products can be optimized if the constant has a column-major layout.
    135   auto builder = HloComputation::Builder(TestName());
    136   Shape lhs_a_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1});
    137   Shape lhs_b_shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 12}, {0, 1});
    138   Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1});
    139   Shape result_a_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1});
    140   Shape result_b_shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 24}, {0, 1});
    141   auto dot_a_lhs = builder.AddInstruction(
    142       HloInstruction::CreateParameter(0, lhs_a_shape, "param0"));
    143   auto dot_b_lhs = builder.AddInstruction(
    144       HloInstruction::CreateParameter(1, lhs_b_shape, "param1"));
    145   auto dot_rhs = builder.AddInstruction(
    146       HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
    147   auto dot_a_result = builder.AddInstruction(
    148       HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs));
    149   auto dot_b_result = builder.AddInstruction(
    150       HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs));
    151   auto tuple_result = builder.AddInstruction(
    152       HloInstruction::CreateTuple({dot_a_result, dot_b_result}));
    153 
    154   auto module = CreateNewModule();
    155   HloComputation* computation = module->AddEntryComputation(builder.Build());
    156 
    157   ComputationLayout computation_layout(computation->ComputeProgramShape());
    158   *computation_layout.mutable_parameter_layout(0) =
    159       ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_a_shape));
    160   *computation_layout.mutable_parameter_layout(1) =
    161       ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_b_shape));
    162   *computation_layout.mutable_result_layout() =
    163       ShapeLayout(LayoutUtil::GetWithDefaultLayout(tuple_result->shape()));
    164   AssignLayouts(module.get(), &computation_layout);
    165 
    166   for (HloInstruction* instruction :
    167        {dot_rhs, dot_a_lhs, dot_b_lhs, dot_a_result, dot_b_result}) {
    168     EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
    169                                   instruction->shape().layout()));
    170   }
    171   for (const auto& instruction : computation->instructions()) {
    172     EXPECT_NE(instruction->opcode(), HloOpcode::kCopy);
    173   }
    174 }
    175 
    176 TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) {
    177   auto builder = HloComputation::Builder(TestName());
    178   Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1});
    179   Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1});
    180   Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1});
    181   auto dot_lhs = builder.AddInstruction(
    182       HloInstruction::CreateConstant(Literal::CreateFromShape(lhs_shape)));
    183   auto dot_rhs = builder.AddInstruction(
    184       HloInstruction::CreateParameter(0, rhs_shape, "param0"));
    185   auto dot_result = builder.AddInstruction(
    186       HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
    187 
    188   auto module = CreateNewModule();
    189   HloComputation* computation = module->AddEntryComputation(builder.Build());
    190 
    191   ComputationLayout computation_layout(computation->ComputeProgramShape());
    192   *computation_layout.mutable_parameter_layout(0) =
    193       ShapeLayout(LayoutUtil::GetWithDefaultLayout(rhs_shape));
    194   *computation_layout.mutable_result_layout() =
    195       ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape));
    196   AssignLayouts(module.get(), &computation_layout);
    197 
    198   for (HloInstruction* instruction : {dot_lhs, dot_rhs, dot_result}) {
    199     EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
    200                                   instruction->shape().layout()));
    201   }
    202   for (const auto& instruction : computation->instructions()) {
    203     EXPECT_NE(instruction->opcode(), HloOpcode::kCopy);
    204   }
    205 }
    206 
    207 TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) {
    208   // This is a case we could theoretically optimize at some point, but today we
    209   // don't.
    210   auto builder = HloComputation::Builder(TestName());
    211   Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1});
    212   Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1});
    213   Shape other_shape = ShapeUtil::MakeShapeWithLayout(F32, {100, 24}, {0, 1});
    214 
    215   auto constant_shape = ShapeUtil::MakeTupleShape({other_shape, rhs_shape});
    216   auto constant = builder.AddInstruction(
    217       HloInstruction::CreateConstant(Literal::CreateFromShape(constant_shape)));
    218 
    219   Shape result_shape = ShapeUtil::MakeShape(F32, {1, 24});
    220 
    221   auto dot_lhs = builder.AddInstruction(
    222       HloInstruction::CreateParameter(0, lhs_shape, "param0"));
    223   auto dot_rhs = builder.AddInstruction(
    224       HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1));
    225   auto dot_result = builder.AddInstruction(
    226       HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
    227 
    228   auto module = CreateNewModule();
    229   HloComputation* computation = module->AddEntryComputation(builder.Build());
    230 
    231   ComputationLayout computation_layout(computation->ComputeProgramShape());
    232   *computation_layout.mutable_parameter_layout(0) =
    233       ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_shape));
    234   *computation_layout.mutable_result_layout() =
    235       ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape));
    236   AssignLayouts(module.get(), &computation_layout);
    237 
    238   for (HloInstruction* instruction : {dot_lhs, dot_rhs, dot_result}) {
    239     EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
    240                                   instruction->shape().layout()));
    241   }
    242   for (const auto& instruction : computation->instructions()) {
    243     EXPECT_NE(instruction->opcode(), HloOpcode::kCopy);
    244   }
    245 }
    246 
    247 struct DotOutputFusionLayoutAssignmentResult {
    248   bool layout_assignment_changed_something;
    249   const HloInstruction* dot_lhs_fusion_param;
    250   const HloInstruction* dot_rhs_fusion_param;
    251   const HloInstruction* addend_fusion_param;
    252 };
    253 
    254 static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion(
    255     HloModule* module, const string& test_name, int m, int k, int n,
    256     const int64 dot_operand_idx_in_add) {
    257   DotOutputFusionLayoutAssignmentResult result;
    258 
    259   CHECK(dot_operand_idx_in_add == 0 || dot_operand_idx_in_add == 1);
    260 
    261   auto builder = HloComputation::Builder(test_name);
    262 
    263   Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
    264   Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
    265   Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
    266 
    267   HloInstruction* dot_lhs = builder.AddInstruction(
    268       HloInstruction::CreateParameter(0, dot_lhs_shape, "param0"));
    269   HloInstruction* addend = builder.AddInstruction(
    270       HloInstruction::CreateParameter(1, dot_shape, "param1"));
    271   HloInstruction* dot_rhs = builder.AddInstruction(
    272       HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape)));
    273   HloInstruction* dot_result = builder.AddInstruction(
    274       HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
    275   HloInstruction* add_result;
    276   if (dot_operand_idx_in_add == 0) {
    277     add_result = builder.AddInstruction(HloInstruction::CreateBinary(
    278         dot_shape, HloOpcode::kAdd, dot_result, addend));
    279   } else {
    280     add_result = builder.AddInstruction(HloInstruction::CreateBinary(
    281         dot_shape, HloOpcode::kAdd, addend, dot_result));
    282   }
    283 
    284   HloComputation* computation = module->AddEntryComputation(builder.Build());
    285 
    286   HloInstruction* fusion_instruction =
    287       module->entry_computation()->AddInstruction(HloInstruction::CreateFusion(
    288           dot_shape, HloInstruction::FusionKind::kOutput, add_result));
    289   TF_RETURN_IF_ERROR(
    290       computation->ReplaceInstruction(add_result, fusion_instruction));
    291 
    292   HloInstruction* fused_add =
    293       fusion_instruction->fused_instructions_computation()->root_instruction();
    294   HloInstruction* fused_dot = fusion_instruction->FuseInstruction(dot_result);
    295 
    296   TF_RETURN_IF_ERROR(
    297       computation->RemoveInstructionAndUnusedOperands(dot_result));
    298 
    299   ComputationLayout computation_layout(computation->ComputeProgramShape());
    300   *computation_layout.mutable_parameter_layout(0) =
    301       ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_lhs_shape));
    302   *computation_layout.mutable_parameter_layout(1) =
    303       ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_shape));
    304   *computation_layout.mutable_result_layout() =
    305       ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_shape));
    306 
    307   result.dot_lhs_fusion_param =
    308       fusion_instruction->operand(fused_dot->operand(0)->parameter_number());
    309   result.dot_rhs_fusion_param =
    310       fusion_instruction->operand(fused_dot->operand(1)->parameter_number());
    311   result.addend_fusion_param = fusion_instruction->operand(
    312       fused_add->operand(1 - dot_operand_idx_in_add)->parameter_number());
    313 
    314   cpu::CpuLayoutAssignment layout_assignment(&computation_layout);
    315   TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something,
    316                       layout_assignment.Run(module));
    317 
    318   return result;
    319 }
    320 
    321 static void AssertCorrectLayoutForDotOutputFusion(
    322     const HloComputation* computation,
    323     const DotOutputFusionLayoutAssignmentResult& layout_assignment_result,
    324     bool expect_col_major_dot_rhs) {
    325   Layout expected_dot_rhs_layout = expect_col_major_dot_rhs
    326                                        ? LayoutUtil::MakeLayout({0, 1})
    327                                        : LayoutUtil::MakeLayout({1, 0});
    328   EXPECT_TRUE(LayoutUtil::Equal(
    329       expected_dot_rhs_layout,
    330       layout_assignment_result.dot_rhs_fusion_param->shape().layout()));
    331 
    332   EXPECT_TRUE(LayoutUtil::Equal(
    333       LayoutUtil::MakeLayout({1, 0}),
    334       layout_assignment_result.dot_lhs_fusion_param->shape().layout()));
    335 
    336   EXPECT_TRUE(LayoutUtil::Equal(
    337       LayoutUtil::MakeLayout({1, 0}),
    338       layout_assignment_result.addend_fusion_param->shape().layout()));
    339   EXPECT_THAT(computation->instructions(), Each(Not(op::Copy())));
    340 }
    341 
    342 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) {
    343   std::unique_ptr<HloModule> module = CreateNewModule();
    344   TF_ASSERT_OK_AND_ASSIGN(
    345       DotOutputFusionLayoutAssignmentResult layout_assignment_result,
    346       RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19,
    347                          /*dot_operand_idx_in_add=*/0));
    348   ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something);
    349   AssertCorrectLayoutForDotOutputFusion(module->entry_computation(),
    350                                         layout_assignment_result,
    351                                         /*expect_col_major_dot_rhs=*/true);
    352 }
    353 
    354 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) {
    355   std::unique_ptr<HloModule> module = CreateNewModule();
    356   TF_ASSERT_OK_AND_ASSIGN(
    357       DotOutputFusionLayoutAssignmentResult layout_assignment_result,
    358       RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19,
    359                          /*dot_operand_idx_in_add=*/1));
    360   ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something);
    361   AssertCorrectLayoutForDotOutputFusion(module->entry_computation(),
    362                                         layout_assignment_result,
    363                                         /*expect_col_major_dot_rhs=*/true);
    364 }
    365 
    366 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) {
    367   std::unique_ptr<HloModule> module = CreateNewModule();
    368   TF_ASSERT_OK_AND_ASSIGN(
    369       DotOutputFusionLayoutAssignmentResult layout_assignment_result,
    370       RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1,
    371                          /*dot_operand_idx_in_add=*/0));
    372   ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something);
    373   AssertCorrectLayoutForDotOutputFusion(module->entry_computation(),
    374                                         layout_assignment_result,
    375                                         /*expect_col_major_dot_rhs=*/false);
    376 }
    377 
    378 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) {
    379   std::unique_ptr<HloModule> module = CreateNewModule();
    380   TF_ASSERT_OK_AND_ASSIGN(
    381       DotOutputFusionLayoutAssignmentResult layout_assignment_result,
    382       RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1,
    383                          /*dot_operand_idx_in_add=*/1));
    384   ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something);
    385   AssertCorrectLayoutForDotOutputFusion(module->entry_computation(),
    386                                         layout_assignment_result,
    387                                         /*expect_col_major_dot_rhs=*/false);
    388 }
    389 
    390 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) {
    391   std::unique_ptr<HloModule> module = CreateNewModule();
    392   TF_ASSERT_OK_AND_ASSIGN(
    393       DotOutputFusionLayoutAssignmentResult layout_assignment_result,
    394       RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19,
    395                          /*dot_operand_idx_in_add=*/0));
    396   ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something);
    397   AssertCorrectLayoutForDotOutputFusion(module->entry_computation(),
    398                                         layout_assignment_result,
    399                                         /*expect_col_major_dot_rhs=*/false);
    400 }
    401 
    402 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_1) {
    403   std::unique_ptr<HloModule> module = CreateNewModule();
    404   TF_ASSERT_OK_AND_ASSIGN(
    405       DotOutputFusionLayoutAssignmentResult layout_assignment_result,
    406       RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19,
    407                          /*dot_operand_idx_in_add=*/1));
    408   ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something);
    409   AssertCorrectLayoutForDotOutputFusion(module->entry_computation(),
    410                                         layout_assignment_result,
    411                                         /*expect_col_major_dot_rhs=*/false);
    412 }
    413 }  // namespace
    414 }  // namespace xla
    415