1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 17 18 #include <memory> 19 #include <string> 20 21 #include "absl/memory/memory.h" 22 #include "absl/strings/str_cat.h" 23 #include "tensorflow/compiler/xla/client/client_library.h" 24 #include "tensorflow/compiler/xla/client/local_client.h" 25 #include "tensorflow/compiler/xla/client/xla_builder.h" 26 #include "tensorflow/compiler/xla/execution_options_util.h" 27 #include "tensorflow/compiler/xla/literal_util.h" 28 #include "tensorflow/compiler/xla/service/platform_util.h" 29 #include "tensorflow/compiler/xla/shape_util.h" 30 #include "tensorflow/compiler/xla/status_macros.h" 31 #include "tensorflow/compiler/xla/statusor.h" 32 #include "tensorflow/compiler/xla/test_helpers.h" 33 #include "tensorflow/core/platform/logging.h" 34 #include "tensorflow/core/platform/types.h" 35 36 namespace xla { 37 namespace { 38 39 // Name of the interpreter backend. 40 constexpr char kInterpreter[] = "interpreter"; 41 42 // Wrapper function that creates a nicer error message (than a bare 43 // ValueOrDie()) if the platform we intend to test is not available. 44 LocalClient* GetOrCreateLocalClientOrDie( 45 const LocalClientOptions& client_options) { 46 StatusOr<LocalClient*> result = 47 ClientLibrary::GetOrCreateLocalClient(client_options); 48 TF_CHECK_OK(result.status()) << " could not create local client for testing"; 49 return result.ValueOrDie(); 50 } 51 52 // Helper functions to get the reference platform. 53 se::Platform* GetReferencePlatform() { 54 auto result = PlatformUtil::GetPlatform(kInterpreter); 55 TF_CHECK_OK(result.status()) << "could not get interpreter platform"; 56 return result.ValueOrDie(); 57 } 58 59 } // namespace 60 61 ClientLibraryTestBase::ClientLibraryTestBase( 62 se::Platform* platform, const LocalClientOptions& client_options) 63 : client_(GetOrCreateLocalClientOrDie(client_options)), 64 execution_options_(CreateDefaultExecutionOptions()) { 65 CHECK_EQ(platform, client_options.platform()); 66 67 LocalClientOptions ref_options; 68 ref_options.set_platform(GetReferencePlatform()); 69 ref_client_ = GetOrCreateLocalClientOrDie(ref_options); 70 71 // Disabling constant_folding so that tests (usually written using Constants) 72 // will exercise the intended code paths, instead of being constant folded. 73 // 74 // TODO(b/38354253): Constant folding is currently disabled. Change tests to 75 // use Parameters instead of Constants, and re-enable constant folding by 76 // default. 77 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( 78 "constant_folding"); 79 80 execution_options_.mutable_debug_options() 81 ->set_xla_hlo_evaluator_use_fast_path(true); 82 } 83 84 ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) 85 : execution_options_(CreateDefaultExecutionOptions()) { 86 LocalClientOptions default_options; 87 default_options.set_platform(platform); 88 client_ = GetOrCreateLocalClientOrDie(default_options); 89 90 LocalClientOptions ref_options; 91 ref_options.set_platform(GetReferencePlatform()); 92 ref_client_ = GetOrCreateLocalClientOrDie(ref_options); 93 94 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( 95 "constant_folding"); 96 97 execution_options_.mutable_debug_options() 98 ->set_xla_hlo_evaluator_use_fast_path(true); 99 } 100 101 string ClientLibraryTestBase::TestName() const { 102 return ::testing::UnitTest::GetInstance()->current_test_info()->name(); 103 } 104 105 StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute( 106 XlaBuilder* builder, absl::Span<GlobalData* const> arguments) { 107 // Build the computation, as a convenience. 108 TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); 109 return client_->Execute(computation, arguments, &execution_options_); 110 } 111 112 StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer( 113 const XlaComputation& computation, absl::Span<GlobalData* const> arguments, 114 const Shape* shape_with_output_layout) { 115 ExecutionOptions execution_options = execution_options_; 116 if (shape_with_output_layout != nullptr) { 117 *execution_options.mutable_shape_with_output_layout() = 118 shape_with_output_layout->ToProto(); 119 } 120 return client_->ExecuteAndTransfer(computation, arguments, 121 &execution_options); 122 } 123 124 StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer( 125 XlaBuilder* builder, absl::Span<GlobalData* const> arguments, 126 const Shape* shape_with_output_layout) { 127 // Build the computation, as a convenience. 128 TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); 129 return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); 130 } 131 132 StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransferReference( 133 const XlaComputation& computation, absl::Span<GlobalData* const> arguments, 134 const Shape* shape_with_output_layout) { 135 ExecutionOptions execution_options = execution_options_; 136 if (shape_with_output_layout != nullptr) { 137 *execution_options.mutable_shape_with_output_layout() = 138 shape_with_output_layout->ToProto(); 139 } 140 execution_options.clear_device_handles(); 141 return ref_client_->ExecuteAndTransfer(computation, arguments, 142 &execution_options); 143 } 144 145 string ClientLibraryTestBase::ExecuteToString( 146 XlaBuilder* builder, absl::Span<GlobalData* const> arguments) { 147 auto computation_status = builder->Build(); 148 if (!computation_status.ok()) { 149 return computation_status.status().ToString(); 150 } 151 auto computation = computation_status.ConsumeValueOrDie(); 152 153 auto result = 154 client_->ExecuteAndTransfer(computation, arguments, &execution_options_); 155 if (!result.ok()) { 156 return result.status().ToString(); 157 } else { 158 return result.ValueOrDie().ToString(); 159 } 160 } 161 162 void ClientLibraryTestBase::ComputeAndCompareR1( 163 XlaBuilder* builder, const tensorflow::core::Bitmap& expected, 164 absl::Span<GlobalData* const> arguments) { 165 Literal expected_literal = LiteralUtil::CreateR1(expected); 166 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, 167 arguments); 168 } 169 170 void ClientLibraryTestBase::ComputeAndCompareLiteral( 171 XlaBuilder* builder, const Literal& expected, 172 absl::Span<GlobalData* const> arguments, const Shape* shape_with_layout) { 173 EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, 174 shape_with_layout)); 175 } 176 177 void ClientLibraryTestBase::ComputeAndCompareLiteral( 178 XlaBuilder* builder, const Literal& expected, 179 absl::Span<GlobalData* const> arguments, ErrorSpec error, 180 const Shape* shape_with_layout) { 181 EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, 182 error, shape_with_layout)); 183 } 184 185 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( 186 const xla::XlaComputation& computation, const Literal& expected, 187 absl::Span<GlobalData* const> arguments, 188 const std::function<void(const Literal& actual, 189 const string& error_message)>& verify_output) { 190 // Try with no layout requirement. 191 TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments)); 192 verify_output(actual, ""); 193 194 // Try with all output layouts. 195 std::vector<int64> minor_to_major(expected.shape().rank()); 196 std::iota(minor_to_major.begin(), minor_to_major.end(), 0); 197 do { 198 auto layout = ShapeUtil::MakeShapeWithLayout( 199 expected.shape().element_type(), 200 AsInt64Slice(expected.shape().dimensions()), minor_to_major); 201 TF_ASSIGN_OR_RETURN(auto actual, 202 ExecuteAndTransfer(computation, arguments, &layout)); 203 verify_output(actual, 204 absl::StrCat("Test with output layout: ", 205 ShapeUtil::HumanStringWithLayout(layout))); 206 } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end())); 207 return Status::OK(); 208 } 209 210 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( 211 const xla::XlaComputation& computation, const Literal& /*expected*/, 212 absl::Span<GlobalData* const> arguments, 213 const std::function<void(const Literal& actual, 214 const string& error_message)>& verify_output, 215 const Shape* output_with_layout) { 216 std::vector<GlobalData*> arguments_with_layout; 217 std::vector<string> layout_strings; 218 // This is a recursive function. It's an std::function instead of a lambda 219 // because it needs to capture itself. The index is the index of the argument 220 // to try all layouts for. 221 std::function<Status(int64)> choose; 222 choose = [&, this](int64 index) -> Status { 223 if (index < arguments.size()) { 224 // Try out all layouts for the operand. 225 TF_ASSIGN_OR_RETURN(auto literal, 226 client_->Transfer(*arguments[index], nullptr)); 227 // Skip tuples because they don't have a rank. 228 if (literal.shape().IsTuple()) { 229 layout_strings.push_back( 230 ShapeUtil::HumanStringWithLayout(literal.shape())); 231 arguments_with_layout.push_back(arguments[index]); 232 TF_RETURN_IF_ERROR(choose(index + 1)); 233 arguments_with_layout.pop_back(); 234 layout_strings.pop_back(); 235 return Status::OK(); 236 } 237 238 std::vector<int64> minor_to_major(literal.shape().rank()); 239 std::iota(minor_to_major.begin(), minor_to_major.end(), 0); 240 do { 241 auto literal_relayout = 242 literal.Relayout(LayoutUtil::MakeLayout(minor_to_major)); 243 layout_strings.push_back( 244 ShapeUtil::HumanStringWithLayout(literal_relayout.shape())); 245 TF_ASSIGN_OR_RETURN(auto data, 246 client_->TransferToServer(literal_relayout)); 247 arguments_with_layout.push_back(data.get()); 248 TF_RETURN_IF_ERROR(choose(index + 1)); 249 arguments_with_layout.pop_back(); 250 layout_strings.pop_back(); 251 } while ( 252 std::next_permutation(minor_to_major.begin(), minor_to_major.end())); 253 return Status::OK(); 254 } 255 256 // Every argument has an assigned layout. 257 TF_ASSIGN_OR_RETURN( 258 auto actual, 259 ExecuteAndTransfer(computation, 260 absl::Span<GlobalData* const>(arguments_with_layout), 261 output_with_layout)); 262 string error_message = "Test with input layouts: "; 263 for (const auto& str : layout_strings) { 264 absl::StrAppend(&error_message, str, " "); 265 } 266 verify_output(actual, error_message); 267 return Status::OK(); 268 }; 269 270 return choose(0); 271 } 272 273 StatusOr<Literal> ClientLibraryTestBase::ComputeAndTransfer( 274 XlaBuilder* builder, absl::Span<GlobalData* const> arguments_passed_in, 275 const Shape* shape_with_layout) { 276 std::vector<GlobalData*> arguments(arguments_passed_in.begin(), 277 arguments_passed_in.end()); 278 279 // Transfer and use elements of arguments_, if the AddParam() API was used. 280 std::vector<std::unique_ptr<GlobalData>> owning_arguments; 281 if (!arguments_.empty()) { 282 CHECK(arguments.empty()); 283 for (const auto& argument : arguments_) { 284 TF_ASSIGN_OR_RETURN( 285 std::unique_ptr<GlobalData> owned_argument, 286 client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); 287 owning_arguments.push_back(std::move(owned_argument)); 288 arguments.push_back(owning_arguments.back().get()); 289 } 290 } 291 292 TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); 293 return ExecuteAndTransfer(computation, arguments, shape_with_layout); 294 } 295 296 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( 297 XlaBuilder* builder, const Literal& expected, 298 absl::Span<GlobalData* const> arguments_passed_in, 299 const Shape* shape_with_layout) { 300 std::vector<GlobalData*> arguments(arguments_passed_in.begin(), 301 arguments_passed_in.end()); 302 303 // Transfer and use elements of arguments_, if the AddParam() API was used. 304 std::vector<std::unique_ptr<GlobalData>> owning_arguments; 305 if (!arguments_.empty()) { 306 CHECK(arguments.empty()); 307 for (const auto& argument : arguments_) { 308 TF_ASSIGN_OR_RETURN( 309 std::unique_ptr<GlobalData> owned_argument, 310 client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); 311 owning_arguments.push_back(std::move(owned_argument)); 312 arguments.push_back(owning_arguments.back().get()); 313 } 314 } 315 316 TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); 317 if (ShapeUtil::ElementIsFloating(expected.shape()) || 318 ShapeUtil::ElementIsComplex(expected.shape())) { 319 LOG(WARNING) << "performing exact comparison of floating point numbers"; 320 } 321 // We allow using a float expected literal for a bfloat16 output. In this 322 // case, we need to convert the expected literal to bfloat16. 323 const Literal* expected_ptr = &expected; 324 Literal converted_expected; 325 Shape layout_shape; 326 if (use_bfloat16_) { 327 converted_expected = LiteralUtil::ConvertF32ToBF16(expected); 328 expected_ptr = &converted_expected; 329 if (shape_with_layout != nullptr) { 330 layout_shape = *shape_with_layout; 331 ShapeUtil::ForEachMutableSubshape( 332 &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { 333 if (subshape->element_type() == F32) { 334 subshape->set_element_type(BF16); 335 } 336 }); 337 shape_with_layout = &layout_shape; 338 } 339 } 340 auto expect_equal = [&](const Literal& actual, const string& error_message) { 341 EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)) << error_message; 342 }; 343 if (execution_options_.debug_options().xla_test_all_output_layouts()) { 344 return ComputeAndCompareLiteralWithAllOutputLayouts( 345 computation, *expected_ptr, arguments, expect_equal); 346 } 347 if (execution_options_.debug_options().xla_test_all_input_layouts()) { 348 return ComputeAndCompareLiteralWithAllInputLayouts( 349 computation, *expected_ptr, arguments, expect_equal, shape_with_layout); 350 } 351 TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, 352 shape_with_layout)); 353 EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)); 354 return Status::OK(); 355 } 356 357 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( 358 XlaBuilder* builder, const Literal& expected, 359 absl::Span<GlobalData* const> arguments_passed_in, ErrorSpec error, 360 const Shape* shape_with_layout) { 361 std::vector<GlobalData*> arguments(arguments_passed_in.begin(), 362 arguments_passed_in.end()); 363 364 // Transfer and use elements of arguments_, if the AddParam() API was used. 365 std::vector<std::unique_ptr<GlobalData>> owning_arguments; 366 if (!arguments_.empty()) { 367 CHECK(arguments.empty()); 368 for (const auto& argument : arguments_) { 369 TF_ASSIGN_OR_RETURN( 370 std::unique_ptr<GlobalData> owned_argument, 371 client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); 372 owning_arguments.push_back(std::move(owned_argument)); 373 arguments.push_back(owning_arguments.back().get()); 374 } 375 } 376 377 TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); 378 // We allow using a float expected literal for a bfloat16 output. In this 379 // case, we need to convert the expected literal to bfloat16. 380 const Literal* expected_ptr = &expected; 381 Literal converted_expected; 382 Shape layout_shape; 383 if (use_bfloat16_) { 384 converted_expected = LiteralUtil::ConvertF32ToBF16(expected); 385 expected_ptr = &converted_expected; 386 if (shape_with_layout != nullptr) { 387 layout_shape = *shape_with_layout; 388 ShapeUtil::ForEachMutableSubshape( 389 &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { 390 if (subshape->element_type() == F32) { 391 subshape->set_element_type(BF16); 392 } 393 }); 394 shape_with_layout = &layout_shape; 395 } 396 } 397 auto expect_near = [&](const Literal& actual, const string& error_message) { 398 EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error)) 399 << error_message; 400 }; 401 if (execution_options_.debug_options().xla_test_all_output_layouts()) { 402 return ComputeAndCompareLiteralWithAllOutputLayouts( 403 computation, *expected_ptr, arguments, expect_near); 404 } 405 if (execution_options_.debug_options().xla_test_all_input_layouts()) { 406 return ComputeAndCompareLiteralWithAllInputLayouts( 407 computation, *expected_ptr, arguments, expect_near, shape_with_layout); 408 } 409 TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, 410 shape_with_layout)); 411 EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error)); 412 return Status::OK(); 413 } 414 415 void ClientLibraryTestBase::ComputeAndCompareR1U8( 416 XlaBuilder* builder, absl::string_view expected, 417 absl::Span<GlobalData* const> arguments) { 418 auto actual_status = ExecuteAndTransfer(builder, arguments); 419 EXPECT_IS_OK(actual_status.status()); 420 if (!actual_status.ok()) { 421 return; 422 } 423 auto actual = actual_status.ConsumeValueOrDie(); 424 425 // Turn the expected value into a literal. 426 Literal expected_literal = LiteralUtil::CreateR1U8(expected); 427 428 VLOG(1) << "expected: " << expected_literal.ToString(); 429 VLOG(1) << "actual: " << actual.ToString(); 430 431 EXPECT_EQ(expected, actual.GetR1U8AsString()); 432 } 433 434 void ClientLibraryTestBase::ComputeAndCompareTuple( 435 XlaBuilder* builder, const Literal& expected, 436 absl::Span<GlobalData* const> arguments) { 437 auto actual_status = ExecuteAndTransfer(builder, arguments); 438 EXPECT_IS_OK(actual_status.status()); 439 if (!actual_status.ok()) { 440 return; 441 } 442 auto actual = actual_status.ConsumeValueOrDie(); 443 EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); 444 } 445 446 void ClientLibraryTestBase::ComputeAndCompareTuple( 447 XlaBuilder* builder, const Literal& expected, 448 absl::Span<GlobalData* const> arguments, ErrorSpec error) { 449 auto actual_status = ExecuteAndTransfer(builder, arguments); 450 EXPECT_IS_OK(actual_status.status()); 451 if (!actual_status.ok()) { 452 return; 453 } 454 auto actual = actual_status.ConsumeValueOrDie(); 455 EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error)); 456 } 457 458 void ClientLibraryTestBase::ComputeAndCompare( 459 XlaBuilder* builder, absl::Span<const Literal> arguments) { 460 auto status_or_data = ComputeValueAndReference(builder, arguments); 461 EXPECT_IS_OK(status_or_data); 462 if (!status_or_data.ok()) { 463 return; 464 } 465 Literal reference, result; 466 std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); 467 EXPECT_TRUE(LiteralTestUtil::Equal(reference, result)); 468 } 469 470 void ClientLibraryTestBase::ComputeAndCompare( 471 XlaBuilder* builder, absl::Span<const Literal> arguments, ErrorSpec error) { 472 auto status_or_data = ComputeValueAndReference(builder, arguments); 473 EXPECT_IS_OK(status_or_data); 474 if (!status_or_data.ok()) { 475 return; 476 } 477 Literal reference, result; 478 std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); 479 EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error)); 480 } 481 482 StatusOr<std::pair<Literal, Literal>> 483 ClientLibraryTestBase::ComputeValueAndReference( 484 XlaBuilder* builder, absl::Span<const Literal> arguments) { 485 // Transfer the arguments to the executor service. We put the unique_ptr's 486 // into a vector to keep the data alive on the service until the end of this 487 // function. 488 std::vector<std::unique_ptr<GlobalData>> argument_data; 489 std::vector<std::unique_ptr<GlobalData>> ref_argument_data; 490 491 // Use `arguments_` if the AddParam() API was used. Otherwise, use 492 // plain `arguments`. 493 if (!arguments_.empty()) { 494 CHECK_EQ(arguments.size(), 0); 495 arguments = arguments_; 496 } 497 498 for (const auto& arg : arguments) { 499 TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg.Clone())); 500 TF_ASSIGN_OR_RETURN(auto ref_data, ref_client_->TransferToServer(arg)); 501 argument_data.push_back(std::move(data)); 502 ref_argument_data.push_back(std::move(ref_data)); 503 } 504 505 // Create raw pointers to the GlobalData for the rest of the call stack. 506 std::vector<GlobalData*> argument_data_ptr; 507 std::transform( 508 argument_data.begin(), argument_data.end(), 509 std::back_inserter(argument_data_ptr), 510 [](const std::unique_ptr<GlobalData>& data) { return data.get(); }); 511 std::vector<GlobalData*> ref_argument_data_ptr; 512 std::transform( 513 ref_argument_data.begin(), ref_argument_data.end(), 514 std::back_inserter(ref_argument_data_ptr), 515 [](const std::unique_ptr<GlobalData>& data) { return data.get(); }); 516 517 TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); 518 519 TF_ASSIGN_OR_RETURN(auto result, 520 ExecuteAndTransfer(computation, argument_data_ptr)); 521 522 TF_ASSIGN_OR_RETURN(auto reference, ExecuteAndTransferReference( 523 computation, ref_argument_data_ptr)); 524 525 return std::make_pair(std::move(reference), std::move(result)); 526 } 527 528 XlaComputation ClientLibraryTestBase::CreateScalarRelu() { 529 XlaBuilder builder("relu"); 530 auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); 531 auto z_value = Parameter(&builder, 0, shape, "z_value"); 532 auto zero = use_bfloat16_ 533 ? ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(0.0f)) 534 : ConstantR0<float>(&builder, 0.0f); 535 Max(z_value, zero); 536 auto computation_status = builder.Build(); 537 TF_CHECK_OK(computation_status.status()); 538 return computation_status.ConsumeValueOrDie(); 539 } 540 541 XlaComputation ClientLibraryTestBase::CreateScalarMax() { 542 XlaBuilder builder("max"); 543 auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); 544 auto x = Parameter(&builder, 0, shape, "x"); 545 auto y = Parameter(&builder, 1, shape, "y"); 546 Max(x, y); 547 auto computation_status = builder.Build(); 548 TF_CHECK_OK(computation_status.status()); 549 return computation_status.ConsumeValueOrDie(); 550 } 551 552 XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() { 553 XlaBuilder builder("relu_sensitivity"); 554 auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); 555 auto activation = Parameter(&builder, 0, shape, "activation"); 556 auto backprop = Parameter(&builder, 1, shape, "backprop"); 557 auto zero = use_bfloat16_ 558 ? ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(0.0f)) 559 : ConstantR0<float>(&builder, 0.0f); 560 auto activation_gtz = Gt(activation, zero); 561 Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero); 562 563 auto computation_status = builder.Build(); 564 TF_CHECK_OK(computation_status.status()); 565 return computation_status.ConsumeValueOrDie(); 566 } 567 568 std::unique_ptr<Array2D<float>> ClientLibraryTestBase::CreatePatternedMatrix( 569 int rows, int cols, float offset) { 570 auto array = absl::make_unique<Array2D<float>>(rows, cols); 571 for (int64 row = 0; row < rows; ++row) { 572 for (int64 col = 0; col < cols; ++col) { 573 (*array)(row, col) = col + (row * 1000.0f) + offset; 574 } 575 } 576 return array; 577 } 578 579 std::unique_ptr<Array2D<float>> 580 ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, 581 int rows_padded, 582 int cols_padded) { 583 CHECK_GE(rows_padded, rows); 584 CHECK_GE(cols_padded, cols); 585 auto array = absl::make_unique<Array2D<float>>(rows_padded, cols_padded, 0.0); 586 for (int64 row = 0; row < rows; ++row) { 587 for (int64 col = 0; col < cols; ++col) { 588 (*array)(row, col) = col + (row * 1000.0f); 589 } 590 } 591 return array; 592 } 593 594 XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, 595 XlaBuilder* builder) { 596 arguments_.push_back(argument.Clone()); 597 return Parameter(builder, /*parameter_number=*/arguments_.size() - 1, 598 MaybeConvertShapeToBfloat16(argument.shape()), ""); 599 } 600 601 XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, 602 XlaBuilder* builder) { 603 return ConstantLiteral(builder, use_bfloat16_ 604 ? LiteralUtil::ConvertF32ToBF16(literal) 605 : LiteralSlice(literal)); 606 } 607 608 std::unique_ptr<GlobalData> 609 ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, 610 const Literal& literal, 611 const string& name, 612 XlaBuilder* builder, 613 XlaOp* data_handle) { 614 return CreateParameterAndTransferLiteral(parameter_number, literal, name, 615 nullptr, builder, data_handle); 616 } 617 618 Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) { 619 if (!use_bfloat16_) { 620 return shape; 621 } 622 Shape new_shape = shape; 623 ShapeUtil::ForEachMutableSubshape(&new_shape, 624 [](Shape* subshape, const ShapeIndex&) { 625 if (subshape->element_type() == F32) { 626 subshape->set_element_type(BF16); 627 } 628 }); 629 return new_shape; 630 } 631 632 Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16( 633 const Literal& literal) { 634 if (use_bfloat16_) { 635 return LiteralUtil::ConvertF32ToBF16(literal); 636 } 637 return literal.Clone(); 638 } 639 640 std::unique_ptr<GlobalData> 641 ClientLibraryTestBase::CreateParameterAndTransferLiteral( 642 int64 parameter_number, const Literal& literal, const string& name, 643 const DeviceHandle* device_handle, XlaBuilder* builder, 644 XlaOp* data_handle) { 645 Literal param_literal = MaybeConvertLiteralToBfloat16(literal); 646 std::unique_ptr<GlobalData> data = 647 client_->TransferToServer(param_literal, device_handle) 648 .ConsumeValueOrDie(); 649 *data_handle = 650 Parameter(builder, parameter_number, param_literal.shape(), name); 651 return data; 652 } 653 654 } // namespace xla 655