1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/bfloat16_propagation.h" 17 #include "tensorflow/compiler/xla/service/bfloat16_support.h" 18 #include "tensorflow/compiler/xla/service/hlo_computation.h" 19 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 20 #include "tensorflow/compiler/xla/service/hlo_module.h" 21 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 22 #include "tensorflow/compiler/xla/shape_util.h" 23 #include "tensorflow/compiler/xla/test.h" 24 #include "tensorflow/compiler/xla/test_helpers.h" 25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 26 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 29 namespace xla { 30 31 // A class specifying the BF16 support used to test the propagation pass. It 32 // specifies that BF16 and mixed precision are supported in all HloInstructions, 33 // and that kDot reduces its operands precision to BF16. 34 class TestBFloat16Support : public BFloat16Support { 35 public: 36 TestBFloat16Support() {} 37 ~TestBFloat16Support() override {} 38 39 bool SupportsBF16Operand(const HloInstruction& hlo, 40 int64 operand_index) const override { 41 return true; 42 } 43 44 bool SupportsBF16Output(const HloInstruction& hlo) const override { 45 return true; 46 } 47 48 bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { 49 return true; 50 } 51 52 bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo, 53 int64 operand_index) const override { 54 return hlo.opcode() == HloOpcode::kDot; 55 } 56 }; 57 58 class BFloat16PropagationTest : public HloTestBase { 59 protected: 60 BFloat16PropagationTest() 61 : HloTestBase(/*verifier_layout_sensitive=*/false, 62 /*allow_mixed_precision_in_hlo_verifier=*/true) {} 63 64 // Runs the propagation pass on the given module, and returns whether the 65 // module is changed after this pass. 66 bool PropagatePrecision(HloModule* module) { 67 TestBFloat16Support bfloat16_support; 68 BFloat16Propagation propagation(&bfloat16_support); 69 StatusOr<bool> result = propagation.Run(module); 70 EXPECT_IS_OK(result.status()); 71 return result.ValueOrDie(); 72 } 73 74 // Returns whether the given HloInstruction's output element type is BF16 or 75 // the only use of it is converting to BF16. 76 bool OutputsBF16(const HloInstruction* inst) { 77 if (inst->shape().element_type() == BF16) { 78 return true; 79 } 80 return inst->user_count() == 1 && 81 inst->users()[0]->opcode() == HloOpcode::kConvert && 82 inst->users()[0]->shape().element_type() == BF16; 83 } 84 85 std::unique_ptr<HloInstruction> CreateDot(const Shape& shape, 86 HloInstruction* lhs, 87 HloInstruction* rhs) { 88 DotDimensionNumbers dot_dnums; 89 dot_dnums.add_lhs_contracting_dimensions(1); 90 dot_dnums.add_rhs_contracting_dimensions(0); 91 return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, 92 DefaultPrecisionConfig(2)); 93 } 94 }; 95 96 // Tests that BF16 can propagate through select over non-tuple buffers, but not 97 // through add where reducing operand precision can affect the result. 98 TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { 99 auto builder = HloComputation::Builder(TestName()); 100 Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); 101 102 HloInstruction* a = 103 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); 104 HloInstruction* b = 105 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); 106 HloInstruction* c = 107 builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c")); 108 HloInstruction* add0 = builder.AddInstruction( 109 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); 110 HloInstruction* add1 = builder.AddInstruction( 111 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b)); 112 HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateCompare( 113 ShapeUtil::MakeShape(PRED, {2, 4}), a, b, ComparisonDirection::kEq)); 114 HloInstruction* sel = builder.AddInstruction( 115 HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1)); 116 HloInstruction* xpose = 117 builder.AddInstruction(HloInstruction::CreateTranspose( 118 ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0})); 119 HloInstruction* dot = builder.AddInstruction( 120 CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, a)); 121 HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( 122 ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot)); 123 124 auto module = CreateNewVerifiedModule(); 125 auto computation = module->AddEntryComputation(builder.Build()); 126 127 EXPECT_TRUE(PropagatePrecision(module.get())); 128 129 EXPECT_EQ(computation->root_instruction(), root); 130 EXPECT_TRUE(OutputsBF16(xpose)); 131 EXPECT_TRUE(OutputsBF16(sel)); 132 EXPECT_TRUE(OutputsBF16(add1)); 133 EXPECT_FALSE(OutputsBF16(add0)); 134 EXPECT_FALSE(OutputsBF16(a)); 135 EXPECT_FALSE(OutputsBF16(b)); 136 EXPECT_FALSE(OutputsBF16(c)); 137 } 138 139 TEST_F(BFloat16PropagationTest, PropagateThroughMaxPoolReduceWindow) { 140 auto module = CreateNewVerifiedModule(); 141 142 auto sub_builder = HloComputation::Builder("max"); 143 HloInstruction* p0 = sub_builder.AddInstruction( 144 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "a")); 145 HloInstruction* p1 = sub_builder.AddInstruction( 146 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "b")); 147 sub_builder.AddInstruction(HloInstruction::CreateBinary( 148 ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, p0, p1)); 149 auto max_computation = module->AddEmbeddedComputation(sub_builder.Build()); 150 151 auto builder = HloComputation::Builder(TestName()); 152 Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); 153 154 HloInstruction* a = 155 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); 156 HloInstruction* b = 157 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); 158 HloInstruction* c = 159 builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c")); 160 HloInstruction* add = builder.AddInstruction( 161 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); 162 Window window; 163 WindowDimension dim; 164 dim.set_size(2); 165 dim.set_stride(1); 166 dim.set_padding_high(1); 167 dim.set_window_dilation(1); 168 dim.set_base_dilation(1); 169 *window.add_dimensions() = dim; 170 *window.add_dimensions() = dim; 171 HloInstruction* rw = 172 builder.AddInstruction(HloInstruction::CreateReduceWindow( 173 shape, add, 174 builder.AddInstruction( 175 HloInstruction::CreateConstant(LiteralUtil::Zero(F32))), 176 window, max_computation)); 177 HloInstruction* xpose = 178 builder.AddInstruction(HloInstruction::CreateTranspose( 179 ShapeUtil::MakeShape(F32, {4, 2}), c, {1, 0})); 180 HloInstruction* dot = builder.AddInstruction( 181 CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, rw)); 182 HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( 183 ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot)); 184 185 auto computation = module->AddEntryComputation(builder.Build()); 186 187 EXPECT_TRUE(PropagatePrecision(module.get())); 188 189 EXPECT_EQ(computation->root_instruction(), root); 190 EXPECT_TRUE(OutputsBF16(add)); 191 EXPECT_TRUE(OutputsBF16(xpose)); 192 EXPECT_TRUE(OutputsBF16(rw)); 193 } 194 195 // Tests that side-effecting all-reduce should not be changed. 196 TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) { 197 auto module = CreateNewVerifiedModule(); 198 199 auto builder = HloComputation::Builder(TestName()); 200 Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); 201 HloInstruction* a = 202 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); 203 HloInstruction* b = 204 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); 205 auto rb = HloComputation::Builder(TestName()); 206 rb.AddInstruction(HloInstruction::CreateBinary( 207 shape, HloOpcode::kAdd, 208 rb.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")), 209 rb.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")))); 210 auto reduction = module->AddEmbeddedComputation(rb.Build()); 211 HloInstruction* all_reduce = 212 builder.AddInstruction(HloInstruction::CreateAllReduce( 213 ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction, 214 /*replica_groups=*/{}, /*barrier=*/"", /*all_reduce_id=*/1)); 215 HloInstruction* gte0 = builder.AddInstruction( 216 HloInstruction::CreateGetTupleElement(shape, all_reduce, 0)); 217 HloInstruction* gte1 = builder.AddInstruction( 218 HloInstruction::CreateGetTupleElement(shape, all_reduce, 1)); 219 HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1)); 220 HloInstruction* root = builder.AddInstruction( 221 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); 222 223 auto computation = module->AddEntryComputation(builder.Build()); 224 225 EXPECT_FALSE(PropagatePrecision(module.get())); 226 EXPECT_EQ(computation->root_instruction(), root); 227 } 228 229 // Tests that if a constant is converted to BF16 then its literal must also be 230 // converted. 231 TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { 232 auto builder = HloComputation::Builder(TestName()); 233 Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); 234 Array2D<float> array_a(4, 4); 235 array_a.FillUnique(1.0f); 236 Array2D<float> array_b(4, 4); 237 array_b.FillUnique(10.0f); 238 239 HloInstruction* a = builder.AddInstruction( 240 HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a))); 241 HloInstruction* b = builder.AddInstruction( 242 HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b))); 243 HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b)); 244 245 auto module = CreateNewVerifiedModule(); 246 auto computation = module->AddEntryComputation(builder.Build()); 247 248 EXPECT_TRUE(PropagatePrecision(module.get())); 249 250 EXPECT_EQ(computation->root_instruction(), dot); 251 EXPECT_TRUE(OutputsBF16(dot->operand(0))); 252 EXPECT_TRUE(OutputsBF16(dot->operand(1))); 253 EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); 254 EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); 255 EXPECT_TRUE(LiteralTestUtil::Equal( 256 LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)), 257 dot->operand(0)->literal())); 258 EXPECT_TRUE(LiteralTestUtil::Equal( 259 LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)), 260 dot->operand(1)->literal())); 261 } 262 263 // Tests that BF16 can be propagated through nested tuples. 264 TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { 265 auto builder = HloComputation::Builder(TestName()); 266 Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); 267 268 HloInstruction* a = 269 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); 270 HloInstruction* b = 271 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); 272 HloInstruction* add0 = builder.AddInstruction( 273 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); 274 HloInstruction* add1 = builder.AddInstruction( 275 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a)); 276 HloInstruction* add2 = builder.AddInstruction( 277 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, b)); 278 HloInstruction* xpose = 279 builder.AddInstruction(HloInstruction::CreateTranspose( 280 ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0})); 281 282 HloInstruction* tuple0 = 283 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1, add2})); 284 HloInstruction* tuple1 = 285 builder.AddInstruction(HloInstruction::CreateTuple({tuple0, xpose})); 286 287 HloInstruction* lhs = builder.AddInstruction( 288 HloInstruction::CreateGetTupleElement(xpose->shape(), tuple1, 1)); 289 HloInstruction* rhs = 290 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 291 add0->shape(), 292 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 293 tuple0->shape(), tuple1, 0)), 294 0)); 295 HloInstruction* dot = builder.AddInstruction( 296 CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs)); 297 298 HloInstruction* output_tuple = 299 builder.AddInstruction(HloInstruction::CreateTuple({dot, add2})); 300 301 auto module = CreateNewVerifiedModule(); 302 auto computation = module->AddEntryComputation(builder.Build()); 303 304 EXPECT_TRUE(PropagatePrecision(module.get())); 305 306 EXPECT_EQ(computation->root_instruction(), output_tuple); 307 EXPECT_TRUE(OutputsBF16(xpose)); 308 EXPECT_TRUE(OutputsBF16(add0)); 309 EXPECT_TRUE(OutputsBF16(add1)); 310 EXPECT_FALSE(OutputsBF16(add2)); 311 } 312 313 // Tests that even if an instruction does not define a buffer in its output, its 314 // shape must match the defining instruction. 315 TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { 316 auto builder = HloComputation::Builder(TestName()); 317 Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); 318 319 HloInstruction* a = 320 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); 321 HloInstruction* b = 322 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); 323 HloInstruction* add0 = builder.AddInstruction( 324 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); 325 HloInstruction* add1 = builder.AddInstruction( 326 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a)); 327 328 HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateTranspose( 329 ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0})); 330 331 HloInstruction* tuple = 332 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); 333 HloInstruction* rhs = builder.AddInstruction( 334 HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1)); 335 336 // lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1. 337 HloInstruction* dot = builder.AddInstruction( 338 CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs)); 339 340 auto module = CreateNewVerifiedModule(); 341 auto computation = module->AddEntryComputation(builder.Build()); 342 343 EXPECT_TRUE(PropagatePrecision(module.get())); 344 345 EXPECT_EQ(computation->root_instruction(), dot); 346 EXPECT_TRUE(OutputsBF16(add1)); 347 EXPECT_TRUE(OutputsBF16(lhs)); 348 349 // add0 and rhs have been eliminated by simplification and DCE. 350 } 351 352 // Tests that a non-fusion computation's root should not be changed. 353 TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { 354 auto builder = HloComputation::Builder(TestName()); 355 Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); 356 357 HloInstruction* a = 358 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); 359 HloInstruction* b = 360 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); 361 HloInstruction* add = builder.AddInstruction( 362 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); 363 364 HloInstruction* dot = builder.AddInstruction(CreateDot(shape, add, add)); 365 366 HloInstruction* tuple = 367 builder.AddInstruction(HloInstruction::CreateTuple({add, dot})); 368 369 auto module = CreateNewVerifiedModule(); 370 auto computation = module->AddEntryComputation(builder.Build()); 371 372 EXPECT_FALSE(PropagatePrecision(module.get())); 373 374 EXPECT_EQ(computation->root_instruction(), tuple); 375 EXPECT_FALSE(OutputsBF16(add)); 376 } 377 378 // Tests that BF16 is propagated properly through fused computations. 379 TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { 380 auto module = CreateNewVerifiedModule(); 381 auto builder = HloComputation::Builder(TestName()); 382 Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); 383 384 HloInstruction* param = builder.AddInstruction( 385 HloInstruction::CreateParameter(0, shape, "param")); 386 HloInstruction* add = builder.AddInstruction( 387 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); 388 389 auto builder_f0 = HloComputation::Builder("fusion0"); 390 HloInstruction* a_f0 = 391 builder_f0.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); 392 HloInstruction* b_f0 = 393 builder_f0.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); 394 HloInstruction* tuple_f0 = 395 builder_f0.AddInstruction(HloInstruction::CreateTuple({a_f0, b_f0})); 396 auto comp_f0 = module->AddEmbeddedComputation(builder_f0.Build()); 397 auto fusion0 = builder.AddInstruction(HloInstruction::CreateFusion( 398 tuple_f0->shape(), HloInstruction::FusionKind::kCustom, {add, add}, 399 comp_f0)); 400 401 auto builder_f1 = HloComputation::Builder("fusion1"); 402 HloInstruction* p_f1 = builder_f1.AddInstruction( 403 HloInstruction::CreateParameter(0, tuple_f0->shape(), "param")); 404 HloInstruction* a_f1 = builder_f1.AddInstruction( 405 HloInstruction::CreateGetTupleElement(shape, p_f1, 0)); 406 HloInstruction* b_f1 = builder_f1.AddInstruction( 407 HloInstruction::CreateGetTupleElement(shape, p_f1, 1)); 408 HloInstruction* dot = builder_f1.AddInstruction(CreateDot(shape, a_f1, b_f1)); 409 auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build()); 410 auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion( 411 dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1)); 412 413 auto computation = module->AddEntryComputation(builder.Build()); 414 415 EXPECT_TRUE(PropagatePrecision(module.get())); 416 417 EXPECT_EQ(computation->root_instruction(), fusion1); 418 EXPECT_TRUE(OutputsBF16(add)); 419 EXPECT_TRUE(OutputsBF16(a_f0)); 420 EXPECT_TRUE(OutputsBF16(b_f0)); 421 EXPECT_TRUE(OutputsBF16(a_f1)); 422 EXPECT_TRUE(OutputsBF16(b_f1)); 423 } 424 425 // Tests that changes to BF16 that cannot be propagated outside a fusion are 426 // discarded. 427 TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { 428 auto module = CreateNewVerifiedModule(); 429 auto builder = HloComputation::Builder(TestName()); 430 Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); 431 432 HloInstruction* param = builder.AddInstruction( 433 HloInstruction::CreateParameter(0, shape, "param")); 434 HloInstruction* add = builder.AddInstruction( 435 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); 436 437 auto builder_f = HloComputation::Builder("fusion"); 438 HloInstruction* a_f = 439 builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); 440 HloInstruction* b_f = 441 builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); 442 HloInstruction* add_f = builder_f.AddInstruction( 443 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f)); 444 HloInstruction* dot_f = 445 builder_f.AddInstruction(CreateDot(shape, add_f, add_f)); 446 auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); 447 auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( 448 dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f)); 449 450 auto computation = module->AddEntryComputation(builder.Build()); 451 452 EXPECT_FALSE(PropagatePrecision(module.get())); 453 EXPECT_EQ(computation->root_instruction(), fusion); 454 } 455 456 // Tests that if 1) the root instruction of a fusion is a tuple, 2) the fusion 457 // outputs are only used by a dot, and 3) one element of the tuple is used by 458 // an add in the fusion computation, then the propagation pass should create a 459 // convert in the fusion computation to keep the add's operand in F32 but change 460 // the fusion output to BF16. E.g., the following fusion computation 461 // (F32, F32) fusion_computation(F32 a, F32 b) 462 // = tuple(F32 a, F32 add(F32 a, F32 b)) 463 // will be changed to 464 // (BF16, BF16) fusion_computation(F32 a, F32 b) 465 // = tuple(BF16 convert(a), BF16 add(F32 a, F32 b)) 466 TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { 467 auto module = CreateNewVerifiedModule(); 468 auto builder = HloComputation::Builder(TestName()); 469 Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); 470 471 HloInstruction* param = builder.AddInstruction( 472 HloInstruction::CreateParameter(0, shape, "param")); 473 HloInstruction* add = builder.AddInstruction( 474 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); 475 476 auto builder_f = HloComputation::Builder("fusion0"); 477 HloInstruction* a_f = 478 builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); 479 HloInstruction* b_f = 480 builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); 481 HloInstruction* add_f = builder_f.AddInstruction( 482 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f)); 483 HloInstruction* tuple_f = 484 builder_f.AddInstruction(HloInstruction::CreateTuple({a_f, add_f})); 485 auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); 486 auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( 487 tuple_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, 488 comp_f)); 489 490 HloInstruction* gte0 = builder.AddInstruction( 491 HloInstruction::CreateGetTupleElement(shape, fusion, 0)); 492 HloInstruction* gte1 = builder.AddInstruction( 493 HloInstruction::CreateGetTupleElement(shape, fusion, 1)); 494 HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1)); 495 496 auto computation = module->AddEntryComputation(builder.Build()); 497 498 EXPECT_TRUE(PropagatePrecision(module.get())); 499 500 EXPECT_EQ(computation->root_instruction(), dot); 501 EXPECT_TRUE(OutputsBF16(gte0)); 502 EXPECT_TRUE(OutputsBF16(gte1)); 503 EXPECT_FALSE(OutputsBF16(a_f)); 504 EXPECT_FALSE(OutputsBF16(b_f)); 505 EXPECT_TRUE(OutputsBF16(add_f)); 506 auto new_fusion_root = comp_f->root_instruction(); 507 EXPECT_EQ(new_fusion_root->opcode(), HloOpcode::kTuple); 508 EXPECT_EQ(new_fusion_root->operand(1), add_f); 509 EXPECT_EQ(new_fusion_root->operand(0)->opcode(), HloOpcode::kConvert); 510 EXPECT_TRUE(OutputsBF16(new_fusion_root->operand(0))); 511 } 512 513 // A select over tuples does not define the leaf buffers, so the types in 514 // on_true and on_false must match, so that as long as one of them is F32, the 515 // other must be F32 as well. 516 TEST_F(BFloat16PropagationTest, SelectOverTuples) { 517 auto module = CreateNewVerifiedModule(); 518 auto builder = HloComputation::Builder(TestName()); 519 Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); 520 521 HloInstruction* param = builder.AddInstruction( 522 HloInstruction::CreateParameter(0, shape, "param")); 523 HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateParameter( 524 1, ShapeUtil::MakeShape(PRED, {}), "pred")); 525 526 HloInstruction* add0 = builder.AddInstruction( 527 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); 528 HloInstruction* add1 = builder.AddInstruction( 529 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, param)); 530 HloInstruction* tuple0 = 531 builder.AddInstruction(HloInstruction::CreateTuple({param, add0})); 532 HloInstruction* tuple1 = 533 builder.AddInstruction(HloInstruction::CreateTuple({param, add1})); 534 HloInstruction* sel = builder.AddInstruction(HloInstruction::CreateTernary( 535 tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1)); 536 HloInstruction* gte0 = builder.AddInstruction( 537 HloInstruction::CreateGetTupleElement(shape, sel, 0)); 538 HloInstruction* gte1 = builder.AddInstruction( 539 HloInstruction::CreateGetTupleElement(shape, sel, 1)); 540 HloInstruction* xpose = 541 builder.AddInstruction(HloInstruction::CreateTranspose( 542 ShapeUtil::MakeShape(F32, {4, 2}), gte0, {1, 0})); 543 HloInstruction* dot = builder.AddInstruction( 544 CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, gte1)); 545 546 auto computation = module->AddEntryComputation(builder.Build()); 547 548 EXPECT_TRUE(PropagatePrecision(module.get())); 549 550 EXPECT_EQ(computation->root_instruction(), dot); 551 EXPECT_FALSE(OutputsBF16(add0)); 552 EXPECT_FALSE(OutputsBF16(add1)); 553 EXPECT_FALSE(OutputsBF16(gte0)); 554 EXPECT_FALSE(OutputsBF16(gte1)); 555 EXPECT_TRUE(OutputsBF16(xpose)); 556 } 557 558 // Tests that BF16 is propagated properly through a while computation with 559 // non-tuple input/output. 560 TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { 561 auto module = CreateNewVerifiedModule(); 562 auto builder = HloComputation::Builder(TestName()); 563 Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); 564 565 HloInstruction* param0 = builder.AddInstruction( 566 HloInstruction::CreateParameter(0, shape, "param0")); 567 HloInstruction* param1 = builder.AddInstruction( 568 HloInstruction::CreateParameter(1, shape, "param1")); 569 HloInstruction* add = builder.AddInstruction( 570 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); 571 572 auto builder_cond = HloComputation::Builder("cond"); 573 auto cond_param = builder_cond.AddInstruction( 574 HloInstruction::CreateParameter(0, shape, "cond_param")); 575 auto cond_dot = 576 builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param)); 577 auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateCompare( 578 ShapeUtil::MakeShape(PRED, {}), 579 builder_cond.AddInstruction(HloInstruction::CreateReshape( 580 ShapeUtil::MakeShape(F32, {}), 581 builder_cond.AddInstruction( 582 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), 583 cond_dot, {0, 0}, {1, 1}, {1, 1})))), 584 builder_cond.AddInstruction(HloInstruction::CreateReshape( 585 ShapeUtil::MakeShape(F32, {}), 586 builder_cond.AddInstruction( 587 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), 588 cond_dot, {1, 1}, {2, 2}, {1, 1})))), 589 ComparisonDirection::kGt)); 590 auto cond = module->AddEmbeddedComputation(builder_cond.Build()); 591 592 auto builder_body = HloComputation::Builder("body"); 593 auto body_param = builder_body.AddInstruction( 594 HloInstruction::CreateParameter(0, shape, "body_param")); 595 auto body_dot = 596 builder_body.AddInstruction(CreateDot(shape, body_param, body_param)); 597 auto body = module->AddEmbeddedComputation(builder_body.Build()); 598 599 auto while_hlo = builder.AddInstruction( 600 HloInstruction::CreateWhile(shape, cond, body, add)); 601 602 auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); 603 auto computation = module->AddEntryComputation(builder.Build()); 604 605 EXPECT_TRUE(PropagatePrecision(module.get())); 606 607 EXPECT_EQ(computation->root_instruction(), dot); 608 EXPECT_TRUE( 609 ShapeUtil::Equal(cond_root->shape(), ShapeUtil::MakeShape(PRED, {}))); 610 EXPECT_TRUE(OutputsBF16(add)); 611 EXPECT_TRUE(OutputsBF16(body_dot)); 612 EXPECT_TRUE(OutputsBF16(body_param)); 613 EXPECT_TRUE(OutputsBF16(cond_param)); 614 EXPECT_FALSE(OutputsBF16(dot)); 615 } 616 617 // Tests that if the while condition prevents using BF16, no changes should be 618 // made to the while body and thus the fusion node inside it. 619 TEST_F(BFloat16PropagationTest, 620 ConditionPreventsPropagationForFusionInsideWhile) { 621 auto module = CreateNewVerifiedModule(); 622 auto builder = HloComputation::Builder(TestName()); 623 Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); 624 625 HloInstruction* param0 = builder.AddInstruction( 626 HloInstruction::CreateParameter(0, shape, "param0")); 627 HloInstruction* param1 = builder.AddInstruction( 628 HloInstruction::CreateParameter(1, shape, "param1")); 629 HloInstruction* add = builder.AddInstruction( 630 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); 631 632 auto builder_cond = HloComputation::Builder("cond"); 633 auto cond_param = builder_cond.AddInstruction( 634 HloInstruction::CreateParameter(0, shape, "cond_param")); 635 builder_cond.AddInstruction(HloInstruction::CreateCompare( 636 ShapeUtil::MakeShape(PRED, {}), 637 builder_cond.AddInstruction(HloInstruction::CreateReshape( 638 ShapeUtil::MakeShape(F32, {}), 639 builder_cond.AddInstruction(HloInstruction::CreateSlice( 640 ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {0, 0}, {1, 1}, 641 {1, 1})))), 642 builder_cond.AddInstruction(HloInstruction::CreateReshape( 643 ShapeUtil::MakeShape(F32, {}), 644 builder_cond.AddInstruction(HloInstruction::CreateSlice( 645 ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2}, 646 {1, 1})))), 647 ComparisonDirection::kGt)); 648 auto cond = module->AddEmbeddedComputation(builder_cond.Build()); 649 650 auto builder_body = HloComputation::Builder("body"); 651 auto body_param = builder_body.AddInstruction( 652 HloInstruction::CreateParameter(0, shape, "body_param")); 653 auto body_transpose = builder_body.AddInstruction( 654 HloInstruction::CreateTranspose(shape, body_param, {0, 1})); 655 656 auto builder_f = HloComputation::Builder("fusion"); 657 HloInstruction* a_f = 658 builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); 659 builder_f.AddInstruction(HloInstruction::CreateTranspose(shape, a_f, {0, 1})); 660 auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); 661 auto body_fusion = builder_body.AddInstruction(HloInstruction::CreateFusion( 662 shape, HloInstruction::FusionKind::kCustom, {body_transpose}, comp_f)); 663 auto body = module->AddEmbeddedComputation(builder_body.Build()); 664 665 auto while_hlo = builder.AddInstruction( 666 HloInstruction::CreateWhile(shape, cond, body, add)); 667 668 auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); 669 auto computation = module->AddEntryComputation(builder.Build()); 670 671 EXPECT_FALSE(PropagatePrecision(module.get())); 672 EXPECT_EQ(computation->root_instruction(), dot); 673 EXPECT_FALSE(OutputsBF16(add)); 674 EXPECT_FALSE(OutputsBF16(body_fusion)); 675 EXPECT_FALSE(OutputsBF16(body_param)); 676 EXPECT_FALSE(OutputsBF16(body_transpose)); 677 EXPECT_FALSE(OutputsBF16(a_f)); 678 } 679 680 // Tests that BF16 is propagated properly through while computations with 681 // tuple-shaped input/output. 682 TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { 683 auto module = CreateNewVerifiedModule(); 684 auto builder = HloComputation::Builder(TestName()); 685 Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); 686 687 HloInstruction* param0 = builder.AddInstruction( 688 HloInstruction::CreateParameter(0, shape, "param0")); 689 HloInstruction* param1 = builder.AddInstruction( 690 HloInstruction::CreateParameter(1, shape, "param1")); 691 HloInstruction* add0 = builder.AddInstruction( 692 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); 693 HloInstruction* add1 = builder.AddInstruction( 694 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); 695 HloInstruction* tuple = 696 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); 697 698 auto builder_cond = HloComputation::Builder("cond"); 699 auto cond_param = builder_cond.AddInstruction( 700 HloInstruction::CreateParameter(0, tuple->shape(), "cond_param")); 701 auto cond_lhs = builder_cond.AddInstruction( 702 HloInstruction::CreateGetTupleElement(shape, cond_param, 0)); 703 auto cond_rhs = builder_cond.AddInstruction( 704 HloInstruction::CreateGetTupleElement(shape, cond_param, 1)); 705 // This add should prevent RHS from using BF16 706 auto cond_add_rhs = builder_cond.AddInstruction( 707 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs)); 708 auto cond_dot = 709 builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs)); 710 builder_cond.AddInstruction(HloInstruction::CreateCompare( 711 ShapeUtil::MakeShape(PRED, {}), 712 builder_cond.AddInstruction(HloInstruction::CreateReshape( 713 ShapeUtil::MakeShape(F32, {}), 714 builder_cond.AddInstruction( 715 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), 716 cond_dot, {0, 0}, {1, 1}, {1, 1})))), 717 builder_cond.AddInstruction(HloInstruction::CreateReshape( 718 ShapeUtil::MakeShape(F32, {}), 719 builder_cond.AddInstruction( 720 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), 721 cond_dot, {1, 1}, {2, 2}, {1, 1})))), 722 ComparisonDirection::kGt)); 723 auto cond = module->AddEmbeddedComputation(builder_cond.Build()); 724 725 auto builder_body = HloComputation::Builder("body"); 726 auto body_param = builder_body.AddInstruction( 727 HloInstruction::CreateParameter(0, tuple->shape(), "body_param")); 728 auto body_lhs = builder_body.AddInstruction( 729 HloInstruction::CreateGetTupleElement(shape, body_param, 0)); 730 auto body_rhs = builder_body.AddInstruction( 731 HloInstruction::CreateGetTupleElement(shape, body_param, 1)); 732 auto body_dot1 = 733 builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs)); 734 auto body_dot2 = 735 builder_body.AddInstruction(CreateDot(shape, body_rhs, body_lhs)); 736 auto body_transpose = builder_body.AddInstruction( 737 HloInstruction::CreateTranspose(shape, body_dot2, {0, 1})); 738 builder_body.AddInstruction( 739 HloInstruction::CreateTuple({body_dot1, body_transpose})); 740 auto body = module->AddEmbeddedComputation(builder_body.Build()); 741 742 auto while_hlo = builder.AddInstruction( 743 HloInstruction::CreateWhile(tuple->shape(), cond, body, tuple)); 744 745 auto lhs = builder.AddInstruction( 746 HloInstruction::CreateGetTupleElement(shape, while_hlo, 0)); 747 auto rhs = builder.AddInstruction( 748 HloInstruction::CreateGetTupleElement(shape, while_hlo, 1)); 749 auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); 750 auto computation = module->AddEntryComputation(builder.Build()); 751 752 EXPECT_TRUE(PropagatePrecision(module.get())); 753 754 EXPECT_EQ(computation->root_instruction(), dot); 755 EXPECT_TRUE(OutputsBF16(lhs)); 756 EXPECT_FALSE(OutputsBF16(rhs)); 757 EXPECT_TRUE(OutputsBF16(body_dot1)); 758 EXPECT_TRUE(OutputsBF16(body_lhs)); 759 EXPECT_FALSE(OutputsBF16(body_rhs)); 760 EXPECT_FALSE(OutputsBF16(body_dot2)); 761 EXPECT_FALSE(OutputsBF16(body_transpose)); 762 EXPECT_TRUE(OutputsBF16(cond_lhs)); 763 EXPECT_FALSE(OutputsBF16(cond_rhs)); 764 EXPECT_TRUE(OutputsBF16(add0)); 765 EXPECT_FALSE(OutputsBF16(add1)); 766 } 767 768 // Tests that BF16 is not propagated through multiple whiles that invoke the 769 // same computation as long as one while prevents the propagation. 770 TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { 771 auto module = CreateNewVerifiedModule(); 772 auto builder = HloComputation::Builder(TestName()); 773 Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); 774 775 HloInstruction* param0 = builder.AddInstruction( 776 HloInstruction::CreateParameter(0, shape, "param0")); 777 HloInstruction* param1 = builder.AddInstruction( 778 HloInstruction::CreateParameter(1, shape, "param1")); 779 HloInstruction* add0 = builder.AddInstruction( 780 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); 781 HloInstruction* add1 = builder.AddInstruction( 782 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); 783 HloInstruction* add2 = builder.AddInstruction( 784 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); 785 HloInstruction* add3 = builder.AddInstruction( 786 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); 787 HloInstruction* tuple0 = 788 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); 789 HloInstruction* tuple1 = 790 builder.AddInstruction(HloInstruction::CreateTuple({add2, add3})); 791 792 // Condition computation for the first while. 793 auto builder_cond0 = HloComputation::Builder("cond0"); 794 auto cond0_param = builder_cond0.AddInstruction( 795 HloInstruction::CreateParameter(0, tuple0->shape(), "cond0_param")); 796 auto cond0_lhs = builder_cond0.AddInstruction( 797 HloInstruction::CreateGetTupleElement(shape, cond0_param, 0)); 798 auto cond0_rhs = builder_cond0.AddInstruction( 799 HloInstruction::CreateGetTupleElement(shape, cond0_param, 1)); 800 // This add should prevent RHS from using BF16 801 auto cond0_add_rhs = 802 builder_cond0.AddInstruction(HloInstruction::CreateBinary( 803 shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs)); 804 auto cond0_dot = 805 builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs)); 806 builder_cond0.AddInstruction(HloInstruction::CreateCompare( 807 ShapeUtil::MakeShape(PRED, {}), 808 builder_cond0.AddInstruction(HloInstruction::CreateReshape( 809 ShapeUtil::MakeShape(F32, {}), 810 builder_cond0.AddInstruction( 811 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), 812 cond0_dot, {0, 0}, {1, 1}, {1, 1})))), 813 builder_cond0.AddInstruction(HloInstruction::CreateReshape( 814 ShapeUtil::MakeShape(F32, {}), 815 builder_cond0.AddInstruction( 816 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), 817 cond0_dot, {1, 1}, {2, 2}, {1, 1})))), 818 ComparisonDirection::kGt)); 819 auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build()); 820 821 // Condition computation for the second while. 822 auto builder_cond1 = HloComputation::Builder("cond1"); 823 auto cond1_param = builder_cond1.AddInstruction( 824 HloInstruction::CreateParameter(0, tuple1->shape(), "cond1_param")); 825 auto cond1_lhs = builder_cond1.AddInstruction( 826 HloInstruction::CreateGetTupleElement(shape, cond1_param, 0)); 827 auto cond1_rhs = builder_cond1.AddInstruction( 828 HloInstruction::CreateGetTupleElement(shape, cond1_param, 1)); 829 // This add should prevent LHS from using BF16 830 auto cond1_add_lhs = 831 builder_cond1.AddInstruction(HloInstruction::CreateBinary( 832 shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs)); 833 auto cond1_dot = 834 builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs)); 835 builder_cond1.AddInstruction(HloInstruction::CreateCompare( 836 ShapeUtil::MakeShape(PRED, {}), 837 builder_cond1.AddInstruction(HloInstruction::CreateReshape( 838 ShapeUtil::MakeShape(F32, {}), 839 builder_cond1.AddInstruction( 840 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), 841 cond1_dot, {0, 0}, {1, 1}, {1, 1})))), 842 builder_cond1.AddInstruction(HloInstruction::CreateReshape( 843 ShapeUtil::MakeShape(F32, {}), 844 builder_cond1.AddInstruction( 845 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), 846 cond1_dot, {1, 1}, {2, 2}, {1, 1})))), 847 ComparisonDirection::kGt)); 848 auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build()); 849 850 // Body computation shared by both whiles. 851 auto builder_body = HloComputation::Builder("body"); 852 auto body_param = builder_body.AddInstruction( 853 HloInstruction::CreateParameter(0, tuple0->shape(), "body_param")); 854 auto body_lhs = builder_body.AddInstruction( 855 HloInstruction::CreateGetTupleElement(shape, body_param, 0)); 856 auto body_rhs = builder_body.AddInstruction( 857 HloInstruction::CreateGetTupleElement(shape, body_param, 1)); 858 auto body_dot = 859 builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs)); 860 builder_body.AddInstruction( 861 HloInstruction::CreateTuple({body_dot, body_rhs})); 862 auto body = module->AddEmbeddedComputation(builder_body.Build()); 863 864 auto while0 = builder.AddInstruction( 865 HloInstruction::CreateWhile(tuple0->shape(), cond0, body, tuple0)); 866 auto while1 = builder.AddInstruction( 867 HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1)); 868 869 auto lhs = builder.AddInstruction( 870 CreateDot(shape, 871 builder.AddInstruction( 872 HloInstruction::CreateGetTupleElement(shape, while0, 0)), 873 builder.AddInstruction( 874 HloInstruction::CreateGetTupleElement(shape, while0, 1)))); 875 auto rhs = builder.AddInstruction( 876 CreateDot(shape, 877 builder.AddInstruction( 878 HloInstruction::CreateGetTupleElement(shape, while1, 0)), 879 builder.AddInstruction( 880 HloInstruction::CreateGetTupleElement(shape, while1, 1)))); 881 auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); 882 auto computation = module->AddEntryComputation(builder.Build()); 883 884 EXPECT_TRUE(PropagatePrecision(module.get())); 885 EXPECT_FALSE(OutputsBF16(body_dot)); 886 EXPECT_FALSE(OutputsBF16(body_rhs)); 887 EXPECT_FALSE(OutputsBF16(body_lhs)); 888 EXPECT_FALSE(OutputsBF16(cond0_lhs)); 889 EXPECT_FALSE(OutputsBF16(cond0_rhs)); 890 EXPECT_FALSE(OutputsBF16(cond1_lhs)); 891 EXPECT_FALSE(OutputsBF16(cond1_rhs)); 892 EXPECT_TRUE(OutputsBF16(cond0_add_rhs)); 893 EXPECT_TRUE(OutputsBF16(cond1_add_lhs)); 894 EXPECT_EQ(computation->root_instruction(), dot); 895 } 896 897 // Tests that if this pass turns an F32 -> BF16 conversion into a no-op (BF16 -> 898 // BF16 conversion), then it will remove that conversion. 899 TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { 900 auto builder = HloComputation::Builder(TestName()); 901 Shape f32_shape = ShapeUtil::MakeShape(F32, {4, 4}); 902 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {4, 4}); 903 904 HloInstruction* param = builder.AddInstruction( 905 HloInstruction::CreateParameter(0, f32_shape, "param")); 906 HloInstruction* add0 = builder.AddInstruction( 907 HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param)); 908 HloInstruction* add1 = builder.AddInstruction( 909 HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param)); 910 HloInstruction* tuple = 911 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); 912 HloInstruction* gte0 = builder.AddInstruction( 913 HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0)); 914 HloInstruction* gte1 = builder.AddInstruction( 915 HloInstruction::CreateGetTupleElement(f32_shape, tuple, 1)); 916 HloInstruction* convert0 = 917 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte0)); 918 HloInstruction* convert1 = 919 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte1)); 920 HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary( 921 bf16_shape, HloOpcode::kAdd, convert0, convert1)); 922 923 auto module = CreateNewVerifiedModule(); 924 auto computation = module->AddEntryComputation(builder.Build()); 925 926 EXPECT_TRUE(PropagatePrecision(module.get())); 927 928 EXPECT_EQ(computation->root_instruction(), add2); 929 EXPECT_EQ(add2->operand(0), add0); 930 EXPECT_EQ(add2->operand(1), add1); 931 EXPECT_EQ(add0->shape().element_type(), BF16); 932 EXPECT_EQ(add1->shape().element_type(), BF16); 933 } 934 935 TEST_F(BFloat16PropagationTest, TupleDomain) { 936 auto builder = HloComputation::Builder(TestName()); 937 Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); 938 939 HloInstruction* a = 940 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); 941 HloInstruction* b = 942 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); 943 HloInstruction* a_trans = 944 builder.AddInstruction(HloInstruction::CreateTranspose(shape, a, {0, 1})); 945 HloInstruction* b_trans = 946 builder.AddInstruction(HloInstruction::CreateTranspose(shape, b, {0, 1})); 947 HloInstruction* tuple = 948 builder.AddInstruction(HloInstruction::CreateTuple({a_trans, b_trans})); 949 HloInstruction* domain = builder.AddInstruction( 950 HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); 951 HloInstruction* a_gte = builder.AddInstruction( 952 HloInstruction::CreateGetTupleElement(shape, domain, 0)); 953 HloInstruction* b_gte = builder.AddInstruction( 954 HloInstruction::CreateGetTupleElement(shape, domain, 1)); 955 HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a_gte, b_gte)); 956 HloInstruction* root = builder.AddInstruction( 957 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); 958 959 auto module = CreateNewVerifiedModule(); 960 auto computation = module->AddEntryComputation(builder.Build()); 961 962 EXPECT_TRUE(PropagatePrecision(module.get())); 963 EXPECT_EQ(computation->root_instruction(), root); 964 965 // test BF16 propagated through domain 966 EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 0).element_type(), 967 BF16); 968 EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 1).element_type(), 969 BF16); 970 971 EXPECT_TRUE(OutputsBF16(a_trans)); 972 EXPECT_TRUE(OutputsBF16(b_trans)); 973 EXPECT_TRUE(OutputsBF16(a_gte)); 974 EXPECT_TRUE(OutputsBF16(b_gte)); 975 EXPECT_FALSE(OutputsBF16(a)); 976 EXPECT_FALSE(OutputsBF16(b)); 977 } 978 979 // Tests that bf16 is not propagated through a domain in case its input cannot 980 // be propagated. In the case below the input of the domain is the parameter 981 // tuple which cannot be propagated, so the domain instruction is not propagated 982 // either. 983 TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) { 984 auto builder = HloComputation::Builder(TestName()); 985 Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); 986 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); 987 988 HloInstruction* param = builder.AddInstruction( 989 HloInstruction::CreateParameter(0, tuple_shape, "param")); 990 HloInstruction* domain = builder.AddInstruction( 991 HloInstruction::CreateDomain(param->shape(), param, nullptr, nullptr)); 992 HloInstruction* a_gte = builder.AddInstruction( 993 HloInstruction::CreateGetTupleElement(shape, domain, 0)); 994 HloInstruction* b_gte = builder.AddInstruction( 995 HloInstruction::CreateGetTupleElement(shape, domain, 1)); 996 HloInstruction* a_trans = builder.AddInstruction( 997 HloInstruction::CreateTranspose(shape, a_gte, {0, 1})); 998 HloInstruction* b_trans = builder.AddInstruction( 999 HloInstruction::CreateTranspose(shape, b_gte, {0, 1})); 1000 HloInstruction* dot = 1001 builder.AddInstruction(CreateDot(shape, a_trans, b_trans)); 1002 HloInstruction* root = builder.AddInstruction( 1003 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); 1004 1005 auto module = CreateNewVerifiedModule(); 1006 auto computation = module->AddEntryComputation(builder.Build()); 1007 1008 EXPECT_TRUE(PropagatePrecision(module.get())); 1009 1010 EXPECT_EQ(computation->root_instruction(), root); 1011 EXPECT_TRUE(OutputsBF16(a_trans)); 1012 EXPECT_TRUE(OutputsBF16(b_trans)); 1013 EXPECT_FALSE(OutputsBF16(a_gte)); 1014 EXPECT_FALSE(OutputsBF16(b_gte)); 1015 EXPECT_FALSE(OutputsBF16(domain)); 1016 EXPECT_FALSE(OutputsBF16(param)); 1017 } 1018 1019 } // namespace xla 1020