Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
     17 
     18 #include <memory>
     19 #include <utility>
     20 
     21 #include "tensorflow/compiler/xla/client/client.h"
     22 #include "tensorflow/compiler/xla/client/client_library.h"
     23 #include "tensorflow/compiler/xla/client/computation.h"
     24 #include "tensorflow/compiler/xla/client/computation_builder.h"
     25 #include "tensorflow/compiler/xla/client/local_client.h"
     26 #include "tensorflow/compiler/xla/client/padding.h"
     27 #include "tensorflow/compiler/xla/service/computation_tracker.h"
     28 #include "tensorflow/compiler/xla/service/hlo_module.h"
     29 #include "tensorflow/compiler/xla/service/local_service.h"
     30 #include "tensorflow/compiler/xla/service/service.h"
     31 #include "tensorflow/compiler/xla/service/user_computation.h"
     32 #include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
     33 #include "tensorflow/compiler/xla/shape_util.h"
     34 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 
     37 #include "tensorflow/compiler/xla/statusor.h"
     38 #include "tensorflow/compiler/xla/test_helpers.h"
     39 
     40 namespace xla {
     41 namespace {
     42 
     43 constexpr int64 kPointerSize = 8;
     44 
     45 int64 ShapeSize(const Shape& shape) {
     46   return ShapeUtil::ByteSizeOf(shape, kPointerSize);
     47 }
     48 
     49 // This test suite tests the HLO cost analysis by first building a computation
     50 // using the client computation builder and running the HloCostAnalysis that
     51 // returns the number of floating point and transcendental operations in the
     52 // graph. We test both individual HLO operations as well as a mixed graph.
     53 class HloCostAnalysisTest : public ::testing::Test {
     54  protected:
     55   HloCostAnalysisTest()
     56       : client_(ClientLibrary::LocalClientOrDie()),
     57         // Accessing service instance is required for the unit tests to enable
     58         // whitebox accesses to the user computation built from the client,
     59         // as shown in the BuildHloGraph functions below.
     60         service_(static_cast<Service*>(ClientLibrary::GetXlaService(
     61             static_cast<LocalClient*>(client_)->platform()))),
     62         computation_tracker_(service_->computation_tracker()) {
     63     // Create a computation for a unary user function: x => exp(x + 0.5)
     64     {
     65       ComputationBuilder builder(client_, "add_and_exp");
     66       auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
     67       auto half = builder.ConstantR0<float>(0.5);
     68       builder.Exp(builder.Add(x, half));
     69       auto computation_status = builder.Build();
     70       TF_CHECK_OK(computation_status.status());
     71       add_and_exp_ = computation_status.ConsumeValueOrDie();
     72     }
     73 
     74     // Create a computation for a binary user function: (x, y) => x + y
     75     {
     76       ComputationBuilder builder(client_, "add");
     77       auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
     78       auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
     79       builder.Add(x, y);
     80       auto computation_status = builder.Build();
     81       TF_CHECK_OK(computation_status.status());
     82       add_ = computation_status.ConsumeValueOrDie();
     83     }
     84 
     85     // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x))
     86     {
     87       ComputationBuilder builder(client_, "sigmoid");
     88       auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
     89       auto one = builder.ConstantR0<float>(1.0);
     90       builder.Div(one, builder.Add(one, builder.Exp(builder.Neg(x))));
     91       auto computation_status = builder.Build();
     92       TF_CHECK_OK(computation_status.status());
     93       sigmoid_ = computation_status.ConsumeValueOrDie();
     94     }
     95 
     96     // Create a computation for a binary max function: (x, y) => max (x, y)
     97     {
     98       ComputationBuilder builder(client_, "max");
     99       auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    100       auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
    101       builder.Max(x, y);
    102       auto computation_status = builder.Build();
    103       TF_CHECK_OK(computation_status.status());
    104       max_ = computation_status.ConsumeValueOrDie();
    105     }
    106 
    107     // Create a computation for a binary GT function: (x, y) => x > y
    108     {
    109       ComputationBuilder builder(client_, "gt");
    110       auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    111       auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
    112       builder.Gt(x, y);
    113       auto computation_status = builder.Build();
    114       TF_CHECK_OK(computation_status.status());
    115       gt_ = computation_status.ConsumeValueOrDie();
    116     }
    117   }
    118 
    119   // Build HLO graph from the given builder and return the HLO module.
    120   std::unique_ptr<HloModule> BuildHloGraph(ComputationBuilder* builder) {
    121     auto computation_status = builder->Build();
    122     TF_CHECK_OK(computation_status.status());
    123     auto computation = computation_status.ConsumeValueOrDie();
    124     auto user_computation_status =
    125         computation_tracker_.Resolve(computation.handle());
    126     TF_CHECK_OK(user_computation_status.status());
    127     auto user_computation = user_computation_status.ConsumeValueOrDie();
    128     VersionedComputationHandle versioned_handle =
    129         user_computation->GetVersionedHandle();
    130     return std::move(
    131         computation_tracker_.BuildHloModule(versioned_handle, HloModuleConfig())
    132             .ValueOrDie());
    133   }
    134 
    135   Client* client_;
    136   Service* service_;
    137   const ComputationTracker& computation_tracker_;
    138 
    139   // User computations used for higher order operations (e.g., Map, Reduce).
    140   Computation add_;
    141   Computation add_and_exp_;
    142   Computation sigmoid_;
    143   Computation max_;
    144   Computation gt_;
    145 };
    146 
    147 TEST_F(HloCostAnalysisTest, MatrixMultiply) {
    148   ComputationBuilder builder(client_, "matrix_multiply");
    149   auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs");
    150   auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs");
    151   auto result = builder.Dot(lhs, rhs);
    152 
    153   // Run HLO cost analysis.
    154   auto hlo_module = BuildHloGraph(&builder);
    155   HloCostAnalysis analysis(ShapeSize);
    156   ASSERT_IS_OK(
    157       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
    158 
    159   // Check the number of computations returned from the analysis (1500 FMAs).
    160   EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5);
    161 
    162   EXPECT_EQ(analysis.transcendental_count(), 0);
    163 
    164   // Bytes accessed is sum of inputs and output.
    165   EXPECT_EQ(analysis.bytes_accessed(),
    166             sizeof(float) * (10 * 5 + 5 * 30 + 10 * 30));
    167 }
    168 
    169 TEST_F(HloCostAnalysisTest, Map) {
    170   ComputationBuilder builder(client_, "map");
    171   auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in");
    172   auto result = builder.Map({input}, add_and_exp_, {0});
    173 
    174   // Run HLO cost analysis.
    175   auto hlo_module = BuildHloGraph(&builder);
    176   HloCostAnalysis analysis(ShapeSize);
    177   ASSERT_IS_OK(
    178       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
    179 
    180   // add contributes to 10 flops and exp contributes to 10 transcendental ops.
    181   EXPECT_EQ(analysis.flop_count(), 10);
    182   EXPECT_EQ(analysis.transcendental_count(), 10);
    183   EXPECT_EQ(analysis.bytes_accessed(), 80);
    184 }
    185 
    186 TEST_F(HloCostAnalysisTest, Convolution) {
    187   ComputationBuilder builder(client_, "convolution");
    188   auto input = builder.Parameter(
    189       0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
    190                                     /*x_dim=*/20}),
    191       "input");
    192   auto kernel = builder.Parameter(
    193       1, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
    194                                     /*x_dim=*/3}),
    195       "kernel");
    196   auto result = builder.Conv(input, kernel, {1, 1}, Padding::kValid);
    197 
    198   // Run HLO cost analysis.
    199   auto hlo_module = BuildHloGraph(&builder);
    200   HloCostAnalysis analysis(ShapeSize);
    201   ASSERT_IS_OK(
    202       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
    203 
    204   // Output shape is [1x1x8x18] and each output element requires (3x3)
    205   // FMAs and one FMA is 2 flops.
    206   EXPECT_EQ(analysis.flop_count(), 8 * 18 * 2 * 3 * 3);
    207 
    208   // Bytes accessed is sum of inputs and output.
    209   EXPECT_EQ(analysis.bytes_accessed(),
    210             sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18));
    211 }
    212 
    213 TEST_F(HloCostAnalysisTest, Reduce) {
    214   ComputationBuilder builder(client_, "reduce");
    215   auto input =
    216       builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
    217   auto result =
    218       builder.Reduce(input, builder.ConstantR0<float>(0.0f), add_, {1});
    219 
    220   // Run HLO cost analysis.
    221   auto hlo_module = BuildHloGraph(&builder);
    222   HloCostAnalysis analysis(ShapeSize);
    223   ASSERT_IS_OK(
    224       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
    225 
    226   // Subtracting the output size from the input size gives the number of
    227   // reduction operations performed.
    228   EXPECT_EQ(analysis.flop_count(), 10 * 20 - 10);
    229 }
    230 
    231 TEST_F(HloCostAnalysisTest, ReduceWindow) {
    232   ComputationBuilder builder(client_, "reduce_window");
    233   auto input =
    234       builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
    235   auto result = builder.ReduceWindow(input, builder.ConstantR0<float>(0), add_,
    236                                      {4, 5}, {4, 5}, Padding::kValid);
    237 
    238   // Run HLO cost analysis.
    239   auto hlo_module = BuildHloGraph(&builder);
    240   HloCostAnalysis analysis(ShapeSize);
    241   ASSERT_IS_OK(
    242       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
    243 
    244   // Each of [2x4] output elements are generated from reducing [4x5] elements.
    245   EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1));
    246 }
    247 
    248 TEST_F(HloCostAnalysisTest, SelectAndScatter) {
    249   ComputationBuilder builder(client_, "select_and_scatter");
    250   auto operand =
    251       builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
    252   auto source =
    253       builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 4}), "source");
    254   auto result =
    255       builder.SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid,
    256                                source, builder.ConstantR0<float>(0), add_);
    257 
    258   // Run HLO cost analysis.
    259   auto hlo_module = BuildHloGraph(&builder);
    260   HloCostAnalysis analysis(ShapeSize);
    261   ASSERT_IS_OK(
    262       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
    263 
    264   // Each of [2x4] source elements computes its destination from reducing [4x5]
    265   // elements followed by the scatter computation.
    266   EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1 + 1));
    267 }
    268 
    269 TEST_F(HloCostAnalysisTest, Broadcast) {
    270   ComputationBuilder b(client_, "broadcast");
    271   b.Broadcast(b.ConstantR0<float>(42), {10, 7});
    272   auto hlo_module = BuildHloGraph(&b);
    273   HloCostAnalysis analysis(ShapeSize);
    274   ASSERT_IS_OK(
    275       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
    276   EXPECT_EQ(analysis.flop_count(), 0);
    277 }
    278 
    279 // Calculates the computation cost of a graph with more than one HLO node.
    280 TEST_F(HloCostAnalysisTest, FullyConnectedForward) {
    281   ComputationBuilder builder(client_, "fully_connected_forward");
    282   auto input =
    283       builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "input");
    284   auto weight =
    285       builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 20}), "weight");
    286   auto bias = builder.Parameter(2, ShapeUtil::MakeShape(F32, {20}), "bias");
    287   // sigmoid(input * weight + bias)
    288   auto result = builder.Map(
    289       {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_, {0, 1});
    290 
    291   // Run HLO cost analysis.
    292   auto hlo_module = BuildHloGraph(&builder);
    293   HloCostAnalysis analysis(ShapeSize);
    294   ASSERT_IS_OK(
    295       hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
    296 
    297   // 1000 FMAs from matrix multiplication, 200 flops from bias addition,
    298   // 600 flops from sigmoid, and 200 transcendental ops from sigmoid.
    299   EXPECT_EQ(analysis.flop_count(), 2 * 1000 + 200 + 3 * 200);
    300   EXPECT_EQ(analysis.transcendental_count(), 200);
    301 }
    302 
    303 TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) {
    304   HloCostAnalysis conv_analysis(ShapeSize);
    305   {
    306     ComputationBuilder builder(client_, "conv_looking_matmul");
    307     auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
    308                                  "input");
    309     auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
    310                                  "weights");
    311     builder.Conv(lhs, rhs, {1, 1}, Padding::kSame);
    312     auto hlo_module = BuildHloGraph(&builder);
    313     ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
    314         &conv_analysis));
    315   }
    316 
    317   HloCostAnalysis matmul_analysis(ShapeSize);
    318   {
    319     ComputationBuilder builder(client_, "matmul");
    320     auto lhs =
    321         builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64}), "input");
    322     auto rhs =
    323         builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64}), "weights");
    324     builder.Dot(lhs, rhs);
    325     auto hlo_module = BuildHloGraph(&builder);
    326     ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
    327         &matmul_analysis));
    328   }
    329 
    330   EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count());
    331 }
    332 
    333 using FusionCostAnalysis = HloTestBase;
    334 
    335 TEST_F(FusionCostAnalysis, LoopFusion) {
    336   // Do this 4 times with different per-second rates to test the computation of
    337   // bottleneck time on fusion nodes.
    338   for (int i = 0; i < 4; ++i) {
    339     Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
    340 
    341     // Fuse all instructions in complicated expression:
    342     //
    343     //   add = Add(C1, C2)
    344     //   clamp = Clamp(C2, add, add)
    345     //   exp = Exp(add)
    346     //   mul = Mul(exp, C3)
    347     //   sub = Sub(mul, clamp)
    348     //   tuple = Tuple({sub, sub, mul, C1})
    349     HloComputation::Builder builder(TestName());
    350     auto c1 = builder.AddInstruction(
    351         HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
    352             /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)));
    353     auto c2 = builder.AddInstruction(
    354         HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
    355             /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)));
    356     auto c3 = builder.AddInstruction(
    357         HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
    358             /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)));
    359     auto add = builder.AddInstruction(
    360         HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2));
    361     auto clamp = builder.AddInstruction(
    362         HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, c2, add, add));
    363     auto exp = builder.AddInstruction(
    364         HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add));
    365     auto mul = builder.AddInstruction(
    366         HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, exp, c3));
    367     auto sub = builder.AddInstruction(
    368         HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp));
    369     auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1});
    370 
    371     HloModule module(TestName());
    372     auto* computation = module.AddEntryComputation(builder.Build());
    373     auto* fusion = computation->CreateFusionInstruction(
    374         {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
    375 
    376     // The time given these rates at i == 0 is exactly even among the properties
    377     // at 1.0 seconds. For other values, one of the rates is slower so that it
    378     // becomes the bottleneck.
    379     HloCostAnalysis fusion_analysis(ShapeSize);
    380     fusion_analysis.set_flops_per_second(16 * (i == 1 ? 1 / 2.0 : 1.0));
    381     fusion_analysis.set_transcendentals_per_second(4 *
    382                                                    (i == 2 ? 1 / 4.0 : 1.0));
    383     fusion_analysis.set_bytes_per_second(64 * (i == 3 ? 1 / 8.0 : 1.0));
    384     ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
    385 
    386     EXPECT_EQ(fusion_analysis.flop_count(), 16);
    387     EXPECT_EQ(fusion_analysis.transcendental_count(), 4);
    388     constexpr int64 bytes_accessed = sizeof(float) * 4 * 2 * 2;
    389     static_assert(bytes_accessed == 64, "");
    390     EXPECT_EQ(fusion_analysis.bytes_accessed(), bytes_accessed);
    391 
    392     EXPECT_EQ(fusion_analysis.optimal_seconds(), 1 << i);
    393   }
    394 }
    395 
    396 TEST_F(FusionCostAnalysis, NoLayout) {
    397   Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
    398   // Instructions within a fused op may have no layout.
    399   Shape shape_without_layout = shape_with_layout;
    400   shape_without_layout.clear_layout();
    401 
    402   HloComputation::Builder builder(TestName());
    403   auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
    404       Literal::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5))));
    405   auto c2 = builder.AddInstruction(
    406       HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3})));
    407 
    408   auto broadcast = builder.AddInstruction(
    409       HloInstruction::CreateBroadcast(shape_without_layout, c2, {1}));
    410   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
    411       shape_with_layout, HloOpcode::kAdd, c1, broadcast));
    412 
    413   HloModule module(TestName());
    414   auto* computation = module.AddEntryComputation(builder.Build());
    415   auto* fusion = computation->CreateFusionInstruction(
    416       {add, broadcast}, HloInstruction::FusionKind::kLoop);
    417 
    418   HloCostAnalysis fusion_analysis(ShapeSize);
    419   ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
    420 
    421   EXPECT_EQ(fusion_analysis.flop_count(), 120);
    422   EXPECT_EQ(fusion_analysis.transcendental_count(), 0);
    423 }
    424 
    425 TEST_F(HloCostAnalysisTest, TupleCost) {
    426   HloCostAnalysis analysis(ShapeSize);
    427   {
    428     ComputationBuilder builder(client_, "matmul");
    429     auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {123}), "x");
    430     auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {42}), "y");
    431     auto tuple = builder.Tuple({x, y});
    432     auto hlo_module = BuildHloGraph(&builder);
    433 
    434     ASSERT_IS_OK(
    435         hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
    436   }
    437 
    438   EXPECT_EQ(analysis.flop_count(), 0);
    439   EXPECT_EQ(analysis.transcendental_count(), 0);
    440   EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2);
    441 }
    442 
    443 }  // namespace
    444 }  // namespace xla
    445