1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <math.h> 17 #include <algorithm> 18 #include <memory> 19 #include <new> 20 #include <random> 21 #include <utility> 22 23 #define EIGEN_USE_THREADS 24 25 #include "absl/memory/memory.h" 26 #include "absl/types/span.h" 27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 28 #include "tensorflow/compiler/xla/array2d.h" 29 #include "tensorflow/compiler/xla/client/client_library.h" 30 #include "tensorflow/compiler/xla/client/xla_builder.h" 31 #include "tensorflow/compiler/xla/literal.h" 32 #include "tensorflow/compiler/xla/primitive_util.h" 33 #include "tensorflow/compiler/xla/service/hlo_computation.h" 34 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 35 #include "tensorflow/compiler/xla/service/hlo_module.h" 36 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 37 #include "tensorflow/compiler/xla/service/hlo_parser.h" 38 #include "tensorflow/compiler/xla/service/platform_util.h" 39 #include "tensorflow/compiler/xla/shape_util.h" 40 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 41 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 42 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 43 #include "tensorflow/compiler/xla/tests/test_macros.h" 44 #include "tensorflow/compiler/xla/xla_data.pb.h" 45 #include "tensorflow/core/common_runtime/eigen_thread_pool.h" 46 #include "tensorflow/core/platform/logging.h" 47 #include "tensorflow/core/platform/protobuf.h" 48 #include "tensorflow/core/platform/test_benchmark.h" 49 #include "tensorflow/core/platform/types.h" 50 51 namespace xla { 52 namespace { 53 54 const int test_width = 2, test_height = 3; 55 56 const float test_float_vals[3][test_width][test_height] = { 57 {{-1.0, -1.0, 1.0}, {-3.0, 0.0, -1.0}}, 58 {{-3.0, 2.0, 1.0}, {0.0, -3.0, 1.0}}, 59 {{-3.0, 0.0, -3.0}, {-1.0, -2.0, 1.0}}}; 60 61 // Test whether fusion operations are emitted with no errors and compute 62 // accurate outputs. 63 class FusionTest : public HloTestBase { 64 protected: 65 template <typename T, int Arity> 66 void TestElementwise2D( 67 HloOpcode opcode, 68 absl::optional<ComparisonDirection> direction = absl::nullopt) { 69 // Create a variable for comparisons since they require the direction. 70 bool is_compare = std::is_same<T, bool>::value; 71 Array2D<float> operand_data[Arity]; 72 for (int i = 0; i < Arity; ++i) { 73 new (&operand_data[i]) Array2D<float>(test_width, test_height); 74 } 75 Array2D<T> answer_data(test_width, test_height); 76 for (int i = 0; i < test_width; ++i) { 77 for (int j = 0; j < test_height; ++j) { 78 float xs[Arity]; 79 for (int k = 0; k < Arity; ++k) { 80 xs[k] = test_float_vals[k][i][j]; 81 operand_data[k](i, j) = xs[k]; 82 } 83 if (is_compare) { 84 answer_data(i, j) = ComputeElementwiseAnswerCompare(*direction, xs); 85 } else { 86 answer_data(i, j) = ComputeElementwiseAnswerFloat(opcode, xs); 87 } 88 } 89 } 90 91 auto builder = HloComputation::Builder(TestName()); 92 auto hlo_module = CreateNewVerifiedModule(); 93 94 auto prim_type = primitive_util::NativeToPrimitiveType<T>(); 95 96 HloInstruction* hlos[4]; 97 for (int i = 0; i < Arity; ++i) { 98 hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant( 99 LiteralUtil::CreateR2FromArray2D(operand_data[i]))); 100 } 101 auto answer_shape = 102 ShapeUtil::MakeShape(prim_type, {test_width, test_height}); 103 std::unique_ptr<HloInstruction> root_hlo; 104 switch (Arity) { 105 case 1: 106 root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]); 107 break; 108 case 2: 109 if (is_compare) { 110 root_hlo = HloInstruction::CreateCompare(answer_shape, hlos[1], 111 hlos[2], *direction); 112 } else { 113 root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1], 114 hlos[2]); 115 } 116 break; 117 case 3: 118 root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1], 119 hlos[2], hlos[3]); 120 break; 121 default: 122 LOG(FATAL) << "Bad arity: " << Arity; 123 } 124 hlos[0] = builder.AddInstruction(std::move(root_hlo)); 125 hlo_module->AddEntryComputation(builder.Build()) 126 ->CreateFusionInstruction( 127 absl::Span<HloInstruction* const>(hlos).subspan(0, Arity + 1), 128 HloInstruction::FusionKind::kLoop); 129 130 auto expected = LiteralUtil::CreateR2FromArray2D(answer_data); 131 auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); 132 if (primitive_util::IsFloatingPointType(prim_type)) { 133 EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, ErrorSpec(1e-4))); 134 } else { 135 EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); 136 } 137 } 138 139 private: 140 float ComputeElementwiseAnswerFloat(HloOpcode opcode, 141 absl::Span<const float> xs); 142 bool ComputeElementwiseAnswerCompare(ComparisonDirection direction, 143 absl::Span<const float> xs); 144 DebugOptions GetDebugOptionsForTest() override { 145 DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); 146 debug_options.add_xla_disable_hlo_passes("layout-assignment"); 147 return debug_options; 148 } 149 }; 150 151 float FusionTest::ComputeElementwiseAnswerFloat(HloOpcode opcode, 152 absl::Span<const float> xs) { 153 switch (opcode) { 154 case HloOpcode::kAdd: 155 return xs[0] + xs[1]; 156 case HloOpcode::kSubtract: 157 return xs[0] - xs[1]; 158 case HloOpcode::kMultiply: 159 return xs[0] * xs[1]; 160 case HloOpcode::kDivide: 161 return xs[0] / xs[1]; 162 case HloOpcode::kPower: 163 return powf(xs[0], xs[1]); 164 case HloOpcode::kMinimum: 165 return std::min(xs[0], xs[1]); 166 case HloOpcode::kMaximum: 167 return std::max(xs[0], xs[1]); 168 case HloOpcode::kClamp: 169 return std::min(xs[2], std::max(xs[1], xs[0])); 170 default: 171 LOG(FATAL) << "No elementwise opcode: " << opcode; 172 } 173 } 174 175 bool FusionTest::ComputeElementwiseAnswerCompare(ComparisonDirection direction, 176 absl::Span<const float> xs) { 177 switch (direction) { 178 case ComparisonDirection::kEq: 179 return xs[0] == xs[1]; 180 case ComparisonDirection::kNe: 181 return xs[0] != xs[1]; 182 case ComparisonDirection::kGt: 183 return xs[0] > xs[1]; 184 case ComparisonDirection::kLt: 185 return xs[0] < xs[1]; 186 case ComparisonDirection::kGe: 187 return xs[0] >= xs[1]; 188 case ComparisonDirection::kLe: 189 return xs[0] <= xs[1]; 190 } 191 } 192 193 XLA_TEST_F(FusionTest, Test) { 194 // test expression: 195 // slice(select({{T, F, T}, {F, T, F}}, 196 // concat(transpose({{1.0}, {2.0}, {3.0}} + 197 // {{-1.0}, {-1.0}, {-1.0}}), 198 // {{1.62, 2.72, 3.14}}) + 199 // (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}), 200 // {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}} 201 auto builder = HloComputation::Builder(TestName()); 202 auto hlo_module = CreateNewVerifiedModule(); 203 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 204 LiteralUtil::CreateR2<float>({{1.0}, {2.0}, {3.0}}))); 205 auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( 206 LiteralUtil::CreateR2<float>({{-1.0}, {-1.0}, {-1.0}}))); 207 auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( 208 ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1)); 209 auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose( 210 ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0})); 211 auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( 212 LiteralUtil::CreateR2<float>({{1.62, 2.72, 3.14}}))); 213 auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate( 214 ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0)); 215 auto const6 = builder.AddInstruction(HloInstruction::CreateConstant( 216 LiteralUtil::CreateR2<float>({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}))); 217 auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary( 218 ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6)); 219 auto add8 = builder.AddInstruction(HloInstruction::CreateBinary( 220 ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7)); 221 auto const9 = builder.AddInstruction(HloInstruction::CreateConstant( 222 LiteralUtil::CreateR2<float>({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}}))); 223 auto const10 = builder.AddInstruction( 224 HloInstruction::CreateConstant(LiteralUtil::CreateR2<bool>( 225 {{true, false, true}, {false, true, false}}))); 226 auto select11 = builder.AddInstruction( 227 HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}), 228 HloOpcode::kSelect, const10, add8, const9)); 229 auto slice12 = builder.AddInstruction(HloInstruction::CreateSlice( 230 ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2}, {1, 1})); 231 // CreateFusionInstruction needs the `instructions_to_fuse` argument in 232 // reverse topological order, so the first element in `instructions_to_fuse` 233 // must be the root. 234 hlo_module->AddEntryComputation(builder.Build()) 235 ->CreateFusionInstruction( 236 {slice12, select11, const10, const9, add8, negate7, const6, concat5, 237 const4, reshape3, add2, const1, const0}, 238 HloInstruction::FusionKind::kLoop); 239 240 EXPECT_TRUE(LiteralTestUtil::Near( 241 LiteralUtil::CreateR2<float>({{0.5}, {2.72}}), 242 ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); 243 } 244 245 // Test whether we emit appropriate code for parameters of fusion instructions. 246 XLA_TEST_F(FusionTest, Parameter) { 247 // Build a computation and fuse part of it so the fusion instruction has an 248 // operand parameter. 249 auto builder = HloComputation::Builder(TestName()); 250 auto hlo_module = CreateNewVerifiedModule(); 251 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 252 LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}}))); 253 auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary( 254 ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0)); 255 auto const2 = builder.AddInstruction(HloInstruction::CreateConstant( 256 LiteralUtil::CreateR2<float>({{-2.0, -2.0, -2.0}}))); 257 // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1} 258 auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( 259 ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2)); 260 // CreateFusionInstruction needs `instructions_to_fuse` in reverse topological 261 // order. 262 hlo_module->AddEntryComputation(builder.Build()) 263 ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2}, 264 HloInstruction::FusionKind::kLoop); 265 266 EXPECT_TRUE(LiteralTestUtil::Near( 267 LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}), 268 ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); 269 } 270 271 XLA_TEST_F(FusionTest, RandomizedParallelPartition) { 272 // Tests parallel partitioning of a fusion instruction. 273 // Create shape with random outer dimension size to generate random parallel 274 // partition counts for each test run. 275 const int seed = tensorflow::testing::RandomSeed(); 276 LOG(INFO) << "RandomizedParallelPartition seed: " << seed; 277 std::mt19937 generator(seed); 278 std::uniform_int_distribution<int> distribution(128, 1024); 279 const int64 rand_dim0_size = distribution(generator); 280 const int64 dim1_size = 1024; 281 Shape shape = 282 ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0}); 283 // Build simple fusion computation: y = x^2 (elementwise). 284 auto builder = HloComputation::Builder(TestName()); 285 auto hlo_module = CreateNewVerifiedModule(); 286 287 auto two = builder.AddInstruction( 288 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))); 289 auto x = 290 builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {})); 291 auto y = builder.AddInstruction( 292 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, x, x)); 293 294 hlo_module->AddEntryComputation(builder.Build()) 295 ->CreateFusionInstruction(/*instructions_to_fuse=*/{y, x, two}, 296 HloInstruction::FusionKind::kLoop); 297 // Compute result. 298 auto result = ExecuteAndTransfer(std::move(hlo_module), {}); 299 // Every element of result should be y = x^2 = 4.0. 300 for (int i = 0; i < rand_dim0_size; ++i) { 301 for (int j = 0; j < dim1_size; ++j) { 302 EXPECT_EQ(4.0, result.Get<float>({i, j})); 303 } 304 } 305 } 306 307 XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { 308 auto builder = HloComputation::Builder(TestName()); 309 auto hlo_module = CreateNewVerifiedModule(); 310 auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( 311 LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}))); 312 auto const_array = builder.AddInstruction(HloInstruction::CreateConstant( 313 LiteralUtil::CreateR2<float>({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}))); 314 auto broadcast = builder.AddInstruction( 315 HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1})); 316 // add2 = broadcast(const_vector) + const_array 317 // = broadcast({1,2,3}) + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}} 318 // = {{1, 2, 3}, {1, 2, 3}} + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}} 319 auto add2 = builder.AddInstruction( 320 HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {2, 3}), 321 HloOpcode::kAdd, broadcast, const_array)); 322 hlo_module->AddEntryComputation(builder.Build()) 323 ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast}, 324 HloInstruction::FusionKind::kLoop); 325 326 EXPECT_TRUE(LiteralTestUtil::Near( 327 LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), 328 ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); 329 } 330 331 XLA_TEST_F(FusionTest, ReshapeToScalar) { 332 auto builder = HloComputation::Builder(TestName()); 333 auto hlo_module = CreateNewVerifiedModule(); 334 auto single_element_array = builder.AddInstruction( 335 HloInstruction::CreateConstant(LiteralUtil::CreateR2<int32>({{5}}))); 336 auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( 337 ShapeUtil::MakeShape(S32, {}), single_element_array)); 338 hlo_module->AddEntryComputation(builder.Build()) 339 ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, 340 HloInstruction::FusionKind::kLoop); 341 EXPECT_TRUE( 342 LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(5), 343 ExecuteAndTransfer(std::move(hlo_module), {}))); 344 } 345 346 XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { 347 auto builder = HloComputation::Builder(TestName()); 348 auto hlo_module = CreateNewVerifiedModule(); 349 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 350 LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}))); 351 auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( 352 ShapeUtil::MakeShape(S32, {1, 2, 3}), const0)); 353 hlo_module->AddEntryComputation(builder.Build()) 354 ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 355 HloInstruction::FusionKind::kLoop); 356 EXPECT_TRUE(LiteralTestUtil::Equal( 357 LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}), 358 ExecuteAndTransfer(std::move(hlo_module), {}))); 359 } 360 361 XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { 362 auto builder = HloComputation::Builder(TestName()); 363 auto hlo_module = CreateNewVerifiedModule(); 364 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 365 LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}))); 366 auto reshape1 = builder.AddInstruction( 367 HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0)); 368 hlo_module->AddEntryComputation(builder.Build()) 369 ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 370 HloInstruction::FusionKind::kLoop); 371 EXPECT_TRUE(LiteralTestUtil::Equal( 372 LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}), 373 ExecuteAndTransfer(std::move(hlo_module), {}))); 374 } 375 376 XLA_TEST_F(FusionTest, Reshape_1by1by1_) { 377 auto builder = HloComputation::Builder(TestName()); 378 auto hlo_module = CreateNewVerifiedModule(); 379 auto const0 = builder.AddInstruction( 380 HloInstruction::CreateConstant(LiteralUtil::CreateR3<int32>({{{7}}}))); 381 auto reshape1 = builder.AddInstruction( 382 HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); 383 hlo_module->AddEntryComputation(builder.Build()) 384 ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 385 HloInstruction::FusionKind::kLoop); 386 EXPECT_TRUE( 387 LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7), 388 ExecuteAndTransfer(std::move(hlo_module), {}))); 389 } 390 391 XLA_TEST_F(FusionTest, Reshape__1by1by1) { 392 auto builder = HloComputation::Builder(TestName()); 393 auto hlo_module = CreateNewVerifiedModule(); 394 auto const0 = builder.AddInstruction( 395 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7))); 396 auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( 397 ShapeUtil::MakeShape(S32, {1, 1, 1}), const0)); 398 hlo_module->AddEntryComputation(builder.Build()) 399 ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 400 HloInstruction::FusionKind::kLoop); 401 EXPECT_TRUE( 402 LiteralTestUtil::Equal(LiteralUtil::CreateR3<int32>({{{7}}}), 403 ExecuteAndTransfer(std::move(hlo_module), {}))); 404 } 405 406 XLA_TEST_F(FusionTest, Reshape__) { 407 auto builder = HloComputation::Builder(TestName()); 408 auto hlo_module = CreateNewVerifiedModule(); 409 auto const0 = builder.AddInstruction( 410 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7))); 411 auto reshape1 = builder.AddInstruction( 412 HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); 413 hlo_module->AddEntryComputation(builder.Build()) 414 ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 415 HloInstruction::FusionKind::kLoop); 416 EXPECT_TRUE( 417 LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7), 418 ExecuteAndTransfer(std::move(hlo_module), {}))); 419 } 420 421 XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { 422 auto builder = HloComputation::Builder(TestName()); 423 auto hlo_module = CreateNewVerifiedModule(); 424 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 425 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); 426 auto reshape1 = builder.AddInstruction( 427 HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0)); 428 hlo_module->AddEntryComputation(builder.Build()) 429 ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 430 HloInstruction::FusionKind::kLoop); 431 EXPECT_TRUE(LiteralTestUtil::Equal( 432 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), 433 ExecuteAndTransfer(std::move(hlo_module), {}))); 434 } 435 436 XLA_TEST_F(FusionTest, Transpose_2by3) { 437 auto builder = HloComputation::Builder(TestName()); 438 auto hlo_module = CreateNewVerifiedModule(); 439 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 440 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}}))); 441 auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( 442 ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0})); 443 hlo_module->AddEntryComputation(builder.Build()) 444 ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 445 HloInstruction::FusionKind::kLoop); 446 EXPECT_TRUE(LiteralTestUtil::Equal( 447 LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}), 448 ExecuteAndTransfer(std::move(hlo_module), {}))); 449 } 450 451 XLA_TEST_F(FusionTest, Transpose_3by3) { 452 auto builder = HloComputation::Builder(TestName()); 453 auto hlo_module = CreateNewVerifiedModule(); 454 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 455 LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); 456 auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( 457 ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0})); 458 hlo_module->AddEntryComputation(builder.Build()) 459 ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, 460 HloInstruction::FusionKind::kLoop); 461 EXPECT_TRUE(LiteralTestUtil::Equal( 462 LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), 463 ExecuteAndTransfer(std::move(hlo_module), {}))); 464 } 465 466 XLA_TEST_F(FusionTest, Reverse) { 467 auto builder = HloComputation::Builder(TestName()); 468 auto hlo_module = CreateNewVerifiedModule(); 469 auto const0 = builder.AddInstruction( 470 HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3}))); 471 auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( 472 ShapeUtil::MakeShape(S32, {3}), const0, {0})); 473 hlo_module->AddEntryComputation(builder.Build()) 474 ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1}, 475 HloInstruction::FusionKind::kLoop); 476 477 EXPECT_TRUE( 478 LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({3, 2, 1}), 479 ExecuteAndTransfer(std::move(hlo_module), {}))); 480 } 481 482 XLA_TEST_F(FusionTest, ReverseNegate) { 483 auto builder = HloComputation::Builder(TestName()); 484 auto hlo_module = CreateNewVerifiedModule(); 485 auto const0 = builder.AddInstruction( 486 HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3}))); 487 auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( 488 ShapeUtil::MakeShape(S32, {3}), const0, {0})); 489 auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( 490 ShapeUtil::MakeShape(S32, {3}), HloOpcode::kNegate, reverse1)); 491 hlo_module->AddEntryComputation(builder.Build()) 492 ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1}, 493 HloInstruction::FusionKind::kLoop); 494 495 EXPECT_TRUE( 496 LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-3, -2, -1}), 497 ExecuteAndTransfer(std::move(hlo_module), {}))); 498 } 499 500 XLA_TEST_F(FusionTest, BroadcastNegate) { 501 auto builder = HloComputation::Builder(TestName()); 502 auto hlo_module = CreateNewVerifiedModule(); 503 auto const0 = builder.AddInstruction( 504 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1))); 505 auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast( 506 ShapeUtil::MakeShape(S32, {2}), const0, {})); 507 auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( 508 ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, broadcast1)); 509 hlo_module->AddEntryComputation(builder.Build()) 510 ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1}, 511 HloInstruction::FusionKind::kLoop); 512 513 EXPECT_TRUE( 514 LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -1}), 515 ExecuteAndTransfer(std::move(hlo_module), {}))); 516 } 517 518 XLA_TEST_F(FusionTest, SliceNegate) { 519 auto builder = HloComputation::Builder(TestName()); 520 auto hlo_module = CreateNewVerifiedModule(); 521 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 522 LiteralUtil::CreateR1<int32>({1, 2, 3, 4}))); 523 auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice( 524 ShapeUtil::MakeShape(S32, {2}), const0, {0}, {4}, {2})); 525 auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( 526 ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, slice1)); 527 hlo_module->AddEntryComputation(builder.Build()) 528 ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1}, 529 HloInstruction::FusionKind::kLoop); 530 531 EXPECT_TRUE( 532 LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -3}), 533 ExecuteAndTransfer(std::move(hlo_module), {}))); 534 } 535 536 XLA_TEST_F(FusionTest, DynamicSliceNegate) { 537 auto builder = HloComputation::Builder(TestName()); 538 auto hlo_module = CreateNewVerifiedModule(); 539 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 540 LiteralUtil::CreateR1<int32>({1, 2, 3, 4}))); 541 auto const1 = builder.AddInstruction( 542 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1))); 543 auto dynamic_slice2 = 544 builder.AddInstruction(HloInstruction::CreateDynamicSlice( 545 ShapeUtil::MakeShape(S32, {2}), const0, {const1}, {2})); 546 auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary( 547 ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, dynamic_slice2)); 548 hlo_module->AddEntryComputation(builder.Build()) 549 ->CreateFusionInstruction( 550 /*instructions_to_fuse=*/{negate3, dynamic_slice2}, 551 HloInstruction::FusionKind::kLoop); 552 553 EXPECT_TRUE( 554 LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-2, -3}), 555 ExecuteAndTransfer(std::move(hlo_module), {}))); 556 } 557 558 XLA_TEST_F(FusionTest, ReshapeNegate) { 559 auto builder = HloComputation::Builder(TestName()); 560 auto hlo_module = CreateNewVerifiedModule(); 561 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 562 LiteralUtil::CreateR1<int32>({1, 2, 3, 4}))); 563 auto reshape1 = builder.AddInstruction( 564 HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {2, 2}), const0)); 565 auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( 566 ShapeUtil::MakeShape(S32, {2, 2}), HloOpcode::kNegate, reshape1)); 567 hlo_module->AddEntryComputation(builder.Build()) 568 ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1}, 569 HloInstruction::FusionKind::kLoop); 570 571 EXPECT_TRUE( 572 LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -2}, {-3, -4}}), 573 ExecuteAndTransfer(std::move(hlo_module), {}))); 574 } 575 576 XLA_TEST_F(FusionTest, TransposeNegate) { 577 auto builder = HloComputation::Builder(TestName()); 578 auto hlo_module = CreateNewVerifiedModule(); 579 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 580 LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}))); 581 auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( 582 ShapeUtil::MakeShape(S32, {2, 2}), const0, {1, 0})); 583 auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( 584 ShapeUtil::MakeShape(S32, {2, 2}), HloOpcode::kNegate, transpose1)); 585 hlo_module->AddEntryComputation(builder.Build()) 586 ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1}, 587 HloInstruction::FusionKind::kLoop); 588 589 EXPECT_TRUE( 590 LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -3}, {-2, -4}}), 591 ExecuteAndTransfer(std::move(hlo_module), {}))); 592 } 593 594 std::unique_ptr<HloComputation> MakeReduceTestComputation() { 595 auto builder = HloComputation::Builder("add"); 596 auto lhs = builder.AddInstruction(HloInstruction::CreateParameter( 597 /*parameter_number=*/0, ShapeUtil::MakeShape(S32, {}), "lhs")); 598 auto rhs = builder.AddInstruction(HloInstruction::CreateParameter( 599 /*parameter_number=*/1, ShapeUtil::MakeShape(S32, {}), "rhs")); 600 builder.AddInstruction(HloInstruction::CreateBinary( 601 ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, lhs, rhs)); 602 return builder.Build(); 603 } 604 605 XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { 606 auto hlo_module = CreateNewVerifiedModule(); 607 608 auto builder = HloComputation::Builder(TestName()); 609 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 610 LiteralUtil::CreateR1<int32>({1, 2, 4, 8}))); 611 auto const1 = builder.AddInstruction( 612 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0))); 613 auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( 614 ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, 615 hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); 616 hlo_module->AddEntryComputation(builder.Build()) 617 ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2}, 618 HloInstruction::FusionKind::kInput); 619 620 EXPECT_TRUE( 621 LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(15), 622 ExecuteAndTransfer(std::move(hlo_module), {}))); 623 } 624 625 XLA_TEST_F(FusionTest, ReduceImplicitBroadcast) { 626 auto hlo_module = CreateNewVerifiedModule(); 627 628 auto builder = HloComputation::Builder(TestName()); 629 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 630 LiteralUtil::CreateR1<int32>({1, 2, 4, 8}))); 631 auto const1 = builder.AddInstruction( 632 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0))); 633 auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( 634 ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, 635 hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); 636 auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary( 637 ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, reduce2)); 638 hlo_module->AddEntryComputation(builder.Build()) 639 ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2}, 640 HloInstruction::FusionKind::kLoop); 641 642 EXPECT_TRUE( 643 LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(-15), 644 ExecuteAndTransfer(std::move(hlo_module), {}))); 645 } 646 647 XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { 648 auto builder = HloComputation::Builder(TestName()); 649 auto hlo_module = CreateNewVerifiedModule(); 650 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 651 LiteralUtil::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); 652 auto const1 = builder.AddInstruction( 653 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1))); 654 Window window; 655 ASSERT_TRUE( 656 tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n" 657 "size:2\n" 658 "stride:1\n" 659 "padding_low:0\n" 660 "padding_high:0\n" 661 "window_dilation:1\n" 662 "base_dilation:1\n" 663 "}\n" 664 "dimensions:{\n" 665 "size:2\n" 666 "stride:1\n" 667 "padding_low:0\n" 668 "padding_high:0\n" 669 "window_dilation:1\n" 670 "base_dilation:1\n" 671 "}\n", 672 &window)); 673 auto nested_builder = HloComputation::Builder("mul"); 674 { 675 auto x = nested_builder.AddInstruction( 676 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "x")); 677 auto y = nested_builder.AddInstruction( 678 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(S32, {}), "y")); 679 nested_builder.AddInstruction(HloInstruction::CreateBinary( 680 ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply, x, y)); 681 } 682 auto nested_computation = 683 hlo_module->AddEmbeddedComputation(nested_builder.Build()); 684 auto reduce_window2 = 685 builder.AddInstruction(HloInstruction::CreateReduceWindow( 686 ShapeUtil::MakeShape(S32, {2, 2}), const0, const1, window, 687 nested_computation)); 688 hlo_module->AddEntryComputation(builder.Build()) 689 ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2}, 690 HloInstruction::FusionKind::kLoop); 691 692 EXPECT_TRUE(LiteralTestUtil::Equal( 693 LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}), 694 ExecuteAndTransfer(std::move(hlo_module), {}))); 695 } 696 697 // When a constant (or other op) which has multiple users is imported 698 // into a fusion, it should remain shared, rather than being duplicated 699 // within the fusion. 700 XLA_TEST_F(FusionTest, SharedConstant) { 701 auto hlo_module = CreateNewVerifiedModule(); 702 703 auto builder = HloComputation::Builder(TestName()); 704 auto const0 = builder.AddInstruction( 705 HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({0}))); 706 auto const1 = builder.AddInstruction( 707 HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2}))); 708 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( 709 ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0)); 710 auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( 711 ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1)); 712 auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( 713 ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2)); 714 auto add4 = builder.AddInstruction(HloInstruction::CreateBinary( 715 ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3)); 716 hlo_module->AddEntryComputation(builder.Build()) 717 ->CreateFusionInstruction({add4, add3, add2, add1, const1}, 718 HloInstruction::FusionKind::kLoop); 719 720 HloComputation* entry_comp = hlo_module->entry_computation(); 721 722 // entry computation contains the constant(0) and the fusion 723 EXPECT_EQ(entry_comp->instruction_count(), 2); 724 725 // fused instruction contains the constant(2), the parameter, and 4 adds 726 EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); 727 728 EXPECT_TRUE( 729 LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({8}), 730 ExecuteAndTransfer(std::move(hlo_module), {}))); 731 } 732 733 XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); } 734 735 XLA_TEST_F(FusionTest, Subtract2D) { 736 TestElementwise2D<float, 2>(HloOpcode::kSubtract); 737 } 738 739 XLA_TEST_F(FusionTest, Multiply2D) { 740 TestElementwise2D<float, 2>(HloOpcode::kMultiply); 741 } 742 743 XLA_TEST_F(FusionTest, Divide2D) { 744 TestElementwise2D<float, 2>(HloOpcode::kDivide); 745 } 746 747 XLA_TEST_F(FusionTest, Power2D) { 748 TestElementwise2D<float, 2>(HloOpcode::kPower); 749 } 750 751 XLA_TEST_F(FusionTest, Minimum2D) { 752 TestElementwise2D<float, 2>(HloOpcode::kMinimum); 753 } 754 755 XLA_TEST_F(FusionTest, Maximum2D) { 756 TestElementwise2D<float, 2>(HloOpcode::kMaximum); 757 } 758 759 XLA_TEST_F(FusionTest, Equal2D) { 760 TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kEq); 761 } 762 763 XLA_TEST_F(FusionTest, Inequal2D) { 764 TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kNe); 765 } 766 767 XLA_TEST_F(FusionTest, Greater2D) { 768 TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kGt); 769 } 770 771 XLA_TEST_F(FusionTest, Lesser2D) { 772 TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kLt); 773 } 774 775 XLA_TEST_F(FusionTest, GreaterOrEqual2D) { 776 TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kGe); 777 } 778 779 XLA_TEST_F(FusionTest, LesserOrEqual2D) { 780 TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kLe); 781 } 782 783 XLA_TEST_F(FusionTest, Clamp2D) { 784 TestElementwise2D<float, 3>(HloOpcode::kClamp); 785 } 786 787 class FusionClientLibraryTest : public ClientLibraryTestBase {}; 788 789 XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) { 790 // On the GPU backend, it's possible to have too many transposes within one 791 // fusion, causing the kernel to run out shared memory and thus not compile. 792 // We want to check that doesn't happen. 793 // 794 // To do this, we create a computation that computes 795 // 796 // P0 + P0*P1*P1 + P0*P2*P2 ... 797 // 798 // where even parameters have layout 1 and odd parameters have layout 2. 799 // 800 // Our goal is to tempt the backend into creating one giant multi-output 801 // fusion for the whole computation, including the transposes. Currently 802 // multi-output fusion only fuses fusions, so each of the terms in the sum 803 // needs to be a fusion itself, thus the contortions above. 804 constexpr int kNumParams = 25; 805 XlaBuilder b("ManyLayoutTransformations"); 806 807 // This test produces values that overflow int32, which is UB, so use uint32, 808 // where overflow is OK. 809 Array2D<uint32> arr(32, 32); 810 arr.FillUnique(); 811 Literal l1 = LiteralUtil::CreateR2FromArray2D(arr).Relayout( 812 LayoutUtil::MakeLayout({0, 1})); 813 814 Literal l2 = LiteralUtil::CreateR2FromArray2D(arr).Relayout( 815 LayoutUtil::MakeLayout({1, 0})); 816 817 XlaOp p0 = AddParam(l1, &b); 818 XlaOp sum = p0; 819 for (int i = 1; i < kNumParams; ++i) { 820 auto pN = AddParam((i % 2 == 0 ? l1 : l2), &b); 821 sum = sum + p0 * pN * pN; 822 } 823 824 ComputeAndCompare(&b, {}); 825 } 826 827 void BM_ParallelFusion(int num_iters) { 828 // Simple element-wise computation to benchmark parallel task partitioning. 829 tensorflow::testing::StopTiming(); 830 831 se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); 832 auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); 833 StreamExecutorMemoryAllocator allocator(platform, executors); 834 835 const int64 intra_op_parallelism_threads = 24; 836 xla::LocalClientOptions client_options; 837 client_options.set_platform(platform); 838 client_options.set_intra_op_parallelism_threads(intra_op_parallelism_threads); 839 auto client = 840 ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie(); 841 842 int device_ordinal = client->default_device_ordinal(); 843 844 // Computation shape parameters. 845 const int64 param0_dim0 = 1024; 846 const int64 param0_dim1 = 1024; 847 const int64 param1_dim0 = 1024; 848 const int64 param1_dim1 = 1024; 849 const int64 param2_dim0 = 1024; 850 const int64 param2_dim1 = 1024; 851 852 // Create computation. 853 XlaBuilder builder("ParallelFusion"); 854 Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1}); 855 auto param0 = Parameter(&builder, 0, shape0, "param0"); 856 Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1}); 857 auto param1 = Parameter(&builder, 1, shape1, "param1"); 858 Shape shape2 = ShapeUtil::MakeShape(F32, {param2_dim0, param2_dim1}); 859 auto param2 = Parameter(&builder, 2, shape2, "param2"); 860 861 auto x = Mul(param0, param1); 862 Add(x, param2); 863 auto computation = builder.Build().ConsumeValueOrDie(); 864 865 // Transfer literals to device. 866 auto param0_literal = 867 LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); 868 ScopedShapedBuffer buffer0 = 869 client->LiteralToShapedBuffer(param0_literal, device_ordinal) 870 .ConsumeValueOrDie(); 871 872 auto param1_literal = 873 LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); 874 ScopedShapedBuffer buffer1 = 875 client->LiteralToShapedBuffer(param1_literal, device_ordinal) 876 .ConsumeValueOrDie(); 877 878 auto param2_literal = 879 LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); 880 ScopedShapedBuffer buffer2 = 881 client->LiteralToShapedBuffer(param2_literal, device_ordinal) 882 .ConsumeValueOrDie(); 883 884 // Build executable. 885 std::unique_ptr<LocalExecutable> executable = 886 client 887 ->Compile(computation, 888 {&buffer0.on_host_shape(), &buffer1.on_host_shape(), 889 &buffer2.on_host_shape()}, 890 ExecutableBuildOptions()) 891 .ConsumeValueOrDie(); 892 893 se::Stream stream(executors[device_ordinal]); 894 stream.Init(); 895 896 // Initialize thread pool. 897 tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen", 898 intra_op_parallelism_threads); 899 tensorflow::EigenThreadPoolWrapper tp(&pool); 900 Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); 901 902 // Initialize ExecutableRunOptions. 903 ExecutableRunOptions options; 904 options.set_allocator(&allocator).set_stream(&stream); 905 options.set_intra_op_thread_pool(&device); 906 907 // Run some warm-up executions. 908 const int kWarmups = 2; 909 for (int i = 0; i < kWarmups; ++i) { 910 auto result = executable->Run({&buffer0, &buffer1, &buffer2}, options); 911 ASSERT_TRUE(result.ok()); 912 } 913 914 // Run benchmark. 915 const int64 total_bytes = param0_dim0 * param0_dim0 + 916 param1_dim0 * param1_dim0 + 917 param2_dim0 * param2_dim0; 918 tensorflow::testing::BytesProcessed(static_cast<int64>(num_iters) * 919 total_bytes * sizeof(float)); 920 tensorflow::testing::UseRealTime(); 921 tensorflow::testing::StartTiming(); 922 for (int i = 0; i < num_iters; ++i) { 923 auto result = executable->Run({&buffer0, &buffer1, &buffer2}, options); 924 ASSERT_TRUE(result.ok()); 925 } 926 } 927 928 BENCHMARK(BM_ParallelFusion); 929 930 } // namespace 931 } // namespace xla 932