Home | History | Annotate | Download | only in service
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
     17 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
     18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     20 #include "tensorflow/compiler/xla/service/hlo_module.h"
     21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     22 #include "tensorflow/compiler/xla/shape_util.h"
     23 #include "tensorflow/compiler/xla/test.h"
     24 #include "tensorflow/compiler/xla/test_helpers.h"
     25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     26 #include "tensorflow/compiler/xla/xla_data.pb.h"
     27 
     28 namespace xla {
     29 
     30 class TestBFloat16Support : public BFloat16Support {
     31  public:
     32   TestBFloat16Support() {}
     33   ~TestBFloat16Support() override {}
     34 
     35   bool SupportsBF16Operand(const HloInstruction& hlo,
     36                            int64 operand_index) const override {
     37     if (hlo.opcode() == HloOpcode::kAdd ||
     38         hlo.opcode() == HloOpcode::kSubtract ||
     39         hlo.opcode() == HloOpcode::kReduce ||
     40         hlo.opcode() == HloOpcode::kTuple ||
     41         hlo.opcode() == HloOpcode::kGetTupleElement) {
     42       return true;
     43     }
     44     return false;
     45   }
     46 
     47   bool SupportsBF16Output(const HloInstruction& hlo) const override {
     48     if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce ||
     49         hlo.opcode() == HloOpcode::kSubtract ||
     50         hlo.opcode() == HloOpcode::kTuple ||
     51         hlo.opcode() == HloOpcode::kGetTupleElement) {
     52       return true;
     53     }
     54     return false;
     55   }
     56 
     57   bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
     58     if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple ||
     59         hlo.opcode() == HloOpcode::kGetTupleElement) {
     60       return true;
     61     }
     62     return false;
     63   }
     64 };
     65 
     66 class BFloat16NormalizationTest : public HloTestBase {
     67  protected:
     68   bool Normalize(HloModule* module) {
     69     TestBFloat16Support bfloat16_support_;
     70     BFloat16Normalization normalization(&bfloat16_support_);
     71     StatusOr<bool> result = normalization.Run(module);
     72     EXPECT_IS_OK(result.status());
     73     return result.ValueOrDie();
     74   }
     75 };
     76 
     77 TEST_F(BFloat16NormalizationTest, NoopIfSupported) {
     78   auto builder = HloComputation::Builder(TestName());
     79   Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
     80   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
     81 
     82   HloInstruction* a = builder.AddInstruction(
     83       HloInstruction::CreateParameter(0, f32_shape, "a"));
     84   HloInstruction* b = builder.AddInstruction(
     85       HloInstruction::CreateParameter(1, bf16_shape, "b"));
     86   HloInstruction* c = builder.AddInstruction(
     87       HloInstruction::CreateParameter(2, f32_shape, "c"));
     88 
     89   HloInstruction* add0 = builder.AddInstruction(
     90       HloInstruction::CreateBinary(bf16_shape, HloOpcode::kAdd, a, b));
     91 
     92   HloInstruction* add1 = builder.AddInstruction(
     93       HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, add0, c));
     94 
     95   auto module = CreateNewModule();
     96   auto computation = module->AddEntryComputation(builder.Build());
     97 
     98   EXPECT_FALSE(Normalize(module.get()));
     99 
    100   EXPECT_EQ(computation->root_instruction(), add1);
    101   EXPECT_EQ(add0->shape().element_type(), BF16);
    102   EXPECT_EQ(add1->shape().element_type(), F32);
    103 }
    104 
    105 TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) {
    106   auto builder = HloComputation::Builder(TestName());
    107   Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
    108   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
    109 
    110   HloInstruction* a = builder.AddInstruction(
    111       HloInstruction::CreateParameter(0, f32_shape, "a"));
    112   HloInstruction* b = builder.AddInstruction(
    113       HloInstruction::CreateParameter(1, bf16_shape, "b"));
    114   HloInstruction* c = builder.AddInstruction(
    115       HloInstruction::CreateParameter(2, f32_shape, "c"));
    116 
    117   HloInstruction* mul0 = builder.AddInstruction(
    118       HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, a, b));
    119 
    120   HloInstruction* mul1 = builder.AddInstruction(
    121       HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, mul0, c));
    122 
    123   auto module = CreateNewModule();
    124   auto computation = module->AddEntryComputation(builder.Build());
    125 
    126   EXPECT_TRUE(Normalize(module.get()));
    127 
    128   EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
    129   EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
    130   EXPECT_EQ(mul0->shape().element_type(), F32);
    131   EXPECT_EQ(mul1->shape().element_type(), F32);
    132   EXPECT_EQ(mul1->operand(0)->opcode(), HloOpcode::kConvert);
    133 }
    134 
    135 TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) {
    136   auto builder = HloComputation::Builder(TestName());
    137   Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
    138   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
    139 
    140   HloInstruction* a = builder.AddInstruction(
    141       HloInstruction::CreateParameter(0, f32_shape, "a"));
    142   HloInstruction* b = builder.AddInstruction(
    143       HloInstruction::CreateParameter(1, bf16_shape, "b"));
    144   HloInstruction* c = builder.AddInstruction(
    145       HloInstruction::CreateParameter(2, f32_shape, "c"));
    146 
    147   HloInstruction* sub0 = builder.AddInstruction(
    148       HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, a, b));
    149 
    150   HloInstruction* sub1 = builder.AddInstruction(
    151       HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, sub0, c));
    152 
    153   auto module = CreateNewModule();
    154   auto computation = module->AddEntryComputation(builder.Build());
    155 
    156   EXPECT_TRUE(Normalize(module.get()));
    157 
    158   EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
    159   EXPECT_EQ(computation->root_instruction()->operand(0), sub1);
    160   EXPECT_EQ(sub0->shape().element_type(), F32);
    161   EXPECT_EQ(sub1->shape().element_type(), F32);
    162   EXPECT_EQ(sub1->operand(0)->opcode(), HloOpcode::kConvert);
    163 }
    164 
    165 TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
    166   Shape f32_input_shape = ShapeUtil::MakeShape(F32, {2, 4});
    167   Shape f32_output_shape = ShapeUtil::MakeShape(F32, {4});
    168 
    169   Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {2, 4});
    170 
    171   auto reduce_comp_builder = HloComputation::Builder("reduce_comp");
    172   auto reduce_comp_param0 = reduce_comp_builder.AddInstruction(
    173       HloInstruction::CreateParameter(0, bf16_scalar_shape, "param0"));
    174   auto reduce_comp_param1 = reduce_comp_builder.AddInstruction(
    175       HloInstruction::CreateParameter(1, bf16_scalar_shape, "param1"));
    176   reduce_comp_builder.AddInstruction(
    177       HloInstruction::CreateBinary(bf16_scalar_shape, HloOpcode::kAdd,
    178                                    reduce_comp_param0, reduce_comp_param1));
    179 
    180   auto module = CreateNewModule();
    181   auto reduce_computation =
    182       module->AddEmbeddedComputation(reduce_comp_builder.Build());
    183 
    184   auto builder = HloComputation::Builder(TestName());
    185   HloInstruction* input = builder.AddInstruction(
    186       HloInstruction::CreateParameter(0, f32_input_shape, "a"));
    187   HloInstruction* init = builder.AddInstruction(
    188       HloInstruction::CreateParameter(1, bf16_scalar_shape, "init"));
    189   HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce(
    190       f32_output_shape, input, init, {0}, reduce_computation));
    191 
    192   auto computation = module->AddEntryComputation(builder.Build());
    193 
    194   EXPECT_TRUE(Normalize(module.get()));
    195 
    196   EXPECT_EQ(computation->root_instruction(), reduce);
    197   EXPECT_EQ(reduce->called_computations().size(), 1);
    198   EXPECT_EQ(reduce->called_computations()[0]->num_parameters(), 2);
    199   EXPECT_EQ(reduce->called_computations()[0]
    200                 ->parameter_instruction(0)
    201                 ->shape()
    202                 .element_type(),
    203             F32);
    204   EXPECT_EQ(reduce->called_computations()[0]
    205                 ->parameter_instruction(1)
    206                 ->shape()
    207                 .element_type(),
    208             F32);
    209   EXPECT_EQ(reduce->called_computations()[0]
    210                 ->root_instruction()
    211                 ->shape()
    212                 .element_type(),
    213             F32);
    214   EXPECT_EQ(reduce->shape().element_type(), F32);
    215   EXPECT_EQ(reduce->operand(0), input);
    216   EXPECT_EQ(input->shape().element_type(), F32);
    217   EXPECT_EQ(reduce->operand(1)->opcode(), HloOpcode::kConvert);
    218   EXPECT_EQ(reduce->operand(1)->shape().element_type(), F32);
    219 }
    220 
    221 TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
    222   auto builder = HloComputation::Builder(TestName());
    223   Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
    224   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
    225 
    226   HloInstruction* a = builder.AddInstruction(
    227       HloInstruction::CreateParameter(0, f32_shape, "a"));
    228   HloInstruction* b = builder.AddInstruction(
    229       HloInstruction::CreateParameter(1, bf16_shape, "b"));
    230 
    231   HloInstruction* crs =
    232       builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
    233           ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}));
    234   HloInstruction* gte = builder.AddInstruction(
    235       HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
    236 
    237   auto module = CreateNewModule();
    238   auto computation = module->AddEntryComputation(builder.Build());
    239 
    240   EXPECT_TRUE(Normalize(module.get()));
    241 
    242   EXPECT_EQ(computation->root_instruction(), gte);
    243   EXPECT_EQ(gte->shape().element_type(), BF16);
    244   EXPECT_EQ(crs->operand(1)->shape().element_type(), F32);
    245   EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32);
    246 }
    247 
    248 }  // namespace xla
    249