1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" 17 18 #include <initializer_list> 19 #include <memory> 20 #include <utility> 21 #include <vector> 22 23 #include "absl/types/span.h" 24 #include "tensorflow/compiler/xla/layout_util.h" 25 #include "tensorflow/compiler/xla/literal.h" 26 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" 27 #include "tensorflow/compiler/xla/service/computation_layout.h" 28 #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" 29 #include "tensorflow/compiler/xla/service/hlo_computation.h" 30 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 31 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 32 #include "tensorflow/compiler/xla/service/hlo_module.h" 33 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 34 #include "tensorflow/compiler/xla/shape_layout.h" 35 #include "tensorflow/compiler/xla/shape_util.h" 36 #include "tensorflow/compiler/xla/test.h" 37 #include "tensorflow/compiler/xla/test_helpers.h" 38 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 39 #include "tensorflow/compiler/xla/tests/test_utils.h" 40 #include "tensorflow/compiler/xla/util.h" 41 #include "tensorflow/compiler/xla/xla_data.pb.h" 42 #include "tensorflow/core/lib/core/status.h" 43 44 namespace op = xla::testing::opcode_matchers; 45 46 namespace xla { 47 namespace { 48 49 class CpuLayoutAssignmentTest : public HloTestBase { 50 protected: 51 void AssignLayouts(HloModule* module, 52 ComputationLayout* entry_computation_layout) { 53 cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( 54 [](int64 shape_size) { 55 return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; 56 }); 57 cpu::CpuLayoutAssignment layout_assignment( 58 entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout, 59 &target_machine_features); 60 EXPECT_IS_OK(layout_assignment.Run(module).status()); 61 } 62 }; 63 64 TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { 65 auto builder = HloComputation::Builder(TestName()); 66 Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1}); 67 Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24}); 68 Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1}); 69 auto dot_lhs = builder.AddInstruction( 70 HloInstruction::CreateParameter(0, lhs_shape, "param0")); 71 auto dot_rhs = builder.AddInstruction( 72 HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); 73 auto result = builder.AddInstruction( 74 CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); 75 76 auto module = CreateNewVerifiedModule(); 77 HloComputation* computation = module->AddEntryComputation(builder.Build()); 78 79 ComputationLayout computation_layout(computation->ComputeProgramShape()); 80 *computation_layout.mutable_parameter_layout(0) = 81 ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_shape)); 82 *computation_layout.mutable_result_layout() = 83 ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape)); 84 AssignLayouts(module.get(), &computation_layout); 85 86 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), 87 dot_lhs->shape().layout())); 88 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), 89 dot_rhs->shape().layout())); 90 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), 91 result->shape().layout())); 92 for (const auto& instruction : computation->instructions()) { 93 EXPECT_NE(instruction->opcode(), HloOpcode::kCopy); 94 } 95 } 96 97 TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { 98 // Two dot products have the same constant as the RHS, and both those dot 99 // products can be optimized if the constant has a column-major layout. 100 auto builder = HloComputation::Builder(TestName()); 101 Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1}); 102 Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24}); 103 Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1}); 104 auto dot_a_lhs = builder.AddInstruction( 105 HloInstruction::CreateParameter(0, lhs_shape, "param0")); 106 auto dot_b_lhs = builder.AddInstruction( 107 HloInstruction::CreateParameter(1, lhs_shape, "param1")); 108 auto dot_rhs = builder.AddInstruction( 109 HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); 110 auto dot_a_result = builder.AddInstruction( 111 CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); 112 auto dot_b_result = builder.AddInstruction( 113 CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); 114 builder.AddInstruction(HloInstruction::CreateBinary( 115 result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); 116 117 auto module = CreateNewVerifiedModule(); 118 HloComputation* computation = module->AddEntryComputation(builder.Build()); 119 120 ComputationLayout computation_layout(computation->ComputeProgramShape()); 121 *computation_layout.mutable_parameter_layout(0) = 122 ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_shape)); 123 *computation_layout.mutable_result_layout() = 124 ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape)); 125 AssignLayouts(module.get(), &computation_layout); 126 127 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), 128 dot_rhs->shape().layout())); 129 for (HloInstruction* instruction : 130 {dot_a_lhs, dot_b_lhs, dot_a_result, dot_b_result}) { 131 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), 132 instruction->shape().layout())); 133 } 134 for (const auto& instruction : computation->instructions()) { 135 EXPECT_NE(instruction->opcode(), HloOpcode::kCopy); 136 } 137 } 138 139 TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { 140 // Two dot products have the same constant as the RHS, but only one of the two 141 // dot products can be optimized if the constant has a column-major layout. 142 auto builder = HloComputation::Builder(TestName()); 143 Shape lhs_a_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1}); 144 Shape lhs_b_shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 12}, {0, 1}); 145 Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1}); 146 Shape result_a_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1}); 147 Shape result_b_shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 24}, {0, 1}); 148 auto dot_a_lhs = builder.AddInstruction( 149 HloInstruction::CreateParameter(0, lhs_a_shape, "param0")); 150 auto dot_b_lhs = builder.AddInstruction( 151 HloInstruction::CreateParameter(1, lhs_b_shape, "param1")); 152 auto dot_rhs = builder.AddInstruction( 153 HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); 154 auto dot_a_result = builder.AddInstruction( 155 CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); 156 auto dot_b_result = builder.AddInstruction( 157 CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); 158 auto tuple_result = builder.AddInstruction( 159 HloInstruction::CreateTuple({dot_a_result, dot_b_result})); 160 161 auto module = CreateNewVerifiedModule(); 162 HloComputation* computation = module->AddEntryComputation(builder.Build()); 163 164 ComputationLayout computation_layout(computation->ComputeProgramShape()); 165 *computation_layout.mutable_parameter_layout(0) = 166 ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_a_shape)); 167 *computation_layout.mutable_parameter_layout(1) = 168 ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_b_shape)); 169 *computation_layout.mutable_result_layout() = 170 ShapeLayout(LayoutUtil::GetWithDefaultLayout(tuple_result->shape())); 171 AssignLayouts(module.get(), &computation_layout); 172 173 for (HloInstruction* instruction : 174 {dot_rhs, dot_a_lhs, dot_b_lhs, dot_a_result, dot_b_result}) { 175 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), 176 instruction->shape().layout())); 177 } 178 for (const auto& instruction : computation->instructions()) { 179 EXPECT_NE(instruction->opcode(), HloOpcode::kCopy); 180 } 181 } 182 183 TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { 184 auto builder = HloComputation::Builder(TestName()); 185 Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1}); 186 Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1}); 187 Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1}); 188 auto dot_lhs = builder.AddInstruction( 189 HloInstruction::CreateConstant(Literal::CreateFromShape(lhs_shape))); 190 auto dot_rhs = builder.AddInstruction( 191 HloInstruction::CreateParameter(0, rhs_shape, "param0")); 192 auto dot_result = builder.AddInstruction( 193 CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); 194 195 auto module = CreateNewVerifiedModule(); 196 HloComputation* computation = module->AddEntryComputation(builder.Build()); 197 198 ComputationLayout computation_layout(computation->ComputeProgramShape()); 199 *computation_layout.mutable_parameter_layout(0) = 200 ShapeLayout(LayoutUtil::GetWithDefaultLayout(rhs_shape)); 201 *computation_layout.mutable_result_layout() = 202 ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape)); 203 AssignLayouts(module.get(), &computation_layout); 204 205 for (HloInstruction* instruction : {dot_lhs, dot_rhs, dot_result}) { 206 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), 207 instruction->shape().layout())); 208 } 209 for (const auto& instruction : computation->instructions()) { 210 EXPECT_NE(instruction->opcode(), HloOpcode::kCopy); 211 } 212 } 213 214 TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { 215 // This is a case we could theoretically optimize at some point, but today we 216 // don't. 217 auto builder = HloComputation::Builder(TestName()); 218 Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1}); 219 Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1}); 220 Shape other_shape = ShapeUtil::MakeShapeWithLayout(F32, {100, 24}, {0, 1}); 221 222 auto constant_shape = ShapeUtil::MakeTupleShape({other_shape, rhs_shape}); 223 auto constant = builder.AddInstruction( 224 HloInstruction::CreateConstant(Literal::CreateFromShape(constant_shape))); 225 226 Shape result_shape = ShapeUtil::MakeShape(F32, {1, 24}); 227 228 auto dot_lhs = builder.AddInstruction( 229 HloInstruction::CreateParameter(0, lhs_shape, "param0")); 230 auto dot_rhs = builder.AddInstruction( 231 HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1)); 232 auto dot_result = builder.AddInstruction( 233 CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); 234 235 auto module = CreateNewVerifiedModule(); 236 HloComputation* computation = module->AddEntryComputation(builder.Build()); 237 238 ComputationLayout computation_layout(computation->ComputeProgramShape()); 239 *computation_layout.mutable_parameter_layout(0) = 240 ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_shape)); 241 *computation_layout.mutable_result_layout() = 242 ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape)); 243 AssignLayouts(module.get(), &computation_layout); 244 245 for (HloInstruction* instruction : {dot_lhs, dot_rhs, dot_result}) { 246 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), 247 instruction->shape().layout())); 248 } 249 for (const auto& instruction : computation->instructions()) { 250 EXPECT_NE(instruction->opcode(), HloOpcode::kCopy); 251 } 252 } 253 254 struct DotOutputFusionLayoutAssignmentResult { 255 bool layout_assignment_changed_something; 256 const HloInstruction* dot_lhs_fusion_param; 257 const HloInstruction* dot_rhs_fusion_param; 258 const HloInstruction* addend_fusion_param; 259 }; 260 261 static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion( 262 HloModule* module, const string& test_name, int m, int k, int n, 263 const int64 dot_operand_idx_in_add) { 264 DotOutputFusionLayoutAssignmentResult result; 265 266 CHECK(dot_operand_idx_in_add == 0 || dot_operand_idx_in_add == 1); 267 268 auto builder = HloComputation::Builder(test_name); 269 270 Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); 271 Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); 272 Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); 273 274 HloInstruction* dot_lhs = builder.AddInstruction( 275 HloInstruction::CreateParameter(0, dot_lhs_shape, "param0")); 276 HloInstruction* addend = builder.AddInstruction( 277 HloInstruction::CreateParameter(1, dot_shape, "param1")); 278 HloInstruction* dot_rhs = builder.AddInstruction( 279 HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape))); 280 HloInstruction* dot_result = 281 builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); 282 HloInstruction* add_result; 283 if (dot_operand_idx_in_add == 0) { 284 add_result = builder.AddInstruction(HloInstruction::CreateBinary( 285 dot_shape, HloOpcode::kAdd, dot_result, addend)); 286 } else { 287 add_result = builder.AddInstruction(HloInstruction::CreateBinary( 288 dot_shape, HloOpcode::kAdd, addend, dot_result)); 289 } 290 291 HloComputation* computation = module->AddEntryComputation(builder.Build()); 292 293 HloInstruction* fusion_instruction = 294 module->entry_computation()->AddInstruction(HloInstruction::CreateFusion( 295 dot_shape, HloInstruction::FusionKind::kOutput, add_result)); 296 TF_RETURN_IF_ERROR( 297 computation->ReplaceInstruction(add_result, fusion_instruction)); 298 299 HloInstruction* fused_add = 300 fusion_instruction->fused_instructions_computation()->root_instruction(); 301 HloInstruction* fused_dot = fusion_instruction->FuseInstruction(dot_result); 302 303 TF_RETURN_IF_ERROR( 304 computation->RemoveInstructionAndUnusedOperands(dot_result)); 305 306 ComputationLayout computation_layout(computation->ComputeProgramShape()); 307 *computation_layout.mutable_parameter_layout(0) = 308 ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_lhs_shape)); 309 *computation_layout.mutable_parameter_layout(1) = 310 ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_shape)); 311 *computation_layout.mutable_result_layout() = 312 ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_shape)); 313 314 result.dot_lhs_fusion_param = 315 fusion_instruction->operand(fused_dot->operand(0)->parameter_number()); 316 result.dot_rhs_fusion_param = 317 fusion_instruction->operand(fused_dot->operand(1)->parameter_number()); 318 result.addend_fusion_param = fusion_instruction->operand( 319 fused_add->operand(1 - dot_operand_idx_in_add)->parameter_number()); 320 321 cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( 322 [](int64 shape_size) { 323 return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; 324 }); 325 cpu::CpuLayoutAssignment layout_assignment( 326 &computation_layout, LayoutAssignment::InstructionCanChangeLayout, 327 &target_machine_features); 328 TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something, 329 layout_assignment.Run(module)); 330 331 return result; 332 } 333 334 static void AssertCorrectLayoutForDotOutputFusion( 335 const HloComputation* computation, 336 const DotOutputFusionLayoutAssignmentResult& layout_assignment_result, 337 bool expect_col_major_dot_rhs) { 338 Layout expected_dot_rhs_layout = expect_col_major_dot_rhs 339 ? LayoutUtil::MakeLayout({0, 1}) 340 : LayoutUtil::MakeLayout({1, 0}); 341 EXPECT_TRUE(LayoutUtil::Equal( 342 expected_dot_rhs_layout, 343 layout_assignment_result.dot_rhs_fusion_param->shape().layout())); 344 345 EXPECT_TRUE(LayoutUtil::Equal( 346 LayoutUtil::MakeLayout({1, 0}), 347 layout_assignment_result.dot_lhs_fusion_param->shape().layout())); 348 349 EXPECT_TRUE(LayoutUtil::Equal( 350 LayoutUtil::MakeLayout({1, 0}), 351 layout_assignment_result.addend_fusion_param->shape().layout())); 352 EXPECT_THAT(computation->instructions(), Each(Not(op::Copy()))); 353 } 354 355 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { 356 std::unique_ptr<HloModule> module = CreateNewVerifiedModule(); 357 TF_ASSERT_OK_AND_ASSIGN( 358 DotOutputFusionLayoutAssignmentResult layout_assignment_result, 359 RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, 360 /*dot_operand_idx_in_add=*/0)); 361 ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); 362 AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), 363 layout_assignment_result, 364 /*expect_col_major_dot_rhs=*/true); 365 } 366 367 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { 368 std::unique_ptr<HloModule> module = CreateNewVerifiedModule(); 369 TF_ASSERT_OK_AND_ASSIGN( 370 DotOutputFusionLayoutAssignmentResult layout_assignment_result, 371 RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, 372 /*dot_operand_idx_in_add=*/1)); 373 ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); 374 AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), 375 layout_assignment_result, 376 /*expect_col_major_dot_rhs=*/true); 377 } 378 379 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { 380 std::unique_ptr<HloModule> module = CreateNewVerifiedModule(); 381 TF_ASSERT_OK_AND_ASSIGN( 382 DotOutputFusionLayoutAssignmentResult layout_assignment_result, 383 RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, 384 /*dot_operand_idx_in_add=*/0)); 385 ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); 386 AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), 387 layout_assignment_result, 388 /*expect_col_major_dot_rhs=*/false); 389 } 390 391 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { 392 std::unique_ptr<HloModule> module = CreateNewVerifiedModule(); 393 TF_ASSERT_OK_AND_ASSIGN( 394 DotOutputFusionLayoutAssignmentResult layout_assignment_result, 395 RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, 396 /*dot_operand_idx_in_add=*/1)); 397 ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); 398 AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), 399 layout_assignment_result, 400 /*expect_col_major_dot_rhs=*/false); 401 } 402 403 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { 404 std::unique_ptr<HloModule> module = CreateNewVerifiedModule(); 405 TF_ASSERT_OK_AND_ASSIGN( 406 DotOutputFusionLayoutAssignmentResult layout_assignment_result, 407 RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, 408 /*dot_operand_idx_in_add=*/0)); 409 ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); 410 AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), 411 layout_assignment_result, 412 /*expect_col_major_dot_rhs=*/false); 413 } 414 415 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_1) { 416 std::unique_ptr<HloModule> module = CreateNewVerifiedModule(); 417 TF_ASSERT_OK_AND_ASSIGN( 418 DotOutputFusionLayoutAssignmentResult layout_assignment_result, 419 RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, 420 /*dot_operand_idx_in_add=*/1)); 421 ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); 422 AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), 423 layout_assignment_result, 424 /*expect_col_major_dot_rhs=*/false); 425 } 426 } // namespace 427 } // namespace xla 428