1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" 17 #include "tensorflow/compiler/xla/service/bfloat16_support.h" 18 #include "tensorflow/compiler/xla/service/hlo_computation.h" 19 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 20 #include "tensorflow/compiler/xla/service/hlo_module.h" 21 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 22 #include "tensorflow/compiler/xla/shape_util.h" 23 #include "tensorflow/compiler/xla/test.h" 24 #include "tensorflow/compiler/xla/test_helpers.h" 25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 28 namespace xla { 29 30 class TestBFloat16Support : public BFloat16Support { 31 public: 32 TestBFloat16Support() {} 33 ~TestBFloat16Support() override {} 34 35 bool SupportsBF16Operand(const HloInstruction& hlo, 36 int64 operand_index) const override { 37 if (hlo.opcode() == HloOpcode::kAdd || 38 hlo.opcode() == HloOpcode::kSubtract || 39 hlo.opcode() == HloOpcode::kReduce || 40 hlo.opcode() == HloOpcode::kTuple || 41 hlo.opcode() == HloOpcode::kGetTupleElement) { 42 return true; 43 } 44 return false; 45 } 46 47 bool SupportsBF16Output(const HloInstruction& hlo) const override { 48 if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce || 49 hlo.opcode() == HloOpcode::kSubtract || 50 hlo.opcode() == HloOpcode::kTuple || 51 hlo.opcode() == HloOpcode::kGetTupleElement) { 52 return true; 53 } 54 return false; 55 } 56 57 bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { 58 if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple || 59 hlo.opcode() == HloOpcode::kGetTupleElement) { 60 return true; 61 } 62 return false; 63 } 64 }; 65 66 class BFloat16NormalizationTest : public HloTestBase { 67 protected: 68 bool Normalize(HloModule* module) { 69 TestBFloat16Support bfloat16_support_; 70 BFloat16Normalization normalization(&bfloat16_support_); 71 StatusOr<bool> result = normalization.Run(module); 72 EXPECT_IS_OK(result.status()); 73 return result.ValueOrDie(); 74 } 75 }; 76 77 TEST_F(BFloat16NormalizationTest, NoopIfSupported) { 78 auto builder = HloComputation::Builder(TestName()); 79 Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); 80 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); 81 82 HloInstruction* a = builder.AddInstruction( 83 HloInstruction::CreateParameter(0, f32_shape, "a")); 84 HloInstruction* b = builder.AddInstruction( 85 HloInstruction::CreateParameter(1, bf16_shape, "b")); 86 HloInstruction* c = builder.AddInstruction( 87 HloInstruction::CreateParameter(2, f32_shape, "c")); 88 89 HloInstruction* add0 = builder.AddInstruction( 90 HloInstruction::CreateBinary(bf16_shape, HloOpcode::kAdd, a, b)); 91 92 HloInstruction* add1 = builder.AddInstruction( 93 HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, add0, c)); 94 95 auto module = CreateNewModule(); 96 auto computation = module->AddEntryComputation(builder.Build()); 97 98 EXPECT_FALSE(Normalize(module.get())); 99 100 EXPECT_EQ(computation->root_instruction(), add1); 101 EXPECT_EQ(add0->shape().element_type(), BF16); 102 EXPECT_EQ(add1->shape().element_type(), F32); 103 } 104 105 TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) { 106 auto builder = HloComputation::Builder(TestName()); 107 Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); 108 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); 109 110 HloInstruction* a = builder.AddInstruction( 111 HloInstruction::CreateParameter(0, f32_shape, "a")); 112 HloInstruction* b = builder.AddInstruction( 113 HloInstruction::CreateParameter(1, bf16_shape, "b")); 114 HloInstruction* c = builder.AddInstruction( 115 HloInstruction::CreateParameter(2, f32_shape, "c")); 116 117 HloInstruction* mul0 = builder.AddInstruction( 118 HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, a, b)); 119 120 HloInstruction* mul1 = builder.AddInstruction( 121 HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, mul0, c)); 122 123 auto module = CreateNewModule(); 124 auto computation = module->AddEntryComputation(builder.Build()); 125 126 EXPECT_TRUE(Normalize(module.get())); 127 128 EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); 129 EXPECT_EQ(computation->root_instruction()->operand(0), mul1); 130 EXPECT_EQ(mul0->shape().element_type(), F32); 131 EXPECT_EQ(mul1->shape().element_type(), F32); 132 EXPECT_EQ(mul1->operand(0)->opcode(), HloOpcode::kConvert); 133 } 134 135 TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) { 136 auto builder = HloComputation::Builder(TestName()); 137 Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); 138 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); 139 140 HloInstruction* a = builder.AddInstruction( 141 HloInstruction::CreateParameter(0, f32_shape, "a")); 142 HloInstruction* b = builder.AddInstruction( 143 HloInstruction::CreateParameter(1, bf16_shape, "b")); 144 HloInstruction* c = builder.AddInstruction( 145 HloInstruction::CreateParameter(2, f32_shape, "c")); 146 147 HloInstruction* sub0 = builder.AddInstruction( 148 HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, a, b)); 149 150 HloInstruction* sub1 = builder.AddInstruction( 151 HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, sub0, c)); 152 153 auto module = CreateNewModule(); 154 auto computation = module->AddEntryComputation(builder.Build()); 155 156 EXPECT_TRUE(Normalize(module.get())); 157 158 EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); 159 EXPECT_EQ(computation->root_instruction()->operand(0), sub1); 160 EXPECT_EQ(sub0->shape().element_type(), F32); 161 EXPECT_EQ(sub1->shape().element_type(), F32); 162 EXPECT_EQ(sub1->operand(0)->opcode(), HloOpcode::kConvert); 163 } 164 165 TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { 166 Shape f32_input_shape = ShapeUtil::MakeShape(F32, {2, 4}); 167 Shape f32_output_shape = ShapeUtil::MakeShape(F32, {4}); 168 169 Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {2, 4}); 170 171 auto reduce_comp_builder = HloComputation::Builder("reduce_comp"); 172 auto reduce_comp_param0 = reduce_comp_builder.AddInstruction( 173 HloInstruction::CreateParameter(0, bf16_scalar_shape, "param0")); 174 auto reduce_comp_param1 = reduce_comp_builder.AddInstruction( 175 HloInstruction::CreateParameter(1, bf16_scalar_shape, "param1")); 176 reduce_comp_builder.AddInstruction( 177 HloInstruction::CreateBinary(bf16_scalar_shape, HloOpcode::kAdd, 178 reduce_comp_param0, reduce_comp_param1)); 179 180 auto module = CreateNewModule(); 181 auto reduce_computation = 182 module->AddEmbeddedComputation(reduce_comp_builder.Build()); 183 184 auto builder = HloComputation::Builder(TestName()); 185 HloInstruction* input = builder.AddInstruction( 186 HloInstruction::CreateParameter(0, f32_input_shape, "a")); 187 HloInstruction* init = builder.AddInstruction( 188 HloInstruction::CreateParameter(1, bf16_scalar_shape, "init")); 189 HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce( 190 f32_output_shape, input, init, {0}, reduce_computation)); 191 192 auto computation = module->AddEntryComputation(builder.Build()); 193 194 EXPECT_TRUE(Normalize(module.get())); 195 196 EXPECT_EQ(computation->root_instruction(), reduce); 197 EXPECT_EQ(reduce->called_computations().size(), 1); 198 EXPECT_EQ(reduce->called_computations()[0]->num_parameters(), 2); 199 EXPECT_EQ(reduce->called_computations()[0] 200 ->parameter_instruction(0) 201 ->shape() 202 .element_type(), 203 F32); 204 EXPECT_EQ(reduce->called_computations()[0] 205 ->parameter_instruction(1) 206 ->shape() 207 .element_type(), 208 F32); 209 EXPECT_EQ(reduce->called_computations()[0] 210 ->root_instruction() 211 ->shape() 212 .element_type(), 213 F32); 214 EXPECT_EQ(reduce->shape().element_type(), F32); 215 EXPECT_EQ(reduce->operand(0), input); 216 EXPECT_EQ(input->shape().element_type(), F32); 217 EXPECT_EQ(reduce->operand(1)->opcode(), HloOpcode::kConvert); 218 EXPECT_EQ(reduce->operand(1)->shape().element_type(), F32); 219 } 220 221 TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { 222 auto builder = HloComputation::Builder(TestName()); 223 Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); 224 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); 225 226 HloInstruction* a = builder.AddInstruction( 227 HloInstruction::CreateParameter(0, f32_shape, "a")); 228 HloInstruction* b = builder.AddInstruction( 229 HloInstruction::CreateParameter(1, bf16_shape, "b")); 230 231 HloInstruction* crs = 232 builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( 233 ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b})); 234 HloInstruction* gte = builder.AddInstruction( 235 HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); 236 237 auto module = CreateNewModule(); 238 auto computation = module->AddEntryComputation(builder.Build()); 239 240 EXPECT_TRUE(Normalize(module.get())); 241 242 EXPECT_EQ(computation->root_instruction(), gte); 243 EXPECT_EQ(gte->shape().element_type(), BF16); 244 EXPECT_EQ(crs->operand(1)->shape().element_type(), F32); 245 EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32); 246 } 247 248 } // namespace xla 249