Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <math.h>
     17 #include <algorithm>
     18 #include <memory>
     19 #include <new>
     20 #include <random>
     21 #include <utility>
     22 
     23 #define EIGEN_USE_THREADS
     24 
     25 #include "absl/memory/memory.h"
     26 #include "absl/types/span.h"
     27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     28 #include "tensorflow/compiler/xla/array2d.h"
     29 #include "tensorflow/compiler/xla/client/client_library.h"
     30 #include "tensorflow/compiler/xla/client/xla_builder.h"
     31 #include "tensorflow/compiler/xla/literal.h"
     32 #include "tensorflow/compiler/xla/primitive_util.h"
     33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     35 #include "tensorflow/compiler/xla/service/hlo_module.h"
     36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     37 #include "tensorflow/compiler/xla/service/hlo_parser.h"
     38 #include "tensorflow/compiler/xla/service/platform_util.h"
     39 #include "tensorflow/compiler/xla/shape_util.h"
     40 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     41 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     42 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     43 #include "tensorflow/compiler/xla/tests/test_macros.h"
     44 #include "tensorflow/compiler/xla/xla_data.pb.h"
     45 #include "tensorflow/core/common_runtime/eigen_thread_pool.h"
     46 #include "tensorflow/core/platform/logging.h"
     47 #include "tensorflow/core/platform/protobuf.h"
     48 #include "tensorflow/core/platform/test_benchmark.h"
     49 #include "tensorflow/core/platform/types.h"
     50 
     51 namespace xla {
     52 namespace {
     53 
     54 const int test_width = 2, test_height = 3;
     55 
     56 const float test_float_vals[3][test_width][test_height] = {
     57     {{-1.0, -1.0, 1.0}, {-3.0, 0.0, -1.0}},
     58     {{-3.0, 2.0, 1.0}, {0.0, -3.0, 1.0}},
     59     {{-3.0, 0.0, -3.0}, {-1.0, -2.0, 1.0}}};
     60 
     61 // Test whether fusion operations are emitted with no errors and compute
     62 // accurate outputs.
     63 class FusionTest : public HloTestBase {
     64  protected:
     65   template <typename T, int Arity>
     66   void TestElementwise2D(
     67       HloOpcode opcode,
     68       absl::optional<ComparisonDirection> direction = absl::nullopt) {
     69     // Create a variable for comparisons since they require the direction.
     70     bool is_compare = std::is_same<T, bool>::value;
     71     Array2D<float> operand_data[Arity];
     72     for (int i = 0; i < Arity; ++i) {
     73       new (&operand_data[i]) Array2D<float>(test_width, test_height);
     74     }
     75     Array2D<T> answer_data(test_width, test_height);
     76     for (int i = 0; i < test_width; ++i) {
     77       for (int j = 0; j < test_height; ++j) {
     78         float xs[Arity];
     79         for (int k = 0; k < Arity; ++k) {
     80           xs[k] = test_float_vals[k][i][j];
     81           operand_data[k](i, j) = xs[k];
     82         }
     83         if (is_compare) {
     84           answer_data(i, j) = ComputeElementwiseAnswerCompare(*direction, xs);
     85         } else {
     86           answer_data(i, j) = ComputeElementwiseAnswerFloat(opcode, xs);
     87         }
     88       }
     89     }
     90 
     91     auto builder = HloComputation::Builder(TestName());
     92     auto hlo_module = CreateNewVerifiedModule();
     93 
     94     auto prim_type = primitive_util::NativeToPrimitiveType<T>();
     95 
     96     HloInstruction* hlos[4];
     97     for (int i = 0; i < Arity; ++i) {
     98       hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant(
     99           LiteralUtil::CreateR2FromArray2D(operand_data[i])));
    100     }
    101     auto answer_shape =
    102         ShapeUtil::MakeShape(prim_type, {test_width, test_height});
    103     std::unique_ptr<HloInstruction> root_hlo;
    104     switch (Arity) {
    105       case 1:
    106         root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]);
    107         break;
    108       case 2:
    109         if (is_compare) {
    110           root_hlo = HloInstruction::CreateCompare(answer_shape, hlos[1],
    111                                                    hlos[2], *direction);
    112         } else {
    113           root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1],
    114                                                   hlos[2]);
    115         }
    116         break;
    117       case 3:
    118         root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1],
    119                                                  hlos[2], hlos[3]);
    120         break;
    121       default:
    122         LOG(FATAL) << "Bad arity: " << Arity;
    123     }
    124     hlos[0] = builder.AddInstruction(std::move(root_hlo));
    125     hlo_module->AddEntryComputation(builder.Build())
    126         ->CreateFusionInstruction(
    127             absl::Span<HloInstruction* const>(hlos).subspan(0, Arity + 1),
    128             HloInstruction::FusionKind::kLoop);
    129 
    130     auto expected = LiteralUtil::CreateR2FromArray2D(answer_data);
    131     auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
    132     if (primitive_util::IsFloatingPointType(prim_type)) {
    133       EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, ErrorSpec(1e-4)));
    134     } else {
    135       EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
    136     }
    137   }
    138 
    139  private:
    140   float ComputeElementwiseAnswerFloat(HloOpcode opcode,
    141                                       absl::Span<const float> xs);
    142   bool ComputeElementwiseAnswerCompare(ComparisonDirection direction,
    143                                        absl::Span<const float> xs);
    144   DebugOptions GetDebugOptionsForTest() override {
    145     DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
    146     debug_options.add_xla_disable_hlo_passes("layout-assignment");
    147     return debug_options;
    148   }
    149 };
    150 
    151 float FusionTest::ComputeElementwiseAnswerFloat(HloOpcode opcode,
    152                                                 absl::Span<const float> xs) {
    153   switch (opcode) {
    154     case HloOpcode::kAdd:
    155       return xs[0] + xs[1];
    156     case HloOpcode::kSubtract:
    157       return xs[0] - xs[1];
    158     case HloOpcode::kMultiply:
    159       return xs[0] * xs[1];
    160     case HloOpcode::kDivide:
    161       return xs[0] / xs[1];
    162     case HloOpcode::kPower:
    163       return powf(xs[0], xs[1]);
    164     case HloOpcode::kMinimum:
    165       return std::min(xs[0], xs[1]);
    166     case HloOpcode::kMaximum:
    167       return std::max(xs[0], xs[1]);
    168     case HloOpcode::kClamp:
    169       return std::min(xs[2], std::max(xs[1], xs[0]));
    170     default:
    171       LOG(FATAL) << "No elementwise opcode: " << opcode;
    172   }
    173 }
    174 
    175 bool FusionTest::ComputeElementwiseAnswerCompare(ComparisonDirection direction,
    176                                                  absl::Span<const float> xs) {
    177   switch (direction) {
    178     case ComparisonDirection::kEq:
    179       return xs[0] == xs[1];
    180     case ComparisonDirection::kNe:
    181       return xs[0] != xs[1];
    182     case ComparisonDirection::kGt:
    183       return xs[0] > xs[1];
    184     case ComparisonDirection::kLt:
    185       return xs[0] < xs[1];
    186     case ComparisonDirection::kGe:
    187       return xs[0] >= xs[1];
    188     case ComparisonDirection::kLe:
    189       return xs[0] <= xs[1];
    190   }
    191 }
    192 
    193 XLA_TEST_F(FusionTest, Test) {
    194   // test expression:
    195   // slice(select({{T, F, T}, {F, T, F}},
    196   //              concat(transpose({{1.0}, {2.0}, {3.0}} +
    197   //                               {{-1.0}, {-1.0}, {-1.0}}),
    198   //                     {{1.62, 2.72, 3.14}}) +
    199   //                     (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}),
    200   //              {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}}
    201   auto builder = HloComputation::Builder(TestName());
    202   auto hlo_module = CreateNewVerifiedModule();
    203   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    204       LiteralUtil::CreateR2<float>({{1.0}, {2.0}, {3.0}})));
    205   auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
    206       LiteralUtil::CreateR2<float>({{-1.0}, {-1.0}, {-1.0}})));
    207   auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
    208       ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1));
    209   auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose(
    210       ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0}));
    211   auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
    212       LiteralUtil::CreateR2<float>({{1.62, 2.72, 3.14}})));
    213   auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate(
    214       ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0));
    215   auto const6 = builder.AddInstruction(HloInstruction::CreateConstant(
    216       LiteralUtil::CreateR2<float>({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}})));
    217   auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary(
    218       ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6));
    219   auto add8 = builder.AddInstruction(HloInstruction::CreateBinary(
    220       ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7));
    221   auto const9 = builder.AddInstruction(HloInstruction::CreateConstant(
    222       LiteralUtil::CreateR2<float>({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})));
    223   auto const10 = builder.AddInstruction(
    224       HloInstruction::CreateConstant(LiteralUtil::CreateR2<bool>(
    225           {{true, false, true}, {false, true, false}})));
    226   auto select11 = builder.AddInstruction(
    227       HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}),
    228                                     HloOpcode::kSelect, const10, add8, const9));
    229   auto slice12 = builder.AddInstruction(HloInstruction::CreateSlice(
    230       ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2}, {1, 1}));
    231   // CreateFusionInstruction needs the `instructions_to_fuse` argument in
    232   // reverse topological order, so the first element in `instructions_to_fuse`
    233   // must be the root.
    234   hlo_module->AddEntryComputation(builder.Build())
    235       ->CreateFusionInstruction(
    236           {slice12, select11, const10, const9, add8, negate7, const6, concat5,
    237            const4, reshape3, add2, const1, const0},
    238           HloInstruction::FusionKind::kLoop);
    239 
    240   EXPECT_TRUE(LiteralTestUtil::Near(
    241       LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
    242       ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
    243 }
    244 
    245 // Test whether we emit appropriate code for parameters of fusion instructions.
    246 XLA_TEST_F(FusionTest, Parameter) {
    247   // Build a computation and fuse part of it so the fusion instruction has an
    248   // operand parameter.
    249   auto builder = HloComputation::Builder(TestName());
    250   auto hlo_module = CreateNewVerifiedModule();
    251   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    252       LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}})));
    253   auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
    254       ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0));
    255   auto const2 = builder.AddInstruction(HloInstruction::CreateConstant(
    256       LiteralUtil::CreateR2<float>({{-2.0, -2.0, -2.0}})));
    257   // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1}
    258   auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
    259       ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2));
    260   // CreateFusionInstruction needs `instructions_to_fuse` in reverse topological
    261   // order.
    262   hlo_module->AddEntryComputation(builder.Build())
    263       ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2},
    264                                 HloInstruction::FusionKind::kLoop);
    265 
    266   EXPECT_TRUE(LiteralTestUtil::Near(
    267       LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
    268       ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
    269 }
    270 
    271 XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
    272   // Tests parallel partitioning of a fusion instruction.
    273   // Create shape with random outer dimension size to generate random parallel
    274   // partition counts for each test run.
    275   const int seed = tensorflow::testing::RandomSeed();
    276   LOG(INFO) << "RandomizedParallelPartition seed: " << seed;
    277   std::mt19937 generator(seed);
    278   std::uniform_int_distribution<int> distribution(128, 1024);
    279   const int64 rand_dim0_size = distribution(generator);
    280   const int64 dim1_size = 1024;
    281   Shape shape =
    282       ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0});
    283   // Build simple fusion computation: y = x^2 (elementwise).
    284   auto builder = HloComputation::Builder(TestName());
    285   auto hlo_module = CreateNewVerifiedModule();
    286 
    287   auto two = builder.AddInstruction(
    288       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
    289   auto x =
    290       builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {}));
    291   auto y = builder.AddInstruction(
    292       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, x, x));
    293 
    294   hlo_module->AddEntryComputation(builder.Build())
    295       ->CreateFusionInstruction(/*instructions_to_fuse=*/{y, x, two},
    296                                 HloInstruction::FusionKind::kLoop);
    297   // Compute result.
    298   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
    299   // Every element of result should be y = x^2 = 4.0.
    300   for (int i = 0; i < rand_dim0_size; ++i) {
    301     for (int j = 0; j < dim1_size; ++j) {
    302       EXPECT_EQ(4.0, result.Get<float>({i, j}));
    303     }
    304   }
    305 }
    306 
    307 XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
    308   auto builder = HloComputation::Builder(TestName());
    309   auto hlo_module = CreateNewVerifiedModule();
    310   auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant(
    311       LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
    312   auto const_array = builder.AddInstruction(HloInstruction::CreateConstant(
    313       LiteralUtil::CreateR2<float>({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}})));
    314   auto broadcast = builder.AddInstruction(
    315       HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1}));
    316   // add2 = broadcast(const_vector) + const_array
    317   //      = broadcast({1,2,3}) + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}
    318   //      = {{1, 2, 3}, {1, 2, 3}} + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}
    319   auto add2 = builder.AddInstruction(
    320       HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {2, 3}),
    321                                    HloOpcode::kAdd, broadcast, const_array));
    322   hlo_module->AddEntryComputation(builder.Build())
    323       ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast},
    324                                 HloInstruction::FusionKind::kLoop);
    325 
    326   EXPECT_TRUE(LiteralTestUtil::Near(
    327       LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
    328       ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
    329 }
    330 
    331 XLA_TEST_F(FusionTest, ReshapeToScalar) {
    332   auto builder = HloComputation::Builder(TestName());
    333   auto hlo_module = CreateNewVerifiedModule();
    334   auto single_element_array = builder.AddInstruction(
    335       HloInstruction::CreateConstant(LiteralUtil::CreateR2<int32>({{5}})));
    336   auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
    337       ShapeUtil::MakeShape(S32, {}), single_element_array));
    338   hlo_module->AddEntryComputation(builder.Build())
    339       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
    340                                 HloInstruction::FusionKind::kLoop);
    341   EXPECT_TRUE(
    342       LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(5),
    343                              ExecuteAndTransfer(std::move(hlo_module), {})));
    344 }
    345 
    346 XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
    347   auto builder = HloComputation::Builder(TestName());
    348   auto hlo_module = CreateNewVerifiedModule();
    349   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    350       LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}})));
    351   auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
    352       ShapeUtil::MakeShape(S32, {1, 2, 3}), const0));
    353   hlo_module->AddEntryComputation(builder.Build())
    354       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
    355                                 HloInstruction::FusionKind::kLoop);
    356   EXPECT_TRUE(LiteralTestUtil::Equal(
    357       LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
    358       ExecuteAndTransfer(std::move(hlo_module), {})));
    359 }
    360 
    361 XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
    362   auto builder = HloComputation::Builder(TestName());
    363   auto hlo_module = CreateNewVerifiedModule();
    364   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    365       LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}})));
    366   auto reshape1 = builder.AddInstruction(
    367       HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0));
    368   hlo_module->AddEntryComputation(builder.Build())
    369       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
    370                                 HloInstruction::FusionKind::kLoop);
    371   EXPECT_TRUE(LiteralTestUtil::Equal(
    372       LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
    373       ExecuteAndTransfer(std::move(hlo_module), {})));
    374 }
    375 
    376 XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
    377   auto builder = HloComputation::Builder(TestName());
    378   auto hlo_module = CreateNewVerifiedModule();
    379   auto const0 = builder.AddInstruction(
    380       HloInstruction::CreateConstant(LiteralUtil::CreateR3<int32>({{{7}}})));
    381   auto reshape1 = builder.AddInstruction(
    382       HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
    383   hlo_module->AddEntryComputation(builder.Build())
    384       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
    385                                 HloInstruction::FusionKind::kLoop);
    386   EXPECT_TRUE(
    387       LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7),
    388                              ExecuteAndTransfer(std::move(hlo_module), {})));
    389 }
    390 
    391 XLA_TEST_F(FusionTest, Reshape__1by1by1) {
    392   auto builder = HloComputation::Builder(TestName());
    393   auto hlo_module = CreateNewVerifiedModule();
    394   auto const0 = builder.AddInstruction(
    395       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
    396   auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
    397       ShapeUtil::MakeShape(S32, {1, 1, 1}), const0));
    398   hlo_module->AddEntryComputation(builder.Build())
    399       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
    400                                 HloInstruction::FusionKind::kLoop);
    401   EXPECT_TRUE(
    402       LiteralTestUtil::Equal(LiteralUtil::CreateR3<int32>({{{7}}}),
    403                              ExecuteAndTransfer(std::move(hlo_module), {})));
    404 }
    405 
    406 XLA_TEST_F(FusionTest, Reshape__) {
    407   auto builder = HloComputation::Builder(TestName());
    408   auto hlo_module = CreateNewVerifiedModule();
    409   auto const0 = builder.AddInstruction(
    410       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
    411   auto reshape1 = builder.AddInstruction(
    412       HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
    413   hlo_module->AddEntryComputation(builder.Build())
    414       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
    415                                 HloInstruction::FusionKind::kLoop);
    416   EXPECT_TRUE(
    417       LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7),
    418                              ExecuteAndTransfer(std::move(hlo_module), {})));
    419 }
    420 
    421 XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
    422   auto builder = HloComputation::Builder(TestName());
    423   auto hlo_module = CreateNewVerifiedModule();
    424   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    425       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
    426   auto reshape1 = builder.AddInstruction(
    427       HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0));
    428   hlo_module->AddEntryComputation(builder.Build())
    429       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
    430                                 HloInstruction::FusionKind::kLoop);
    431   EXPECT_TRUE(LiteralTestUtil::Equal(
    432       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
    433       ExecuteAndTransfer(std::move(hlo_module), {})));
    434 }
    435 
    436 XLA_TEST_F(FusionTest, Transpose_2by3) {
    437   auto builder = HloComputation::Builder(TestName());
    438   auto hlo_module = CreateNewVerifiedModule();
    439   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    440       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}})));
    441   auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
    442       ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0}));
    443   hlo_module->AddEntryComputation(builder.Build())
    444       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
    445                                 HloInstruction::FusionKind::kLoop);
    446   EXPECT_TRUE(LiteralTestUtil::Equal(
    447       LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
    448       ExecuteAndTransfer(std::move(hlo_module), {})));
    449 }
    450 
    451 XLA_TEST_F(FusionTest, Transpose_3by3) {
    452   auto builder = HloComputation::Builder(TestName());
    453   auto hlo_module = CreateNewVerifiedModule();
    454   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    455       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
    456   auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
    457       ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0}));
    458   hlo_module->AddEntryComputation(builder.Build())
    459       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
    460                                 HloInstruction::FusionKind::kLoop);
    461   EXPECT_TRUE(LiteralTestUtil::Equal(
    462       LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
    463       ExecuteAndTransfer(std::move(hlo_module), {})));
    464 }
    465 
    466 XLA_TEST_F(FusionTest, Reverse) {
    467   auto builder = HloComputation::Builder(TestName());
    468   auto hlo_module = CreateNewVerifiedModule();
    469   auto const0 = builder.AddInstruction(
    470       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3})));
    471   auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
    472       ShapeUtil::MakeShape(S32, {3}), const0, {0}));
    473   hlo_module->AddEntryComputation(builder.Build())
    474       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1},
    475                                 HloInstruction::FusionKind::kLoop);
    476 
    477   EXPECT_TRUE(
    478       LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({3, 2, 1}),
    479                              ExecuteAndTransfer(std::move(hlo_module), {})));
    480 }
    481 
    482 XLA_TEST_F(FusionTest, ReverseNegate) {
    483   auto builder = HloComputation::Builder(TestName());
    484   auto hlo_module = CreateNewVerifiedModule();
    485   auto const0 = builder.AddInstruction(
    486       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3})));
    487   auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
    488       ShapeUtil::MakeShape(S32, {3}), const0, {0}));
    489   auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
    490       ShapeUtil::MakeShape(S32, {3}), HloOpcode::kNegate, reverse1));
    491   hlo_module->AddEntryComputation(builder.Build())
    492       ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1},
    493                                 HloInstruction::FusionKind::kLoop);
    494 
    495   EXPECT_TRUE(
    496       LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-3, -2, -1}),
    497                              ExecuteAndTransfer(std::move(hlo_module), {})));
    498 }
    499 
    500 XLA_TEST_F(FusionTest, BroadcastNegate) {
    501   auto builder = HloComputation::Builder(TestName());
    502   auto hlo_module = CreateNewVerifiedModule();
    503   auto const0 = builder.AddInstruction(
    504       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
    505   auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
    506       ShapeUtil::MakeShape(S32, {2}), const0, {}));
    507   auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
    508       ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, broadcast1));
    509   hlo_module->AddEntryComputation(builder.Build())
    510       ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1},
    511                                 HloInstruction::FusionKind::kLoop);
    512 
    513   EXPECT_TRUE(
    514       LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -1}),
    515                              ExecuteAndTransfer(std::move(hlo_module), {})));
    516 }
    517 
    518 XLA_TEST_F(FusionTest, SliceNegate) {
    519   auto builder = HloComputation::Builder(TestName());
    520   auto hlo_module = CreateNewVerifiedModule();
    521   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    522       LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
    523   auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
    524       ShapeUtil::MakeShape(S32, {2}), const0, {0}, {4}, {2}));
    525   auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
    526       ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, slice1));
    527   hlo_module->AddEntryComputation(builder.Build())
    528       ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1},
    529                                 HloInstruction::FusionKind::kLoop);
    530 
    531   EXPECT_TRUE(
    532       LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -3}),
    533                              ExecuteAndTransfer(std::move(hlo_module), {})));
    534 }
    535 
    536 XLA_TEST_F(FusionTest, DynamicSliceNegate) {
    537   auto builder = HloComputation::Builder(TestName());
    538   auto hlo_module = CreateNewVerifiedModule();
    539   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    540       LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
    541   auto const1 = builder.AddInstruction(
    542       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
    543   auto dynamic_slice2 =
    544       builder.AddInstruction(HloInstruction::CreateDynamicSlice(
    545           ShapeUtil::MakeShape(S32, {2}), const0, {const1}, {2}));
    546   auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary(
    547       ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, dynamic_slice2));
    548   hlo_module->AddEntryComputation(builder.Build())
    549       ->CreateFusionInstruction(
    550           /*instructions_to_fuse=*/{negate3, dynamic_slice2},
    551           HloInstruction::FusionKind::kLoop);
    552 
    553   EXPECT_TRUE(
    554       LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-2, -3}),
    555                              ExecuteAndTransfer(std::move(hlo_module), {})));
    556 }
    557 
    558 XLA_TEST_F(FusionTest, ReshapeNegate) {
    559   auto builder = HloComputation::Builder(TestName());
    560   auto hlo_module = CreateNewVerifiedModule();
    561   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    562       LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
    563   auto reshape1 = builder.AddInstruction(
    564       HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {2, 2}), const0));
    565   auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
    566       ShapeUtil::MakeShape(S32, {2, 2}), HloOpcode::kNegate, reshape1));
    567   hlo_module->AddEntryComputation(builder.Build())
    568       ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
    569                                 HloInstruction::FusionKind::kLoop);
    570 
    571   EXPECT_TRUE(
    572       LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -2}, {-3, -4}}),
    573                              ExecuteAndTransfer(std::move(hlo_module), {})));
    574 }
    575 
    576 XLA_TEST_F(FusionTest, TransposeNegate) {
    577   auto builder = HloComputation::Builder(TestName());
    578   auto hlo_module = CreateNewVerifiedModule();
    579   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    580       LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}})));
    581   auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose(
    582       ShapeUtil::MakeShape(S32, {2, 2}), const0, {1, 0}));
    583   auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
    584       ShapeUtil::MakeShape(S32, {2, 2}), HloOpcode::kNegate, transpose1));
    585   hlo_module->AddEntryComputation(builder.Build())
    586       ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
    587                                 HloInstruction::FusionKind::kLoop);
    588 
    589   EXPECT_TRUE(
    590       LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -3}, {-2, -4}}),
    591                              ExecuteAndTransfer(std::move(hlo_module), {})));
    592 }
    593 
    594 std::unique_ptr<HloComputation> MakeReduceTestComputation() {
    595   auto builder = HloComputation::Builder("add");
    596   auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
    597       /*parameter_number=*/0, ShapeUtil::MakeShape(S32, {}), "lhs"));
    598   auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
    599       /*parameter_number=*/1, ShapeUtil::MakeShape(S32, {}), "rhs"));
    600   builder.AddInstruction(HloInstruction::CreateBinary(
    601       ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, lhs, rhs));
    602   return builder.Build();
    603 }
    604 
    605 XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
    606   auto hlo_module = CreateNewVerifiedModule();
    607 
    608   auto builder = HloComputation::Builder(TestName());
    609   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    610       LiteralUtil::CreateR1<int32>({1, 2, 4, 8})));
    611   auto const1 = builder.AddInstruction(
    612       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
    613   auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
    614       ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
    615       hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
    616   hlo_module->AddEntryComputation(builder.Build())
    617       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2},
    618                                 HloInstruction::FusionKind::kInput);
    619 
    620   EXPECT_TRUE(
    621       LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(15),
    622                              ExecuteAndTransfer(std::move(hlo_module), {})));
    623 }
    624 
    625 XLA_TEST_F(FusionTest, ReduceImplicitBroadcast) {
    626   auto hlo_module = CreateNewVerifiedModule();
    627 
    628   auto builder = HloComputation::Builder(TestName());
    629   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    630       LiteralUtil::CreateR1<int32>({1, 2, 4, 8})));
    631   auto const1 = builder.AddInstruction(
    632       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
    633   auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
    634       ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
    635       hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
    636   auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary(
    637       ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, reduce2));
    638   hlo_module->AddEntryComputation(builder.Build())
    639       ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2},
    640                                 HloInstruction::FusionKind::kLoop);
    641 
    642   EXPECT_TRUE(
    643       LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(-15),
    644                              ExecuteAndTransfer(std::move(hlo_module), {})));
    645 }
    646 
    647 XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
    648   auto builder = HloComputation::Builder(TestName());
    649   auto hlo_module = CreateNewVerifiedModule();
    650   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
    651       LiteralUtil::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}})));
    652   auto const1 = builder.AddInstruction(
    653       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
    654   Window window;
    655   ASSERT_TRUE(
    656       tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n"
    657                                                         "size:2\n"
    658                                                         "stride:1\n"
    659                                                         "padding_low:0\n"
    660                                                         "padding_high:0\n"
    661                                                         "window_dilation:1\n"
    662                                                         "base_dilation:1\n"
    663                                                         "}\n"
    664                                                         "dimensions:{\n"
    665                                                         "size:2\n"
    666                                                         "stride:1\n"
    667                                                         "padding_low:0\n"
    668                                                         "padding_high:0\n"
    669                                                         "window_dilation:1\n"
    670                                                         "base_dilation:1\n"
    671                                                         "}\n",
    672                                                         &window));
    673   auto nested_builder = HloComputation::Builder("mul");
    674   {
    675     auto x = nested_builder.AddInstruction(
    676         HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "x"));
    677     auto y = nested_builder.AddInstruction(
    678         HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(S32, {}), "y"));
    679     nested_builder.AddInstruction(HloInstruction::CreateBinary(
    680         ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply, x, y));
    681   }
    682   auto nested_computation =
    683       hlo_module->AddEmbeddedComputation(nested_builder.Build());
    684   auto reduce_window2 =
    685       builder.AddInstruction(HloInstruction::CreateReduceWindow(
    686           ShapeUtil::MakeShape(S32, {2, 2}), const0, const1, window,
    687           nested_computation));
    688   hlo_module->AddEntryComputation(builder.Build())
    689       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2},
    690                                 HloInstruction::FusionKind::kLoop);
    691 
    692   EXPECT_TRUE(LiteralTestUtil::Equal(
    693       LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
    694       ExecuteAndTransfer(std::move(hlo_module), {})));
    695 }
    696 
    697 // When a constant (or other op) which has multiple users is imported
    698 // into a fusion, it should remain shared, rather than being duplicated
    699 // within the fusion.
    700 XLA_TEST_F(FusionTest, SharedConstant) {
    701   auto hlo_module = CreateNewVerifiedModule();
    702 
    703   auto builder = HloComputation::Builder(TestName());
    704   auto const0 = builder.AddInstruction(
    705       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({0})));
    706   auto const1 = builder.AddInstruction(
    707       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
    708   auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
    709       ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0));
    710   auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
    711       ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1));
    712   auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
    713       ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2));
    714   auto add4 = builder.AddInstruction(HloInstruction::CreateBinary(
    715       ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3));
    716   hlo_module->AddEntryComputation(builder.Build())
    717       ->CreateFusionInstruction({add4, add3, add2, add1, const1},
    718                                 HloInstruction::FusionKind::kLoop);
    719 
    720   HloComputation* entry_comp = hlo_module->entry_computation();
    721 
    722   // entry computation contains the constant(0) and the fusion
    723   EXPECT_EQ(entry_comp->instruction_count(), 2);
    724 
    725   // fused instruction contains the constant(2), the parameter, and 4 adds
    726   EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
    727 
    728   EXPECT_TRUE(
    729       LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({8}),
    730                              ExecuteAndTransfer(std::move(hlo_module), {})));
    731 }
    732 
    733 XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
    734 
    735 XLA_TEST_F(FusionTest, Subtract2D) {
    736   TestElementwise2D<float, 2>(HloOpcode::kSubtract);
    737 }
    738 
    739 XLA_TEST_F(FusionTest, Multiply2D) {
    740   TestElementwise2D<float, 2>(HloOpcode::kMultiply);
    741 }
    742 
    743 XLA_TEST_F(FusionTest, Divide2D) {
    744   TestElementwise2D<float, 2>(HloOpcode::kDivide);
    745 }
    746 
    747 XLA_TEST_F(FusionTest, Power2D) {
    748   TestElementwise2D<float, 2>(HloOpcode::kPower);
    749 }
    750 
    751 XLA_TEST_F(FusionTest, Minimum2D) {
    752   TestElementwise2D<float, 2>(HloOpcode::kMinimum);
    753 }
    754 
    755 XLA_TEST_F(FusionTest, Maximum2D) {
    756   TestElementwise2D<float, 2>(HloOpcode::kMaximum);
    757 }
    758 
    759 XLA_TEST_F(FusionTest, Equal2D) {
    760   TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kEq);
    761 }
    762 
    763 XLA_TEST_F(FusionTest, Inequal2D) {
    764   TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kNe);
    765 }
    766 
    767 XLA_TEST_F(FusionTest, Greater2D) {
    768   TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kGt);
    769 }
    770 
    771 XLA_TEST_F(FusionTest, Lesser2D) {
    772   TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kLt);
    773 }
    774 
    775 XLA_TEST_F(FusionTest, GreaterOrEqual2D) {
    776   TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kGe);
    777 }
    778 
    779 XLA_TEST_F(FusionTest, LesserOrEqual2D) {
    780   TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kLe);
    781 }
    782 
    783 XLA_TEST_F(FusionTest, Clamp2D) {
    784   TestElementwise2D<float, 3>(HloOpcode::kClamp);
    785 }
    786 
    787 class FusionClientLibraryTest : public ClientLibraryTestBase {};
    788 
    789 XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) {
    790   // On the GPU backend, it's possible to have too many transposes within one
    791   // fusion, causing the kernel to run out shared memory and thus not compile.
    792   // We want to check that doesn't happen.
    793   //
    794   // To do this, we create a computation that computes
    795   //
    796   //   P0 + P0*P1*P1 + P0*P2*P2 ...
    797   //
    798   // where even parameters have layout 1 and odd parameters have layout 2.
    799   //
    800   // Our goal is to tempt the backend into creating one giant multi-output
    801   // fusion for the whole computation, including the transposes.  Currently
    802   // multi-output fusion only fuses fusions, so each of the terms in the sum
    803   // needs to be a fusion itself, thus the contortions above.
    804   constexpr int kNumParams = 25;
    805   XlaBuilder b("ManyLayoutTransformations");
    806 
    807   // This test produces values that overflow int32, which is UB, so use uint32,
    808   // where overflow is OK.
    809   Array2D<uint32> arr(32, 32);
    810   arr.FillUnique();
    811   Literal l1 = LiteralUtil::CreateR2FromArray2D(arr).Relayout(
    812       LayoutUtil::MakeLayout({0, 1}));
    813 
    814   Literal l2 = LiteralUtil::CreateR2FromArray2D(arr).Relayout(
    815       LayoutUtil::MakeLayout({1, 0}));
    816 
    817   XlaOp p0 = AddParam(l1, &b);
    818   XlaOp sum = p0;
    819   for (int i = 1; i < kNumParams; ++i) {
    820     auto pN = AddParam((i % 2 == 0 ? l1 : l2), &b);
    821     sum = sum + p0 * pN * pN;
    822   }
    823 
    824   ComputeAndCompare(&b, {});
    825 }
    826 
    827 void BM_ParallelFusion(int num_iters) {
    828   // Simple element-wise computation to benchmark parallel task partitioning.
    829   tensorflow::testing::StopTiming();
    830 
    831   se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
    832   auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
    833   StreamExecutorMemoryAllocator allocator(platform, executors);
    834 
    835   const int64 intra_op_parallelism_threads = 24;
    836   xla::LocalClientOptions client_options;
    837   client_options.set_platform(platform);
    838   client_options.set_intra_op_parallelism_threads(intra_op_parallelism_threads);
    839   auto client =
    840       ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie();
    841 
    842   int device_ordinal = client->default_device_ordinal();
    843 
    844   // Computation shape parameters.
    845   const int64 param0_dim0 = 1024;
    846   const int64 param0_dim1 = 1024;
    847   const int64 param1_dim0 = 1024;
    848   const int64 param1_dim1 = 1024;
    849   const int64 param2_dim0 = 1024;
    850   const int64 param2_dim1 = 1024;
    851 
    852   // Create computation.
    853   XlaBuilder builder("ParallelFusion");
    854   Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1});
    855   auto param0 = Parameter(&builder, 0, shape0, "param0");
    856   Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1});
    857   auto param1 = Parameter(&builder, 1, shape1, "param1");
    858   Shape shape2 = ShapeUtil::MakeShape(F32, {param2_dim0, param2_dim1});
    859   auto param2 = Parameter(&builder, 2, shape2, "param2");
    860 
    861   auto x = Mul(param0, param1);
    862   Add(x, param2);
    863   auto computation = builder.Build().ConsumeValueOrDie();
    864 
    865   // Transfer literals to device.
    866   auto param0_literal =
    867       LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
    868   ScopedShapedBuffer buffer0 =
    869       client->LiteralToShapedBuffer(param0_literal, device_ordinal)
    870           .ConsumeValueOrDie();
    871 
    872   auto param1_literal =
    873       LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
    874   ScopedShapedBuffer buffer1 =
    875       client->LiteralToShapedBuffer(param1_literal, device_ordinal)
    876           .ConsumeValueOrDie();
    877 
    878   auto param2_literal =
    879       LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
    880   ScopedShapedBuffer buffer2 =
    881       client->LiteralToShapedBuffer(param2_literal, device_ordinal)
    882           .ConsumeValueOrDie();
    883 
    884   // Build executable.
    885   std::unique_ptr<LocalExecutable> executable =
    886       client
    887           ->Compile(computation,
    888                     {&buffer0.on_host_shape(), &buffer1.on_host_shape(),
    889                      &buffer2.on_host_shape()},
    890                     ExecutableBuildOptions())
    891           .ConsumeValueOrDie();
    892 
    893   se::Stream stream(executors[device_ordinal]);
    894   stream.Init();
    895 
    896   // Initialize thread pool.
    897   tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
    898                                       intra_op_parallelism_threads);
    899   tensorflow::EigenThreadPoolWrapper tp(&pool);
    900   Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
    901 
    902   // Initialize ExecutableRunOptions.
    903   ExecutableRunOptions options;
    904   options.set_allocator(&allocator).set_stream(&stream);
    905   options.set_intra_op_thread_pool(&device);
    906 
    907   // Run some warm-up executions.
    908   const int kWarmups = 2;
    909   for (int i = 0; i < kWarmups; ++i) {
    910     auto result = executable->Run({&buffer0, &buffer1, &buffer2}, options);
    911     ASSERT_TRUE(result.ok());
    912   }
    913 
    914   // Run benchmark.
    915   const int64 total_bytes = param0_dim0 * param0_dim0 +
    916                             param1_dim0 * param1_dim0 +
    917                             param2_dim0 * param2_dim0;
    918   tensorflow::testing::BytesProcessed(static_cast<int64>(num_iters) *
    919                                       total_bytes * sizeof(float));
    920   tensorflow::testing::UseRealTime();
    921   tensorflow::testing::StartTiming();
    922   for (int i = 0; i < num_iters; ++i) {
    923     auto result = executable->Run({&buffer0, &buffer1, &buffer2}, options);
    924     ASSERT_TRUE(result.ok());
    925   }
    926 }
    927 
    928 BENCHMARK(BM_ParallelFusion);
    929 
    930 }  // namespace
    931 }  // namespace xla
    932