Home | History | Annotate | Download | only in service
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
     17 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
     18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     20 #include "tensorflow/compiler/xla/service/hlo_module.h"
     21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     22 #include "tensorflow/compiler/xla/shape_util.h"
     23 #include "tensorflow/compiler/xla/test.h"
     24 #include "tensorflow/compiler/xla/test_helpers.h"
     25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     26 #include "tensorflow/compiler/xla/xla_data.pb.h"
     27 
     28 namespace xla {
     29 
     30 class TestBFloat16Support : public BFloat16Support {
     31  public:
     32   TestBFloat16Support() {}
     33   ~TestBFloat16Support() override {}
     34 
     35   bool SupportsBF16Operand(const HloInstruction& hlo,
     36                            int64 operand_index) const override {
     37     if (hlo.opcode() == HloOpcode::kAdd ||
     38         hlo.opcode() == HloOpcode::kSubtract ||
     39         hlo.opcode() == HloOpcode::kTuple ||
     40         hlo.opcode() == HloOpcode::kGetTupleElement) {
     41       return true;
     42     }
     43     return false;
     44   }
     45 
     46   bool SupportsBF16Output(const HloInstruction& hlo) const override {
     47     if (hlo.opcode() == HloOpcode::kAdd ||
     48         hlo.opcode() == HloOpcode::kSubtract ||
     49         hlo.opcode() == HloOpcode::kTuple ||
     50         hlo.opcode() == HloOpcode::kGetTupleElement) {
     51       return true;
     52     }
     53     return false;
     54   }
     55 
     56   bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
     57     if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple ||
     58         hlo.opcode() == HloOpcode::kGetTupleElement) {
     59       return true;
     60     }
     61     return false;
     62   }
     63 };
     64 
     65 class BFloat16ConversionFoldingTest : public HloTestBase {
     66  protected:
     67   bool FoldConversions(HloModule* module) {
     68     TestBFloat16Support bfloat16_support_;
     69     BFloat16ConversionFolding fold(&bfloat16_support_);
     70     StatusOr<bool> result = fold.Run(module);
     71     EXPECT_IS_OK(result.status());
     72     return result.ValueOrDie();
     73   }
     74 };
     75 
     76 TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) {
     77   auto builder = HloComputation::Builder(TestName());
     78   Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
     79   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
     80 
     81   HloInstruction* a = builder.AddInstruction(
     82       HloInstruction::CreateParameter(0, f32_shape, "a"));
     83   HloInstruction* b = builder.AddInstruction(
     84       HloInstruction::CreateParameter(1, f32_shape, "b"));
     85   HloInstruction* c = builder.AddInstruction(
     86       HloInstruction::CreateParameter(2, f32_shape, "c"));
     87 
     88   HloInstruction* add0 = builder.AddInstruction(
     89       HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, a, b));
     90   HloInstruction* convert0 =
     91       builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add0));
     92   HloInstruction* convert1 = builder.AddInstruction(
     93       HloInstruction::CreateConvert(f32_shape, convert0));
     94 
     95   HloInstruction* add1 = builder.AddInstruction(
     96       HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, convert1, c));
     97   builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add1));
     98 
     99   auto module = CreateNewModule();
    100   auto computation = module->AddEntryComputation(builder.Build());
    101 
    102   EXPECT_TRUE(FoldConversions(module.get()));
    103 
    104   EXPECT_EQ(computation->root_instruction(), add1);
    105   EXPECT_EQ(add0->shape().element_type(), BF16);
    106   EXPECT_EQ(add1->shape().element_type(), BF16);
    107   EXPECT_EQ(add1->operand(0), add0);
    108 }
    109 
    110 TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) {
    111   auto builder = HloComputation::Builder(TestName());
    112   Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
    113   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
    114 
    115   HloInstruction* a = builder.AddInstruction(
    116       HloInstruction::CreateParameter(0, f32_shape, "a"));
    117   HloInstruction* b = builder.AddInstruction(
    118       HloInstruction::CreateParameter(1, f32_shape, "b"));
    119   HloInstruction* c = builder.AddInstruction(
    120       HloInstruction::CreateParameter(2, f32_shape, "c"));
    121 
    122   HloInstruction* mul0 = builder.AddInstruction(
    123       HloInstruction::CreateBinary(f32_shape, HloOpcode::kMultiply, a, b));
    124   HloInstruction* convert0 =
    125       builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul0));
    126   HloInstruction* convert1 = builder.AddInstruction(
    127       HloInstruction::CreateConvert(f32_shape, convert0));
    128 
    129   HloInstruction* mul1 = builder.AddInstruction(HloInstruction::CreateBinary(
    130       f32_shape, HloOpcode::kMultiply, convert1, c));
    131   HloInstruction* convert2 =
    132       builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul1));
    133 
    134   auto module = CreateNewModule();
    135   auto computation = module->AddEntryComputation(builder.Build());
    136 
    137   EXPECT_FALSE(FoldConversions(module.get()));
    138 
    139   EXPECT_EQ(computation->root_instruction(), convert2);
    140   EXPECT_EQ(mul0->shape().element_type(), F32);
    141   EXPECT_EQ(mul1->shape().element_type(), F32);
    142   EXPECT_EQ(mul1->operand(0), convert1);
    143 }
    144 
    145 TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) {
    146   auto builder = HloComputation::Builder(TestName());
    147   Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
    148   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
    149 
    150   HloInstruction* a = builder.AddInstruction(
    151       HloInstruction::CreateParameter(0, f32_shape, "a"));
    152   HloInstruction* b = builder.AddInstruction(
    153       HloInstruction::CreateParameter(1, f32_shape, "b"));
    154   HloInstruction* c = builder.AddInstruction(
    155       HloInstruction::CreateParameter(2, f32_shape, "c"));
    156 
    157   HloInstruction* sub0 = builder.AddInstruction(
    158       HloInstruction::CreateBinary(f32_shape, HloOpcode::kSubtract, a, b));
    159   HloInstruction* convert0 =
    160       builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub0));
    161   HloInstruction* convert1 = builder.AddInstruction(
    162       HloInstruction::CreateConvert(f32_shape, convert0));
    163 
    164   HloInstruction* sub1 = builder.AddInstruction(HloInstruction::CreateBinary(
    165       f32_shape, HloOpcode::kSubtract, convert1, c));
    166   HloInstruction* convert2 =
    167       builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub1));
    168 
    169   auto module = CreateNewModule();
    170   auto computation = module->AddEntryComputation(builder.Build());
    171 
    172   EXPECT_FALSE(FoldConversions(module.get()));
    173 
    174   EXPECT_EQ(computation->root_instruction(), convert2);
    175   EXPECT_EQ(sub0->shape().element_type(), F32);
    176   EXPECT_EQ(sub1->shape().element_type(), F32);
    177   EXPECT_EQ(sub1->operand(0), convert1);
    178 }
    179 
    180 TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
    181   auto builder = HloComputation::Builder(TestName());
    182   Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
    183   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
    184 
    185   HloInstruction* a = builder.AddInstruction(
    186       HloInstruction::CreateParameter(0, f32_shape, "a"));
    187   HloInstruction* b = builder.AddInstruction(
    188       HloInstruction::CreateParameter(1, bf16_shape, "b"));
    189   HloInstruction* convert0 =
    190       builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, b));
    191 
    192   HloInstruction* tuple =
    193       builder.AddInstruction(HloInstruction::CreateTuple({a, convert0}));
    194   HloInstruction* gte = builder.AddInstruction(
    195       HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0));
    196   HloInstruction* convert1 =
    197       builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte));
    198 
    199   auto module = CreateNewModule();
    200   auto computation = module->AddEntryComputation(builder.Build());
    201 
    202   EXPECT_FALSE(FoldConversions(module.get()));
    203 
    204   EXPECT_EQ(computation->root_instruction(), convert1);
    205   EXPECT_EQ(gte->shape().element_type(), F32);
    206   EXPECT_EQ(tuple->operand(1), convert0);
    207 }
    208 
    209 }  // namespace xla
    210