1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h" 17 #include "tensorflow/compiler/xla/service/bfloat16_support.h" 18 #include "tensorflow/compiler/xla/service/hlo_computation.h" 19 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 20 #include "tensorflow/compiler/xla/service/hlo_module.h" 21 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 22 #include "tensorflow/compiler/xla/shape_util.h" 23 #include "tensorflow/compiler/xla/test.h" 24 #include "tensorflow/compiler/xla/test_helpers.h" 25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 28 namespace xla { 29 30 class TestBFloat16Support : public BFloat16Support { 31 public: 32 TestBFloat16Support() {} 33 ~TestBFloat16Support() override {} 34 35 bool SupportsBF16Operand(const HloInstruction& hlo, 36 int64 operand_index) const override { 37 if (hlo.opcode() == HloOpcode::kAdd || 38 hlo.opcode() == HloOpcode::kSubtract || 39 hlo.opcode() == HloOpcode::kTuple || 40 hlo.opcode() == HloOpcode::kGetTupleElement) { 41 return true; 42 } 43 return false; 44 } 45 46 bool SupportsBF16Output(const HloInstruction& hlo) const override { 47 if (hlo.opcode() == HloOpcode::kAdd || 48 hlo.opcode() == HloOpcode::kSubtract || 49 hlo.opcode() == HloOpcode::kTuple || 50 hlo.opcode() == HloOpcode::kGetTupleElement) { 51 return true; 52 } 53 return false; 54 } 55 56 bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { 57 if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple || 58 hlo.opcode() == HloOpcode::kGetTupleElement) { 59 return true; 60 } 61 return false; 62 } 63 }; 64 65 class BFloat16ConversionFoldingTest : public HloTestBase { 66 protected: 67 bool FoldConversions(HloModule* module) { 68 TestBFloat16Support bfloat16_support_; 69 BFloat16ConversionFolding fold(&bfloat16_support_); 70 StatusOr<bool> result = fold.Run(module); 71 EXPECT_IS_OK(result.status()); 72 return result.ValueOrDie(); 73 } 74 }; 75 76 TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) { 77 auto builder = HloComputation::Builder(TestName()); 78 Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); 79 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); 80 81 HloInstruction* a = builder.AddInstruction( 82 HloInstruction::CreateParameter(0, f32_shape, "a")); 83 HloInstruction* b = builder.AddInstruction( 84 HloInstruction::CreateParameter(1, f32_shape, "b")); 85 HloInstruction* c = builder.AddInstruction( 86 HloInstruction::CreateParameter(2, f32_shape, "c")); 87 88 HloInstruction* add0 = builder.AddInstruction( 89 HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, a, b)); 90 HloInstruction* convert0 = 91 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add0)); 92 HloInstruction* convert1 = builder.AddInstruction( 93 HloInstruction::CreateConvert(f32_shape, convert0)); 94 95 HloInstruction* add1 = builder.AddInstruction( 96 HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, convert1, c)); 97 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add1)); 98 99 auto module = CreateNewModule(); 100 auto computation = module->AddEntryComputation(builder.Build()); 101 102 EXPECT_TRUE(FoldConversions(module.get())); 103 104 EXPECT_EQ(computation->root_instruction(), add1); 105 EXPECT_EQ(add0->shape().element_type(), BF16); 106 EXPECT_EQ(add1->shape().element_type(), BF16); 107 EXPECT_EQ(add1->operand(0), add0); 108 } 109 110 TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) { 111 auto builder = HloComputation::Builder(TestName()); 112 Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); 113 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); 114 115 HloInstruction* a = builder.AddInstruction( 116 HloInstruction::CreateParameter(0, f32_shape, "a")); 117 HloInstruction* b = builder.AddInstruction( 118 HloInstruction::CreateParameter(1, f32_shape, "b")); 119 HloInstruction* c = builder.AddInstruction( 120 HloInstruction::CreateParameter(2, f32_shape, "c")); 121 122 HloInstruction* mul0 = builder.AddInstruction( 123 HloInstruction::CreateBinary(f32_shape, HloOpcode::kMultiply, a, b)); 124 HloInstruction* convert0 = 125 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul0)); 126 HloInstruction* convert1 = builder.AddInstruction( 127 HloInstruction::CreateConvert(f32_shape, convert0)); 128 129 HloInstruction* mul1 = builder.AddInstruction(HloInstruction::CreateBinary( 130 f32_shape, HloOpcode::kMultiply, convert1, c)); 131 HloInstruction* convert2 = 132 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul1)); 133 134 auto module = CreateNewModule(); 135 auto computation = module->AddEntryComputation(builder.Build()); 136 137 EXPECT_FALSE(FoldConversions(module.get())); 138 139 EXPECT_EQ(computation->root_instruction(), convert2); 140 EXPECT_EQ(mul0->shape().element_type(), F32); 141 EXPECT_EQ(mul1->shape().element_type(), F32); 142 EXPECT_EQ(mul1->operand(0), convert1); 143 } 144 145 TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) { 146 auto builder = HloComputation::Builder(TestName()); 147 Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); 148 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); 149 150 HloInstruction* a = builder.AddInstruction( 151 HloInstruction::CreateParameter(0, f32_shape, "a")); 152 HloInstruction* b = builder.AddInstruction( 153 HloInstruction::CreateParameter(1, f32_shape, "b")); 154 HloInstruction* c = builder.AddInstruction( 155 HloInstruction::CreateParameter(2, f32_shape, "c")); 156 157 HloInstruction* sub0 = builder.AddInstruction( 158 HloInstruction::CreateBinary(f32_shape, HloOpcode::kSubtract, a, b)); 159 HloInstruction* convert0 = 160 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub0)); 161 HloInstruction* convert1 = builder.AddInstruction( 162 HloInstruction::CreateConvert(f32_shape, convert0)); 163 164 HloInstruction* sub1 = builder.AddInstruction(HloInstruction::CreateBinary( 165 f32_shape, HloOpcode::kSubtract, convert1, c)); 166 HloInstruction* convert2 = 167 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub1)); 168 169 auto module = CreateNewModule(); 170 auto computation = module->AddEntryComputation(builder.Build()); 171 172 EXPECT_FALSE(FoldConversions(module.get())); 173 174 EXPECT_EQ(computation->root_instruction(), convert2); 175 EXPECT_EQ(sub0->shape().element_type(), F32); 176 EXPECT_EQ(sub1->shape().element_type(), F32); 177 EXPECT_EQ(sub1->operand(0), convert1); 178 } 179 180 TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { 181 auto builder = HloComputation::Builder(TestName()); 182 Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); 183 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); 184 185 HloInstruction* a = builder.AddInstruction( 186 HloInstruction::CreateParameter(0, f32_shape, "a")); 187 HloInstruction* b = builder.AddInstruction( 188 HloInstruction::CreateParameter(1, bf16_shape, "b")); 189 HloInstruction* convert0 = 190 builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, b)); 191 192 HloInstruction* tuple = 193 builder.AddInstruction(HloInstruction::CreateTuple({a, convert0})); 194 HloInstruction* gte = builder.AddInstruction( 195 HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0)); 196 HloInstruction* convert1 = 197 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte)); 198 199 auto module = CreateNewModule(); 200 auto computation = module->AddEntryComputation(builder.Build()); 201 202 EXPECT_FALSE(FoldConversions(module.get())); 203 204 EXPECT_EQ(computation->root_instruction(), convert1); 205 EXPECT_EQ(gte->shape().element_type(), F32); 206 EXPECT_EQ(tuple->operand(1), convert0); 207 } 208 209 } // namespace xla 210