1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" 17 18 #include <initializer_list> 19 #include <memory> 20 #include <utility> 21 #include <vector> 22 23 #include "tensorflow/compiler/xla/layout_util.h" 24 #include "tensorflow/compiler/xla/literal_util.h" 25 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" 26 #include "tensorflow/compiler/xla/service/computation_layout.h" 27 #include "tensorflow/compiler/xla/service/hlo_computation.h" 28 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 29 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 30 #include "tensorflow/compiler/xla/service/hlo_module.h" 31 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 32 #include "tensorflow/compiler/xla/shape_layout.h" 33 #include "tensorflow/compiler/xla/shape_util.h" 34 #include "tensorflow/compiler/xla/test.h" 35 #include "tensorflow/compiler/xla/test_helpers.h" 36 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 37 #include "tensorflow/compiler/xla/tests/test_utils.h" 38 #include "tensorflow/compiler/xla/util.h" 39 #include "tensorflow/compiler/xla/xla_data.pb.h" 40 #include "tensorflow/core/lib/core/status.h" 41 #include "tensorflow/core/lib/gtl/array_slice.h" 42 43 namespace op = xla::testing::opcode_matchers; 44 45 namespace xla { 46 namespace { 47 48 class CpuLayoutAssignmentTest : public HloTestBase { 49 protected: 50 void AssignLayouts(HloModule* module, 51 ComputationLayout* entry_computation_layout) { 52 cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout); 53 EXPECT_IS_OK(layout_assignment.Run(module).status()); 54 } 55 }; 56 57 TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { 58 auto builder = HloComputation::Builder(TestName()); 59 Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1}); 60 Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24}); 61 Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1}); 62 auto dot_lhs = builder.AddInstruction( 63 HloInstruction::CreateParameter(0, lhs_shape, "param0")); 64 auto dot_rhs = builder.AddInstruction( 65 HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); 66 auto result = builder.AddInstruction( 67 HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); 68 69 auto module = CreateNewModule(); 70 HloComputation* computation = module->AddEntryComputation(builder.Build()); 71 72 ComputationLayout computation_layout(computation->ComputeProgramShape()); 73 *computation_layout.mutable_parameter_layout(0) = 74 ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_shape)); 75 *computation_layout.mutable_result_layout() = 76 ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape)); 77 AssignLayouts(module.get(), &computation_layout); 78 79 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), 80 dot_lhs->shape().layout())); 81 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), 82 dot_rhs->shape().layout())); 83 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), 84 result->shape().layout())); 85 for (const auto& instruction : computation->instructions()) { 86 EXPECT_NE(instruction->opcode(), HloOpcode::kCopy); 87 } 88 } 89 90 TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { 91 // Two dot products have the same constant as the RHS, and both those dot 92 // products can be optimized if the constant has a column-major layout. 93 auto builder = HloComputation::Builder(TestName()); 94 Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1}); 95 Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 24}); 96 Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1}); 97 auto dot_a_lhs = builder.AddInstruction( 98 HloInstruction::CreateParameter(0, lhs_shape, "param0")); 99 auto dot_b_lhs = builder.AddInstruction( 100 HloInstruction::CreateParameter(1, lhs_shape, "param1")); 101 auto dot_rhs = builder.AddInstruction( 102 HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); 103 auto dot_a_result = builder.AddInstruction( 104 HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); 105 auto dot_b_result = builder.AddInstruction( 106 HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); 107 builder.AddInstruction(HloInstruction::CreateBinary( 108 result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); 109 110 auto module = CreateNewModule(); 111 HloComputation* computation = module->AddEntryComputation(builder.Build()); 112 113 ComputationLayout computation_layout(computation->ComputeProgramShape()); 114 *computation_layout.mutable_parameter_layout(0) = 115 ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_shape)); 116 *computation_layout.mutable_result_layout() = 117 ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape)); 118 AssignLayouts(module.get(), &computation_layout); 119 120 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), 121 dot_rhs->shape().layout())); 122 for (HloInstruction* instruction : 123 {dot_a_lhs, dot_b_lhs, dot_a_result, dot_b_result}) { 124 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), 125 instruction->shape().layout())); 126 } 127 for (const auto& instruction : computation->instructions()) { 128 EXPECT_NE(instruction->opcode(), HloOpcode::kCopy); 129 } 130 } 131 132 TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { 133 // Two dot products have the same constant as the RHS, but only one of the two 134 // dot products can be optimized if the constant has a column-major layout. 135 auto builder = HloComputation::Builder(TestName()); 136 Shape lhs_a_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1}); 137 Shape lhs_b_shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 12}, {0, 1}); 138 Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1}); 139 Shape result_a_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1}); 140 Shape result_b_shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 24}, {0, 1}); 141 auto dot_a_lhs = builder.AddInstruction( 142 HloInstruction::CreateParameter(0, lhs_a_shape, "param0")); 143 auto dot_b_lhs = builder.AddInstruction( 144 HloInstruction::CreateParameter(1, lhs_b_shape, "param1")); 145 auto dot_rhs = builder.AddInstruction( 146 HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); 147 auto dot_a_result = builder.AddInstruction( 148 HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); 149 auto dot_b_result = builder.AddInstruction( 150 HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); 151 auto tuple_result = builder.AddInstruction( 152 HloInstruction::CreateTuple({dot_a_result, dot_b_result})); 153 154 auto module = CreateNewModule(); 155 HloComputation* computation = module->AddEntryComputation(builder.Build()); 156 157 ComputationLayout computation_layout(computation->ComputeProgramShape()); 158 *computation_layout.mutable_parameter_layout(0) = 159 ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_a_shape)); 160 *computation_layout.mutable_parameter_layout(1) = 161 ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_b_shape)); 162 *computation_layout.mutable_result_layout() = 163 ShapeLayout(LayoutUtil::GetWithDefaultLayout(tuple_result->shape())); 164 AssignLayouts(module.get(), &computation_layout); 165 166 for (HloInstruction* instruction : 167 {dot_rhs, dot_a_lhs, dot_b_lhs, dot_a_result, dot_b_result}) { 168 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), 169 instruction->shape().layout())); 170 } 171 for (const auto& instruction : computation->instructions()) { 172 EXPECT_NE(instruction->opcode(), HloOpcode::kCopy); 173 } 174 } 175 176 TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { 177 auto builder = HloComputation::Builder(TestName()); 178 Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1}); 179 Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1}); 180 Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 24}, {0, 1}); 181 auto dot_lhs = builder.AddInstruction( 182 HloInstruction::CreateConstant(Literal::CreateFromShape(lhs_shape))); 183 auto dot_rhs = builder.AddInstruction( 184 HloInstruction::CreateParameter(0, rhs_shape, "param0")); 185 auto dot_result = builder.AddInstruction( 186 HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); 187 188 auto module = CreateNewModule(); 189 HloComputation* computation = module->AddEntryComputation(builder.Build()); 190 191 ComputationLayout computation_layout(computation->ComputeProgramShape()); 192 *computation_layout.mutable_parameter_layout(0) = 193 ShapeLayout(LayoutUtil::GetWithDefaultLayout(rhs_shape)); 194 *computation_layout.mutable_result_layout() = 195 ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape)); 196 AssignLayouts(module.get(), &computation_layout); 197 198 for (HloInstruction* instruction : {dot_lhs, dot_rhs, dot_result}) { 199 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), 200 instruction->shape().layout())); 201 } 202 for (const auto& instruction : computation->instructions()) { 203 EXPECT_NE(instruction->opcode(), HloOpcode::kCopy); 204 } 205 } 206 207 TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { 208 // This is a case we could theoretically optimize at some point, but today we 209 // don't. 210 auto builder = HloComputation::Builder(TestName()); 211 Shape lhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 12}, {0, 1}); 212 Shape rhs_shape = ShapeUtil::MakeShapeWithLayout(F32, {12, 24}, {0, 1}); 213 Shape other_shape = ShapeUtil::MakeShapeWithLayout(F32, {100, 24}, {0, 1}); 214 215 auto constant_shape = ShapeUtil::MakeTupleShape({other_shape, rhs_shape}); 216 auto constant = builder.AddInstruction( 217 HloInstruction::CreateConstant(Literal::CreateFromShape(constant_shape))); 218 219 Shape result_shape = ShapeUtil::MakeShape(F32, {1, 24}); 220 221 auto dot_lhs = builder.AddInstruction( 222 HloInstruction::CreateParameter(0, lhs_shape, "param0")); 223 auto dot_rhs = builder.AddInstruction( 224 HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1)); 225 auto dot_result = builder.AddInstruction( 226 HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); 227 228 auto module = CreateNewModule(); 229 HloComputation* computation = module->AddEntryComputation(builder.Build()); 230 231 ComputationLayout computation_layout(computation->ComputeProgramShape()); 232 *computation_layout.mutable_parameter_layout(0) = 233 ShapeLayout(LayoutUtil::GetWithDefaultLayout(lhs_shape)); 234 *computation_layout.mutable_result_layout() = 235 ShapeLayout(LayoutUtil::GetWithDefaultLayout(result_shape)); 236 AssignLayouts(module.get(), &computation_layout); 237 238 for (HloInstruction* instruction : {dot_lhs, dot_rhs, dot_result}) { 239 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), 240 instruction->shape().layout())); 241 } 242 for (const auto& instruction : computation->instructions()) { 243 EXPECT_NE(instruction->opcode(), HloOpcode::kCopy); 244 } 245 } 246 247 struct DotOutputFusionLayoutAssignmentResult { 248 bool layout_assignment_changed_something; 249 const HloInstruction* dot_lhs_fusion_param; 250 const HloInstruction* dot_rhs_fusion_param; 251 const HloInstruction* addend_fusion_param; 252 }; 253 254 static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion( 255 HloModule* module, const string& test_name, int m, int k, int n, 256 const int64 dot_operand_idx_in_add) { 257 DotOutputFusionLayoutAssignmentResult result; 258 259 CHECK(dot_operand_idx_in_add == 0 || dot_operand_idx_in_add == 1); 260 261 auto builder = HloComputation::Builder(test_name); 262 263 Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); 264 Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); 265 Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); 266 267 HloInstruction* dot_lhs = builder.AddInstruction( 268 HloInstruction::CreateParameter(0, dot_lhs_shape, "param0")); 269 HloInstruction* addend = builder.AddInstruction( 270 HloInstruction::CreateParameter(1, dot_shape, "param1")); 271 HloInstruction* dot_rhs = builder.AddInstruction( 272 HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape))); 273 HloInstruction* dot_result = builder.AddInstruction( 274 HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); 275 HloInstruction* add_result; 276 if (dot_operand_idx_in_add == 0) { 277 add_result = builder.AddInstruction(HloInstruction::CreateBinary( 278 dot_shape, HloOpcode::kAdd, dot_result, addend)); 279 } else { 280 add_result = builder.AddInstruction(HloInstruction::CreateBinary( 281 dot_shape, HloOpcode::kAdd, addend, dot_result)); 282 } 283 284 HloComputation* computation = module->AddEntryComputation(builder.Build()); 285 286 HloInstruction* fusion_instruction = 287 module->entry_computation()->AddInstruction(HloInstruction::CreateFusion( 288 dot_shape, HloInstruction::FusionKind::kOutput, add_result)); 289 TF_RETURN_IF_ERROR( 290 computation->ReplaceInstruction(add_result, fusion_instruction)); 291 292 HloInstruction* fused_add = 293 fusion_instruction->fused_instructions_computation()->root_instruction(); 294 HloInstruction* fused_dot = fusion_instruction->FuseInstruction(dot_result); 295 296 TF_RETURN_IF_ERROR( 297 computation->RemoveInstructionAndUnusedOperands(dot_result)); 298 299 ComputationLayout computation_layout(computation->ComputeProgramShape()); 300 *computation_layout.mutable_parameter_layout(0) = 301 ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_lhs_shape)); 302 *computation_layout.mutable_parameter_layout(1) = 303 ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_shape)); 304 *computation_layout.mutable_result_layout() = 305 ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_shape)); 306 307 result.dot_lhs_fusion_param = 308 fusion_instruction->operand(fused_dot->operand(0)->parameter_number()); 309 result.dot_rhs_fusion_param = 310 fusion_instruction->operand(fused_dot->operand(1)->parameter_number()); 311 result.addend_fusion_param = fusion_instruction->operand( 312 fused_add->operand(1 - dot_operand_idx_in_add)->parameter_number()); 313 314 cpu::CpuLayoutAssignment layout_assignment(&computation_layout); 315 TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something, 316 layout_assignment.Run(module)); 317 318 return result; 319 } 320 321 static void AssertCorrectLayoutForDotOutputFusion( 322 const HloComputation* computation, 323 const DotOutputFusionLayoutAssignmentResult& layout_assignment_result, 324 bool expect_col_major_dot_rhs) { 325 Layout expected_dot_rhs_layout = expect_col_major_dot_rhs 326 ? LayoutUtil::MakeLayout({0, 1}) 327 : LayoutUtil::MakeLayout({1, 0}); 328 EXPECT_TRUE(LayoutUtil::Equal( 329 expected_dot_rhs_layout, 330 layout_assignment_result.dot_rhs_fusion_param->shape().layout())); 331 332 EXPECT_TRUE(LayoutUtil::Equal( 333 LayoutUtil::MakeLayout({1, 0}), 334 layout_assignment_result.dot_lhs_fusion_param->shape().layout())); 335 336 EXPECT_TRUE(LayoutUtil::Equal( 337 LayoutUtil::MakeLayout({1, 0}), 338 layout_assignment_result.addend_fusion_param->shape().layout())); 339 EXPECT_THAT(computation->instructions(), Each(Not(op::Copy()))); 340 } 341 342 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { 343 std::unique_ptr<HloModule> module = CreateNewModule(); 344 TF_ASSERT_OK_AND_ASSIGN( 345 DotOutputFusionLayoutAssignmentResult layout_assignment_result, 346 RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, 347 /*dot_operand_idx_in_add=*/0)); 348 ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); 349 AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), 350 layout_assignment_result, 351 /*expect_col_major_dot_rhs=*/true); 352 } 353 354 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { 355 std::unique_ptr<HloModule> module = CreateNewModule(); 356 TF_ASSERT_OK_AND_ASSIGN( 357 DotOutputFusionLayoutAssignmentResult layout_assignment_result, 358 RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, 359 /*dot_operand_idx_in_add=*/1)); 360 ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); 361 AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), 362 layout_assignment_result, 363 /*expect_col_major_dot_rhs=*/true); 364 } 365 366 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { 367 std::unique_ptr<HloModule> module = CreateNewModule(); 368 TF_ASSERT_OK_AND_ASSIGN( 369 DotOutputFusionLayoutAssignmentResult layout_assignment_result, 370 RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, 371 /*dot_operand_idx_in_add=*/0)); 372 ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); 373 AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), 374 layout_assignment_result, 375 /*expect_col_major_dot_rhs=*/false); 376 } 377 378 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { 379 std::unique_ptr<HloModule> module = CreateNewModule(); 380 TF_ASSERT_OK_AND_ASSIGN( 381 DotOutputFusionLayoutAssignmentResult layout_assignment_result, 382 RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, 383 /*dot_operand_idx_in_add=*/1)); 384 ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); 385 AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), 386 layout_assignment_result, 387 /*expect_col_major_dot_rhs=*/false); 388 } 389 390 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { 391 std::unique_ptr<HloModule> module = CreateNewModule(); 392 TF_ASSERT_OK_AND_ASSIGN( 393 DotOutputFusionLayoutAssignmentResult layout_assignment_result, 394 RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, 395 /*dot_operand_idx_in_add=*/0)); 396 ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); 397 AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), 398 layout_assignment_result, 399 /*expect_col_major_dot_rhs=*/false); 400 } 401 402 TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_1) { 403 std::unique_ptr<HloModule> module = CreateNewModule(); 404 TF_ASSERT_OK_AND_ASSIGN( 405 DotOutputFusionLayoutAssignmentResult layout_assignment_result, 406 RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, 407 /*dot_operand_idx_in_add=*/1)); 408 ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); 409 AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), 410 layout_assignment_result, 411 /*expect_col_major_dot_rhs=*/false); 412 } 413 } // namespace 414 } // namespace xla 415