1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/layout_assignment.h" 17 18 #include <initializer_list> 19 #include <memory> 20 #include <utility> 21 #include <vector> 22 23 #include "tensorflow/compiler/xla/layout_util.h" 24 #include "tensorflow/compiler/xla/literal_util.h" 25 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" 26 #include "tensorflow/compiler/xla/service/computation_layout.h" 27 #include "tensorflow/compiler/xla/service/hlo_computation.h" 28 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 29 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 30 #include "tensorflow/compiler/xla/service/hlo_module.h" 31 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 32 #include "tensorflow/compiler/xla/shape_layout.h" 33 #include "tensorflow/compiler/xla/shape_util.h" 34 #include "tensorflow/compiler/xla/test.h" 35 #include "tensorflow/compiler/xla/test_helpers.h" 36 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 37 #include "tensorflow/compiler/xla/tests/test_utils.h" 38 #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" 39 #include "tensorflow/compiler/xla/util.h" 40 #include "tensorflow/compiler/xla/xla_data.pb.h" 41 #include "tensorflow/core/lib/core/status.h" 42 #include "tensorflow/core/lib/core/status_test_util.h" 43 #include "tensorflow/core/lib/gtl/array_slice.h" 44 45 namespace op = xla::testing::opcode_matchers; 46 47 namespace xla { 48 namespace { 49 50 using ::testing::ElementsAre; 51 52 class LayoutAssignmentTest : public HloTestBase { 53 protected: 54 void AssignLayouts(HloModule* module, 55 ComputationLayout* entry_computation_layout) { 56 LayoutAssignment layout_assignment(entry_computation_layout); 57 EXPECT_IS_OK(layout_assignment.Run(module).status()); 58 } 59 }; 60 61 TEST_F(LayoutAssignmentTest, ComputationLayout) { 62 // Verify the layouts of the root and parameter instructions of a computation 63 // match the ComputationLayout for two different layouts. 64 std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}}; 65 for (auto& minor_to_major : minor_to_majors) { 66 auto builder = HloComputation::Builder(TestName()); 67 Shape ashape = ShapeUtil::MakeShape(F32, {42, 12}); 68 auto param0 = builder.AddInstruction( 69 HloInstruction::CreateParameter(0, ashape, "param0")); 70 auto param1 = builder.AddInstruction( 71 HloInstruction::CreateParameter(1, ashape, "param1")); 72 auto add = builder.AddInstruction( 73 HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1)); 74 auto module = CreateNewModule(); 75 HloComputation* computation = module->AddEntryComputation(builder.Build()); 76 77 Layout layout = LayoutUtil::MakeLayout(minor_to_major); 78 Shape shape(ashape); 79 *shape.mutable_layout() = layout; 80 const ShapeLayout shape_layout(shape); 81 82 ComputationLayout computation_layout(computation->ComputeProgramShape()); 83 *computation_layout.mutable_parameter_layout(0) = shape_layout; 84 *computation_layout.mutable_parameter_layout(1) = shape_layout; 85 *computation_layout.mutable_result_layout() = shape_layout; 86 AssignLayouts(module.get(), &computation_layout); 87 EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout())); 88 EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout())); 89 EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout())); 90 } 91 } 92 93 TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { 94 // Verify the layouts of the root and parameter instructions of a computation 95 // match the ComputationLayout which has mixed layout. 96 auto builder = HloComputation::Builder(TestName()); 97 Shape ashape = ShapeUtil::MakeShape(F32, {42, 12}); 98 auto param0 = builder.AddInstruction( 99 HloInstruction::CreateParameter(0, ashape, "param0")); 100 auto param1 = builder.AddInstruction( 101 HloInstruction::CreateParameter(1, ashape, "param1")); 102 builder.AddInstruction( 103 HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1)); 104 auto module = CreateNewModule(); 105 HloComputation* computation = module->AddEntryComputation(builder.Build()); 106 107 Layout col_major_layout = LayoutUtil::MakeLayout({1, 0}); 108 Shape col_major_shape(ashape); 109 *col_major_shape.mutable_layout() = col_major_layout; 110 const ShapeLayout col_major(col_major_shape); 111 112 Layout row_major_layout = LayoutUtil::MakeLayout({0, 1}); 113 Shape row_major_shape(ashape); 114 *row_major_shape.mutable_layout() = row_major_layout; 115 const ShapeLayout row_major(row_major_shape); 116 117 ComputationLayout computation_layout(computation->ComputeProgramShape()); 118 *computation_layout.mutable_parameter_layout(0) = col_major; 119 *computation_layout.mutable_parameter_layout(1) = row_major; 120 *computation_layout.mutable_result_layout() = col_major; 121 122 AssignLayouts(module.get(), &computation_layout); 123 EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout())); 124 EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout())); 125 EXPECT_TRUE(LayoutUtil::Equal( 126 col_major_layout, computation->root_instruction()->shape().layout())); 127 } 128 129 TEST_F(LayoutAssignmentTest, FusionInstruction) { 130 // Verify that the layout of the fused parameters in a fusion instruction 131 // match that of the fusion operands. Other fused instructions should have no 132 // layout. 133 std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}}; 134 for (auto& minor_to_major : minor_to_majors) { 135 auto builder = HloComputation::Builder(TestName()); 136 auto constant_literal1 = Literal::CreateR2WithLayout<float>( 137 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major)); 138 auto constant_literal2 = Literal::CreateR2WithLayout<float>( 139 {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major)); 140 Shape ashape = constant_literal1->shape(); 141 142 auto constant1 = builder.AddInstruction( 143 HloInstruction::CreateConstant(std::move(constant_literal1))); 144 auto constant2 = builder.AddInstruction( 145 HloInstruction::CreateConstant(std::move(constant_literal2))); 146 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 147 ashape, HloOpcode::kAdd, constant1, constant2)); 148 auto negate1 = builder.AddInstruction( 149 HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, add)); 150 auto negate2 = builder.AddInstruction( 151 HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1)); 152 153 auto module = CreateNewModule(); 154 HloComputation* computation = module->AddEntryComputation(builder.Build()); 155 156 auto fusion = computation->CreateFusionInstruction( 157 {negate2, negate1, add}, HloInstruction::FusionKind::kLoop); 158 159 Layout layout = LayoutUtil::MakeLayout(minor_to_major); 160 Shape shape(ashape); 161 *shape.mutable_layout() = layout; 162 const ShapeLayout shape_layout(shape); 163 164 ComputationLayout computation_layout(computation->ComputeProgramShape()); 165 *computation_layout.mutable_result_layout() = shape_layout; 166 167 AssignLayouts(module.get(), &computation_layout); 168 169 EXPECT_TRUE(LayoutUtil::Equal( 170 layout, fusion->fused_parameter(0)->shape().layout())); 171 EXPECT_TRUE(LayoutUtil::Equal( 172 layout, fusion->fused_parameter(1)->shape().layout())); 173 EXPECT_TRUE(LayoutUtil::Equal( 174 layout, fusion->fused_expression_root()->shape().layout())); 175 176 // Inner fused node should not have layout. 177 EXPECT_FALSE(LayoutUtil::HasLayout( 178 fusion->fused_expression_root()->operand(0)->shape())); 179 } 180 } 181 182 TEST_F(LayoutAssignmentTest, TupleLayout) { 183 // Verify the layouts of a tuple are assigned properly (the element layouts 184 // match their source). 185 auto builder = HloComputation::Builder(TestName()); 186 auto constant0 = builder.AddInstruction( 187 HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( 188 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); 189 auto constant1 = builder.AddInstruction( 190 HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( 191 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); 192 auto tuple = builder.AddInstruction( 193 HloInstruction::CreateTuple({constant0, constant1})); 194 195 // To avoid having to construct a tuple layout in the ComputationLayout below, 196 // make the result of the instruction be an array. 197 auto get_element0 = builder.AddInstruction( 198 HloInstruction::CreateGetTupleElement(constant0->shape(), tuple, 0)); 199 auto negate = builder.AddInstruction(HloInstruction::CreateUnary( 200 constant0->shape(), HloOpcode::kNegate, get_element0)); 201 202 auto module = CreateNewModule(); 203 module->AddEntryComputation(builder.Build()); 204 205 ComputationLayout computation_layout( 206 module->entry_computation()->ComputeProgramShape()); 207 208 AssignLayouts(module.get(), &computation_layout); 209 210 EXPECT_TRUE( 211 LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape())); 212 213 EXPECT_TRUE(LayoutUtil::HasLayout(tuple->shape())); 214 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual( 215 negate->shape(), computation_layout.result_layout().shape())); 216 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual( 217 ShapeUtil::GetTupleElementShape(tuple->shape(), 1), constant1->shape())); 218 } 219 220 TEST_F(LayoutAssignmentTest, TupleSelect) { 221 // Verify layouts of a select with tuple operands is assigned properly. 222 auto builder = HloComputation::Builder(TestName()); 223 auto constant0 = builder.AddInstruction( 224 HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( 225 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); 226 auto constant1 = builder.AddInstruction( 227 HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( 228 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); 229 auto tuple0 = builder.AddInstruction( 230 HloInstruction::CreateTuple({constant0, constant1})); 231 auto tuple1 = builder.AddInstruction( 232 HloInstruction::CreateTuple({constant0, constant1})); 233 234 auto pred = builder.AddInstruction( 235 HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))); 236 237 auto select = builder.AddInstruction(HloInstruction::CreateTernary( 238 tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); 239 240 auto module = CreateNewModule(); 241 module->AddEntryComputation(builder.Build()); 242 243 ComputationLayout computation_layout( 244 module->entry_computation()->ComputeProgramShape()); 245 Shape result_shape = 246 ShapeUtil::MakeTupleShape({constant0->shape(), constant1->shape()}); 247 TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( 248 result_shape)); 249 250 AssignLayouts(module.get(), &computation_layout); 251 252 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape())); 253 } 254 255 TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { 256 // Construct following computation which has conflicting layouts for two 257 // elements of a tuple which share the same source logicalb buffer: 258 // 259 // %constant = Constant(...) 260 // %inner_tuple = Tuple(%constant) 261 // %nested_tuple = Tuple(%inner_tuple, %inner_tuple) 262 // 263 // Result layout col-major for the first element and row-major for the 264 // second. This results in the conflict where the element of the inner_tuple 265 // needs to be both col and row major. This is resolved by deep-copying the 266 // tuple and assigning the layouts of the copied arrays as needed. 267 auto builder = HloComputation::Builder(TestName()); 268 auto constant = builder.AddInstruction(HloInstruction::CreateConstant( 269 Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); 270 auto inner_tuple = 271 builder.AddInstruction(HloInstruction::CreateTuple({constant})); 272 auto nested_tuple = builder.AddInstruction( 273 HloInstruction::CreateTuple({inner_tuple, inner_tuple})); 274 275 auto module = CreateNewModule(); 276 module->AddEntryComputation(builder.Build()); 277 278 ComputationLayout computation_layout( 279 module->entry_computation()->ComputeProgramShape()); 280 Shape result_shape = nested_tuple->shape(); 281 *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) = 282 ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); 283 *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{1, 0}) = 284 ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}); 285 TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( 286 result_shape)); 287 288 LayoutAssignment layout_assignment(&computation_layout); 289 AssignLayouts(module.get(), &computation_layout); 290 291 // Layout assignment should have deep copied the result of the computation to 292 // address the layout conflict. This results in several Tuple() and 293 // GetTupleElement() instructions. Running algebraic simplification should 294 // clean up the code to something like: 295 // 296 // %constant = Constant(...) layout={1,0} 297 // %tuple.0 = Tuple(%constant) layout=({1,0}) 298 // %copy = Copy(%constant) layout={0,1} # layout transposed 299 // %tuple.1 = Tuple(%copy) layout=({0,1}) 300 // %tuple.2 = Tuple(%tuple.0, %tuple.1) layout=(({1,0}), ({0,1})) 301 // 302 EXPECT_TRUE( 303 AlgebraicSimplifier(/*is_layout_sensitive=*/true, 304 [](const Shape&, const Shape&) { return false; }) 305 .Run(module.get()) 306 .ValueOrDie()); 307 HloInstruction* root = module->entry_computation()->root_instruction(); 308 // Verify layout of the root and the root's operands. 309 EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape())); 310 EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}), 311 root->operand(0)->shape())); 312 EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {1}), 313 root->operand(1)->shape())); 314 315 // Verify the structure of the HLO graph. 316 EXPECT_THAT(root, 317 op::Tuple(op::Tuple(constant), op::Tuple(op::Copy(constant)))); 318 } 319 320 TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { 321 // param -> log -> reshape -> tanh 322 auto builder = HloComputation::Builder(TestName()); 323 Shape ashape = ShapeUtil::MakeShape(F32, {1, 2, 3, 1}); 324 Shape bshape = ShapeUtil::MakeShape(F32, {3, 1, 2}); 325 auto param = builder.AddInstruction( 326 HloInstruction::CreateParameter(0, ashape, "param")); 327 auto log = builder.AddInstruction( 328 HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param)); 329 auto reshape = 330 builder.AddInstruction(HloInstruction::CreateReshape(bshape, log)); 331 auto tanh = builder.AddInstruction( 332 HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape)); 333 334 auto module = CreateNewModule(); 335 HloComputation* computation = 336 module->AddEntryComputation(builder.Build(tanh)); 337 338 Shape ashape_with_layout(ashape); 339 Shape bshape_with_layout(bshape); 340 *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 2, 1, 3}); 341 *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0}); 342 343 ComputationLayout computation_layout(computation->ComputeProgramShape()); 344 *computation_layout.mutable_parameter_layout(0) = 345 ShapeLayout(ashape_with_layout); 346 *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); 347 AssignLayouts(module.get(), &computation_layout); 348 349 auto log_minor_to_major = 350 AsInt64Slice(log->shape().layout().minor_to_major()); 351 EXPECT_GT(PositionInContainer(log_minor_to_major, 1), 352 PositionInContainer(log_minor_to_major, 2)); 353 354 auto reshape_minor_to_major = 355 AsInt64Slice(reshape->shape().layout().minor_to_major()); 356 EXPECT_GT(PositionInContainer(reshape_minor_to_major, 0), 357 PositionInContainer(reshape_minor_to_major, 2)); 358 } 359 360 // Test whether LayoutAssignment assigns layouts to elementwise operations to 361 // keep linear indices valid across them, and to transpositions to make them 362 // bitcasts. 363 TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { 364 // param -> log -> transpose -> tanh 365 auto builder = HloComputation::Builder(TestName()); 366 Shape ashape = ShapeUtil::MakeShape(F32, {42, 12}); 367 Shape bshape = ShapeUtil::MakeShape(F32, {12, 42}); 368 auto param = builder.AddInstruction( 369 HloInstruction::CreateParameter(0, ashape, "param")); 370 auto log = builder.AddInstruction( 371 HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param)); 372 auto transpose = builder.AddInstruction( 373 HloInstruction::CreateTranspose(bshape, log, {1, 0})); 374 auto tanh = builder.AddInstruction( 375 HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose)); 376 auto module = CreateNewModule(); 377 auto computation = module->AddEntryComputation(builder.Build(tanh)); 378 379 Shape ashape_with_layout(ashape); 380 Shape bshape_with_layout(bshape); 381 *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0}); 382 *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); 383 384 ComputationLayout computation_layout(computation->ComputeProgramShape()); 385 *computation_layout.mutable_parameter_layout(0) = 386 ShapeLayout(ashape_with_layout); 387 *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); 388 AssignLayouts(module.get(), &computation_layout); 389 390 EXPECT_TRUE( 391 LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout())); 392 EXPECT_TRUE(LayoutUtil::Equal(bshape_with_layout.layout(), 393 transpose->shape().layout())); 394 EXPECT_TRUE( 395 LayoutUtil::Equal(bshape_with_layout.layout(), tanh->shape().layout())); 396 } 397 398 // Test whether LayoutAssignment assigns layouts to transpositions to make them 399 // bitcasts. 400 TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { 401 // param -> broadcast -> transpose 402 auto builder = HloComputation::Builder(TestName()); 403 Shape ashape = ShapeUtil::MakeShape(F32, {3, 4}); 404 Shape bshape = ShapeUtil::MakeShape(F32, {2, 3, 4}); 405 Shape cshape = ShapeUtil::MakeShape(F32, {4, 3, 2}); 406 auto param = builder.AddInstruction( 407 HloInstruction::CreateParameter(0, ashape, "param")); 408 auto broadcast = builder.AddInstruction( 409 HloInstruction::CreateBroadcast(bshape, param, {1, 2})); 410 auto transpose = builder.AddInstruction( 411 HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0})); 412 auto module = CreateNewModule(); 413 HloComputation* computation = 414 module->AddEntryComputation(builder.Build(transpose)); 415 416 Shape input_shape_with_layout(ashape); 417 Shape output_shape_with_layout(cshape); 418 *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0}); 419 *output_shape_with_layout.mutable_layout() = 420 LayoutUtil::MakeLayout({2, 1, 0}); 421 422 ComputationLayout computation_layout(computation->ComputeProgramShape()); 423 *computation_layout.mutable_parameter_layout(0) = 424 ShapeLayout(input_shape_with_layout); 425 *computation_layout.mutable_result_layout() = 426 ShapeLayout(output_shape_with_layout); 427 AssignLayouts(module.get(), &computation_layout); 428 429 EXPECT_THAT(broadcast->shape().layout().minor_to_major(), 430 ElementsAre(0, 1, 2)); 431 } 432 433 TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { 434 // param[4] -> broadcast[3x4] ------> transpose[4x3]-------- -------> tuple 435 // \ / 436 // \-> tanh[3x4] -> broadcast2[2x3x4] -/ 437 // 438 // The layout of `transpose` is set to {1,0} because it provides a buffer to 439 // the computation result which has a fixed layout.. Therefore, `broadcast` 440 // (the operand of transpose) is expected to have layout {0,1} so that the 441 // transpose is a bitcast. Furthermore, `tanh` is expected to have the same 442 // layout as `broadcast` (i.e. {0,1}) because `tanh` is elementwise. 443 Shape f32_4 = ShapeUtil::MakeShape(F32, {4}); 444 Shape f32_34 = ShapeUtil::MakeShape(F32, {3, 4}); 445 Shape f32_43 = ShapeUtil::MakeShape(F32, {4, 3}); 446 Shape f32_234 = ShapeUtil::MakeShape(F32, {2, 3, 4}); 447 448 auto builder = HloComputation::Builder(TestName()); 449 auto param = builder.AddInstruction( 450 HloInstruction::CreateParameter(0, f32_4, "param")); 451 auto broadcast = builder.AddInstruction( 452 HloInstruction::CreateBroadcast(f32_34, param, {3})); 453 auto transpose = builder.AddInstruction( 454 HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0})); 455 auto tanh = builder.AddInstruction( 456 HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast)); 457 auto broadcast2 = builder.AddInstruction( 458 HloInstruction::CreateBroadcast(f32_234, tanh, {2})); 459 auto tuple = builder.AddInstruction( 460 HloInstruction::CreateTuple({transpose, broadcast2})); 461 auto module = CreateNewModule(); 462 HloComputation* computation = 463 module->AddEntryComputation(builder.Build(tuple)); 464 465 ComputationLayout computation_layout(computation->ComputeProgramShape()); 466 Shape param_shape_with_layout(f32_4); 467 Shape transpose_shape_with_layout(f32_43); 468 Shape broadcast2_shape_with_layout(f32_234); 469 *param_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0}); 470 *transpose_shape_with_layout.mutable_layout() = 471 LayoutUtil::MakeLayout({1, 0}); 472 *broadcast2_shape_with_layout.mutable_layout() = 473 LayoutUtil::MakeLayout({2, 1, 0}); 474 475 *computation_layout.mutable_parameter_layout(0) = 476 ShapeLayout(param_shape_with_layout); 477 *computation_layout.mutable_result_layout() = 478 ShapeLayout(ShapeUtil::MakeTupleShape( 479 {transpose_shape_with_layout, broadcast2_shape_with_layout})); 480 AssignLayouts(module.get(), &computation_layout); 481 482 EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1)); 483 EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0)); 484 EXPECT_THAT(tanh->shape().layout().minor_to_major(), ElementsAre(0, 1)); 485 } 486 487 class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment { 488 public: 489 explicit OperandsMustBeTheSameLayoutAssignment( 490 ComputationLayout* entry_computation_layout) 491 : LayoutAssignment(entry_computation_layout) {} 492 493 protected: 494 Status PropagateBufferConstraint( 495 const BufferLayoutConstraint& buffer_constraint, 496 LayoutConstraints* constraints) override { 497 const LogicalBuffer& buffer = buffer_constraint.buffer(); 498 const HloInstruction* instruction = buffer.instruction(); 499 500 // Force the operands' layout to the output layout. 501 for (int64 operand_no = 0; operand_no < instruction->operand_count(); 502 ++operand_no) { 503 const HloInstruction* operand = instruction->operand(operand_no); 504 if (ShapeUtil::Rank(instruction->shape()) != 505 ShapeUtil::Rank(operand->shape())) { 506 continue; 507 } 508 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( 509 buffer_constraint.layout(), instruction, operand_no, 510 /*mandatory=*/true)); 511 } 512 return PropagateBufferConstraintToUses(buffer_constraint, constraints); 513 } 514 }; 515 516 TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { 517 // param0 -> concatenate -> reshape 518 // param1 -^ 519 auto builder = HloComputation::Builder(TestName()); 520 Shape ashape = ShapeUtil::MakeShape(F32, {50, 1}); 521 Shape bshape = ShapeUtil::MakeShape(F32, {50, 2}); 522 Shape cshape = ShapeUtil::MakeShape(F32, {100}); 523 auto param0 = builder.AddInstruction( 524 HloInstruction::CreateParameter(0, ashape, "param")); 525 auto param1 = builder.AddInstruction( 526 HloInstruction::CreateParameter(1, ashape, "param")); 527 auto concatenate = builder.AddInstruction( 528 HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1)); 529 auto reshape = builder.AddInstruction( 530 HloInstruction::CreateReshape(cshape, concatenate)); 531 auto module = CreateNewModule(); 532 HloComputation* computation = 533 module->AddEntryComputation(builder.Build(reshape)); 534 535 Shape param0_shape_with_layout(ashape); 536 Shape param1_shape_with_layout(ashape); 537 *param0_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); 538 *param1_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0}); 539 540 ComputationLayout computation_layout(computation->ComputeProgramShape()); 541 *computation_layout.mutable_parameter_layout(0) = 542 ShapeLayout(param0_shape_with_layout); 543 *computation_layout.mutable_parameter_layout(1) = 544 ShapeLayout(param1_shape_with_layout); 545 OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout); 546 EXPECT_IS_OK(layout_assignment.Run(module.get()).status()); 547 548 EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode()); 549 EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(), 550 ElementsAre(1, 0)); 551 EXPECT_THAT(concatenate->operand(1)->shape().layout().minor_to_major(), 552 ElementsAre(1, 0)); 553 EXPECT_THAT(concatenate->shape().layout().minor_to_major(), 554 ElementsAre(1, 0)); 555 } 556 557 // Test layout assignment of a transpose into a bitcast based on its operand. 558 TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) { 559 auto builder = HloComputation::Builder(TestName()); 560 Shape input_shape_with_layout = 561 ShapeUtil::MakeShapeWithLayout(F32, {3, 5, 6, 7}, {2, 0, 3, 1}); 562 auto param = builder.AddInstruction( 563 HloInstruction::CreateParameter(0, input_shape_with_layout, "param")); 564 auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( 565 ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1})); 566 auto module = CreateNewModule(); 567 HloComputation* computation = 568 module->AddEntryComputation(builder.Build(transpose)); 569 ComputationLayout computation_layout(computation->ComputeProgramShape()); 570 AssignLayouts(module.get(), &computation_layout); 571 EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), 572 transpose->shape(), {2, 3, 0, 1})); 573 } 574 // Test layout assignment of a transpose into a bitcast based on its user. 575 TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { 576 auto builder = HloComputation::Builder(TestName()); 577 Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7}); 578 auto constant = builder.AddInstruction( 579 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); 580 auto broadcast = builder.AddInstruction( 581 HloInstruction::CreateBroadcast(input_shape, constant, {})); 582 auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( 583 ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1})); 584 auto module = CreateNewModule(); 585 HloComputation* computation = 586 module->AddEntryComputation(builder.Build(transpose)); 587 ComputationLayout computation_layout(computation->ComputeProgramShape()); 588 AssignLayouts(module.get(), &computation_layout); 589 EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), 590 transpose->shape(), {2, 3, 0, 1})); 591 } 592 593 // A GTE inside of a fusion node inherits the layout of its operand (which 594 // should, if we keep following operands, eventually be a parameter). 595 TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { 596 const char* module_str = R"( 597 HloModule test_module 598 599 fused_computation { 600 fparam = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0) 601 gte0 = f32[2,2,2] get-tuple-element(fparam), index=0 602 gte1 = (f32[2,2,2], f32[2,2,2]) get-tuple-element(fparam), index=1 603 gte1a = f32[2,2,2] get-tuple-element(gte1), index=0 604 gte1b = f32[2,2,2] get-tuple-element(gte1), index=1 605 add = f32[2,2,2] add(gte1a, gte1b) 606 ROOT fresult = f32[2,2,2] add(gte0, add) 607 } 608 609 ENTRY entry_computation { 610 param = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0) 611 ROOT fusion = 612 f32[2,2,2] fusion(param), kind=kLoop, calls=fused_computation 613 } 614 )"; 615 616 auto module = tools::Parse(module_str).ValueOrDie(); 617 ComputationLayout computation_layout( 618 module->entry_computation()->ComputeProgramShape()); 619 Shape param_shape = ShapeUtil::MakeTupleShape( 620 {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}), 621 ShapeUtil::MakeTupleShape({ 622 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {1, 2, 0}), 623 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {2, 0, 1}), 624 })}); 625 TF_ASSERT_OK( 626 computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape( 627 param_shape)); 628 computation_layout.mutable_result_layout()->ResetLayout( 629 LayoutUtil::MakeLayout({2, 1, 0})); 630 AssignLayouts(module.get(), &computation_layout); 631 632 HloComputation* fused_computation = *std::find_if( 633 module->computations().begin(), module->computations().end(), 634 [](const HloComputation* c) { return c->name() == "fused_computation"; }); 635 636 auto fused_instr = [&](const string& name) { 637 auto it = std::find_if( 638 fused_computation->instructions().begin(), 639 fused_computation->instructions().end(), 640 [&](const HloInstruction* i) { return i->name() == name; }); 641 CHECK(it != fused_computation->instructions().end()); 642 return *it; 643 }; 644 645 EXPECT_THAT(fused_instr("gte0")->shape().layout().minor_to_major(), 646 ElementsAre(0, 1, 2)); 647 EXPECT_THAT( 648 fused_instr("gte1")->shape().tuple_shapes(0).layout().minor_to_major(), 649 ElementsAre(1, 2, 0)); 650 EXPECT_THAT( 651 fused_instr("gte1")->shape().tuple_shapes(1).layout().minor_to_major(), 652 ElementsAre(2, 0, 1)); 653 EXPECT_THAT(fused_instr("gte1a")->shape().layout().minor_to_major(), 654 ElementsAre(1, 2, 0)); 655 EXPECT_THAT(fused_instr("gte1b")->shape().layout().minor_to_major(), 656 ElementsAre(2, 0, 1)); 657 EXPECT_THAT(fused_instr("fresult")->shape().layout().minor_to_major(), 658 ElementsAre(2, 1, 0)); 659 } 660 661 TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { 662 auto builder = HloComputation::Builder(TestName()); 663 auto module = CreateNewModule(); 664 Shape shape = ShapeUtil::MakeShape(F32, {128, 8}); 665 Shape tshape = ShapeUtil::MakeTupleShape({shape, shape}); 666 Shape result_tshape = ShapeUtil::MakeTupleShape({shape}); 667 668 auto param0 = builder.AddInstruction( 669 HloInstruction::CreateParameter(0, shape, "param0")); 670 auto param1 = builder.AddInstruction( 671 HloInstruction::CreateParameter(1, shape, "param1")); 672 auto pred = builder.AddInstruction(HloInstruction::CreateParameter( 673 2, ShapeUtil::MakeShape(PRED, {}), "param2")); 674 auto tuple = 675 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1})); 676 677 auto true_builder = HloComputation::Builder(TestName() + "_TrueBranch"); 678 { 679 auto param = true_builder.AddInstruction( 680 HloInstruction::CreateParameter(0, tshape, "param")); 681 auto gte0 = true_builder.AddInstruction( 682 HloInstruction::CreateGetTupleElement(shape, param, 0)); 683 auto gte1 = true_builder.AddInstruction( 684 HloInstruction::CreateGetTupleElement(shape, param, 1)); 685 auto add = true_builder.AddInstruction( 686 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, gte0, gte1)); 687 true_builder.AddInstruction(HloInstruction::CreateTuple({add})); 688 } 689 HloComputation* true_computation = 690 module->AddEmbeddedComputation(true_builder.Build()); 691 692 auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch"); 693 { 694 Shape xshape = ShapeUtil::MakeShapeWithLayout(F32, {128, 8}, {0, 1}); 695 false_builder.AddInstruction( 696 HloInstruction::CreateParameter(0, tshape, "param")); 697 // Using infeed as layout assignment does not mess up with it. 698 auto infeed = 699 false_builder.AddInstruction(HloInstruction::CreateInfeed(xshape, "")); 700 false_builder.AddInstruction(HloInstruction::CreateTuple({infeed})); 701 } 702 HloComputation* false_computation = 703 module->AddEmbeddedComputation(false_builder.Build()); 704 builder.AddInstruction(HloInstruction::CreateConditional( 705 result_tshape, pred, tuple, true_computation, tuple, false_computation)); 706 707 HloComputation* computation = module->AddEntryComputation(builder.Build()); 708 ComputationLayout computation_layout(computation->ComputeProgramShape()); 709 710 AssignLayouts(module.get(), &computation_layout); 711 712 const HloInstruction* true_root = true_computation->root_instruction(); 713 const HloInstruction* false_root = false_computation->root_instruction(); 714 EXPECT_THAT(true_root->opcode(), HloOpcode::kTuple); 715 EXPECT_THAT(false_root->opcode(), HloOpcode::kTuple); 716 717 const HloInstruction* true_result = true_root->operand(0); 718 const HloInstruction* false_result = false_root->operand(0); 719 EXPECT_TRUE(LayoutUtil::Equal(true_result->shape().layout(), 720 false_result->shape().layout())); 721 EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy); 722 } 723 724 } // namespace 725 } // namespace xla 726