1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 17 18 #include <string> 19 20 #include "tensorflow/compiler/xla/client/client_library.h" 21 #include "tensorflow/compiler/xla/client/computation.h" 22 #include "tensorflow/compiler/xla/client/local_client.h" 23 #include "tensorflow/compiler/xla/execution_options_util.h" 24 #include "tensorflow/compiler/xla/literal_util.h" 25 #include "tensorflow/compiler/xla/ptr_util.h" 26 #include "tensorflow/compiler/xla/shape_util.h" 27 #include "tensorflow/compiler/xla/status_macros.h" 28 #include "tensorflow/compiler/xla/statusor.h" 29 #include "tensorflow/compiler/xla/test_helpers.h" 30 #include "tensorflow/core/lib/strings/str_util.h" 31 #include "tensorflow/core/platform/logging.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace se = ::perftools::gputools; 35 36 namespace xla { 37 namespace { 38 // Wrapper function that creates a nicer error message (than a bare 39 // ValueOrDie()) if the platform we intend to test is not available. 40 Client* GetOrCreateLocalClientOrDie(const LocalClientOptions& client_options) { 41 StatusOr<Client*> result = 42 ClientLibrary::GetOrCreateLocalClient(client_options); 43 TF_CHECK_OK(result.status()) << " could not create local client for testing"; 44 return result.ValueOrDie(); 45 } 46 } // namespace 47 48 ClientLibraryTestBase::ClientLibraryTestBase( 49 perftools::gputools::Platform* platform, 50 const LocalClientOptions& client_options) 51 : client_(GetOrCreateLocalClientOrDie(client_options)), 52 execution_options_(CreateDefaultExecutionOptions()) { 53 CHECK_EQ(platform, client_options.platform()); 54 // Disabling constant_folding so that tests (usually written using Constants) 55 // will exercise the intended code paths, instead of being constant folded. 56 // 57 // TODO(b/38354253): Constant folding is currently disabled. Change tests to 58 // use Parameters instead of Constants, and re-enable constant folding by 59 // default. 60 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( 61 "constant_folding"); 62 } 63 64 ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) 65 : execution_options_(CreateDefaultExecutionOptions()) { 66 LocalClientOptions default_options; 67 default_options.set_platform(platform); 68 client_ = GetOrCreateLocalClientOrDie(default_options); 69 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( 70 "constant_folding"); 71 } 72 73 string ClientLibraryTestBase::TestName() const { 74 return ::testing::UnitTest::GetInstance()->current_test_info()->name(); 75 } 76 77 StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute( 78 ComputationBuilder* builder, 79 tensorflow::gtl::ArraySlice<GlobalData*> arguments) { 80 // Build the computation, as a convenience. 81 TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); 82 return client_->Execute(computation, arguments, &execution_options_); 83 } 84 85 StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer( 86 const Computation& computation, 87 tensorflow::gtl::ArraySlice<GlobalData*> arguments, 88 const Shape* shape_with_output_layout) { 89 ExecutionOptions execution_options = execution_options_; 90 if (shape_with_output_layout != nullptr) { 91 *execution_options.mutable_shape_with_output_layout() = 92 *shape_with_output_layout; 93 } 94 return client_->ExecuteAndTransfer(computation, arguments, 95 &execution_options); 96 } 97 98 StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer( 99 ComputationBuilder* builder, 100 tensorflow::gtl::ArraySlice<GlobalData*> arguments, 101 const Shape* shape_with_output_layout) { 102 // Build the computation, as a convenience. 103 TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); 104 return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); 105 } 106 107 std::unique_ptr<GlobalData> ClientLibraryTestBase::ExecuteOrDie( 108 ComputationBuilder* builder, 109 tensorflow::gtl::ArraySlice<GlobalData*> arguments) { 110 return Execute(builder, arguments).ConsumeValueOrDie(); 111 } 112 113 std::unique_ptr<Literal> ClientLibraryTestBase::ExecuteAndTransferOrDie( 114 ComputationBuilder* builder, 115 tensorflow::gtl::ArraySlice<GlobalData*> arguments) { 116 return ExecuteAndTransfer(builder, arguments).ConsumeValueOrDie(); 117 } 118 119 string ClientLibraryTestBase::ExecuteToString( 120 ComputationBuilder* builder, 121 tensorflow::gtl::ArraySlice<GlobalData*> arguments) { 122 StatusOr<Computation> computation_status = builder->Build(); 123 if (!computation_status.ok()) { 124 return computation_status.status().ToString(); 125 } 126 Computation computation = computation_status.ConsumeValueOrDie(); 127 128 auto result = 129 client_->ExecuteAndTransfer(computation, arguments, &execution_options_); 130 if (!result.ok()) { 131 return result.status().ToString(); 132 } else { 133 return result.ValueOrDie()->ToString(); 134 } 135 } 136 137 void ClientLibraryTestBase::ComputeAndCompareR1( 138 ComputationBuilder* builder, const tensorflow::core::Bitmap& expected, 139 tensorflow::gtl::ArraySlice<GlobalData*> arguments) { 140 std::unique_ptr<Literal> expected_literal = Literal::CreateR1(expected); 141 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, 142 arguments); 143 } 144 145 void ClientLibraryTestBase::ComputeAndCompareLiteral( 146 ComputationBuilder* builder, const Literal& expected, 147 tensorflow::gtl::ArraySlice<GlobalData*> arguments, 148 const Shape* shape_with_layout) { 149 EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, 150 shape_with_layout)); 151 } 152 153 void ClientLibraryTestBase::ComputeAndCompareLiteral( 154 ComputationBuilder* builder, const Literal& expected, 155 tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error, 156 const Shape* shape_with_layout) { 157 EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, 158 error, shape_with_layout)); 159 } 160 161 tensorflow::Status 162 ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( 163 const xla::Computation& computation, const Literal& expected, 164 tensorflow::gtl::ArraySlice<GlobalData*> arguments, 165 const std::function<void(const Literal& actual, 166 const string& error_message)>& verify_output) { 167 // Try with no layout requirement. 168 TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments)); 169 verify_output(*actual, ""); 170 171 // Try with all output layouts. 172 std::vector<int64> minor_to_major(ShapeUtil::Rank(expected.shape())); 173 std::iota(minor_to_major.begin(), minor_to_major.end(), 0); 174 do { 175 auto layout = ShapeUtil::MakeShapeWithLayout( 176 expected.shape().element_type(), 177 AsInt64Slice(expected.shape().dimensions()), minor_to_major); 178 TF_ASSIGN_OR_RETURN(auto actual, 179 ExecuteAndTransfer(computation, arguments, &layout)); 180 verify_output(*actual, tensorflow::strings::StrCat( 181 "Test with output layout: ", 182 ShapeUtil::HumanStringWithLayout(layout))); 183 } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end())); 184 return tensorflow::Status::OK(); 185 } 186 187 tensorflow::Status 188 ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( 189 const xla::Computation& computation, const Literal& expected, 190 tensorflow::gtl::ArraySlice<GlobalData*> arguments, 191 const std::function<void(const Literal& actual, 192 const string& error_message)>& verify_output, 193 const Shape* output_with_layout) { 194 std::vector<GlobalData*> arguments_with_layout; 195 std::vector<string> layout_strings; 196 // This is a recursive function. It's an std::function instead of a lambda 197 // because it needs to capture itself. The index is the index of the argument 198 // to try all layouts for. 199 std::function<tensorflow::Status(int64)> choose; 200 choose = [&, this](int64 index) -> tensorflow::Status { 201 if (index < arguments.size()) { 202 // Try out all layouts for the operand. 203 TF_ASSIGN_OR_RETURN(auto literal, 204 client_->Transfer(*arguments[index], nullptr)); 205 // Skip tuples because they don't have a rank. 206 if (ShapeUtil::IsTuple(literal->shape())) { 207 layout_strings.push_back( 208 ShapeUtil::HumanStringWithLayout(literal->shape())); 209 arguments_with_layout.push_back(arguments[index]); 210 TF_RETURN_IF_ERROR(choose(index + 1)); 211 arguments_with_layout.pop_back(); 212 layout_strings.pop_back(); 213 return tensorflow::Status::OK(); 214 } 215 216 std::vector<int64> minor_to_major(ShapeUtil::Rank(literal->shape())); 217 std::iota(minor_to_major.begin(), minor_to_major.end(), 0); 218 do { 219 auto literal_relayout = 220 literal->Relayout(LayoutUtil::MakeLayout(minor_to_major)); 221 layout_strings.push_back( 222 ShapeUtil::HumanStringWithLayout(literal_relayout->shape())); 223 TF_ASSIGN_OR_RETURN(auto data, 224 client_->TransferToServer(*literal_relayout)); 225 arguments_with_layout.push_back(data.get()); 226 TF_RETURN_IF_ERROR(choose(index + 1)); 227 arguments_with_layout.pop_back(); 228 layout_strings.pop_back(); 229 } while ( 230 std::next_permutation(minor_to_major.begin(), minor_to_major.end())); 231 return tensorflow::Status::OK(); 232 } 233 234 // Every argument has an assigned layout. 235 TF_ASSIGN_OR_RETURN( 236 auto actual, 237 ExecuteAndTransfer( 238 computation, 239 tensorflow::gtl::ArraySlice<GlobalData*>(arguments_with_layout), 240 output_with_layout)); 241 string error_message = "Test with input layouts: "; 242 for (const auto& str : layout_strings) { 243 tensorflow::strings::StrAppend(&error_message, str, " "); 244 } 245 verify_output(*actual, error_message); 246 return tensorflow::Status::OK(); 247 }; 248 249 return choose(0); 250 } 251 252 tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( 253 ComputationBuilder* builder, const Literal& expected, 254 tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in, 255 const Shape* shape_with_layout) { 256 std::vector<GlobalData*> arguments(arguments_passed_in.begin(), 257 arguments_passed_in.end()); 258 if (!arguments_.empty()) { 259 CHECK(arguments.empty()); 260 for (const auto& argument : arguments_) { 261 arguments.push_back(argument.get()); 262 } 263 } 264 265 TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); 266 if (ShapeUtil::ElementIsFloating(expected.shape()) || 267 ShapeUtil::ElementIsComplex(expected.shape())) { 268 LOG(WARNING) << "performing exact comparison of floating point numbers"; 269 } else { 270 TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) || 271 expected.shape().element_type() == PRED) 272 << ShapeUtil::HumanString(expected.shape()); 273 } 274 // We allow using a float expected literal for a bfloat16 output. In this 275 // case, we need to convert the expected literal to bfloat16. 276 const Literal* expected_ptr = &expected; 277 std::unique_ptr<Literal> converted_expected; 278 Shape layout_shape; 279 if (use_bfloat16_) { 280 converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); 281 expected_ptr = converted_expected.get(); 282 if (shape_with_layout != nullptr) { 283 layout_shape = *shape_with_layout; 284 ShapeUtil::ForEachMutableSubshape( 285 &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { 286 if (subshape->element_type() == F32) { 287 subshape->set_element_type(BF16); 288 } 289 }); 290 shape_with_layout = &layout_shape; 291 } 292 } 293 auto expect_equal = [&](const Literal& actual, const string& error_message) { 294 LiteralTestUtil::ExpectEqual(*expected_ptr, actual, error_message); 295 }; 296 if (execution_options_.debug_options().xla_test_all_output_layouts()) { 297 return ComputeAndCompareLiteralWithAllOutputLayouts( 298 computation, *expected_ptr, arguments, expect_equal); 299 } 300 if (execution_options_.debug_options().xla_test_all_input_layouts()) { 301 return ComputeAndCompareLiteralWithAllInputLayouts( 302 computation, *expected_ptr, arguments, expect_equal, shape_with_layout); 303 } 304 TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, 305 shape_with_layout)); 306 LiteralTestUtil::ExpectEqual(*expected_ptr, *actual); 307 return tensorflow::Status::OK(); 308 } 309 310 tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( 311 ComputationBuilder* builder, const Literal& expected, 312 tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in, 313 ErrorSpec error, const Shape* shape_with_layout) { 314 std::vector<GlobalData*> arguments(arguments_passed_in.begin(), 315 arguments_passed_in.end()); 316 if (!arguments_.empty()) { 317 CHECK(arguments.empty()); 318 for (const auto& argument : arguments_) { 319 arguments.push_back(argument.get()); 320 } 321 } 322 323 TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) || 324 ShapeUtil::ElementIsComplex(expected.shape())); 325 TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); 326 // We allow using a float expected literal for a bfloat16 output. In this 327 // case, we need to convert the expected literal to bfloat16. 328 const Literal* expected_ptr = &expected; 329 std::unique_ptr<Literal> converted_expected; 330 Shape layout_shape; 331 if (use_bfloat16_) { 332 converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); 333 expected_ptr = converted_expected.get(); 334 if (shape_with_layout != nullptr) { 335 layout_shape = *shape_with_layout; 336 ShapeUtil::ForEachMutableSubshape( 337 &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { 338 if (subshape->element_type() == F32) { 339 subshape->set_element_type(BF16); 340 } 341 }); 342 shape_with_layout = &layout_shape; 343 } 344 } 345 auto expect_near = [&](const Literal& actual, const string& error_message) { 346 LiteralTestUtil::ExpectNear(*expected_ptr, actual, error, error_message); 347 }; 348 if (execution_options_.debug_options().xla_test_all_output_layouts()) { 349 return ComputeAndCompareLiteralWithAllOutputLayouts( 350 computation, *expected_ptr, arguments, expect_near); 351 } 352 if (execution_options_.debug_options().xla_test_all_input_layouts()) { 353 return ComputeAndCompareLiteralWithAllInputLayouts( 354 computation, *expected_ptr, arguments, expect_near, shape_with_layout); 355 } 356 TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, 357 shape_with_layout)); 358 LiteralTestUtil::ExpectNear(*expected_ptr, *actual, error); 359 return tensorflow::Status::OK(); 360 } 361 362 void ClientLibraryTestBase::ComputeAndCompareR1U8( 363 ComputationBuilder* builder, tensorflow::StringPiece expected, 364 tensorflow::gtl::ArraySlice<GlobalData*> arguments) { 365 auto actual_status = ExecuteAndTransfer(builder, arguments); 366 EXPECT_IS_OK(actual_status.status()); 367 if (!actual_status.ok()) { 368 return; 369 } 370 auto actual = actual_status.ConsumeValueOrDie(); 371 372 // Turn the expected value into a literal. 373 std::unique_ptr<Literal> expected_literal = Literal::CreateR1U8(expected); 374 375 VLOG(1) << "expected: " << expected_literal->ToString(); 376 VLOG(1) << "actual: " << actual->ToString(); 377 378 EXPECT_EQ(expected, actual->GetR1U8AsString()); 379 } 380 381 void ClientLibraryTestBase::ComputeAndCompareTuple( 382 ComputationBuilder* builder, const Literal& expected, 383 tensorflow::gtl::ArraySlice<GlobalData*> arguments) { 384 auto actual_status = ExecuteAndTransfer(builder, arguments); 385 EXPECT_IS_OK(actual_status.status()); 386 if (!actual_status.ok()) { 387 return; 388 } 389 auto actual = actual_status.ConsumeValueOrDie(); 390 LiteralTestUtil::ExpectEqual(expected, *actual); 391 } 392 393 void ClientLibraryTestBase::ComputeAndCompareTuple( 394 ComputationBuilder* builder, const Literal& expected, 395 tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) { 396 auto actual_status = ExecuteAndTransfer(builder, arguments); 397 EXPECT_IS_OK(actual_status.status()); 398 if (!actual_status.ok()) { 399 return; 400 } 401 auto actual = actual_status.ConsumeValueOrDie(); 402 LiteralTestUtil::ExpectNear(expected, *actual, error); 403 } 404 405 void ClientLibraryTestBase::ComputeAndCompare( 406 ComputationBuilder* builder, const ComputationDataHandle& operand, 407 tensorflow::gtl::ArraySlice<Literal> arguments) { 408 auto status_or_data = ComputeValueAndReference(builder, operand, arguments); 409 EXPECT_IS_OK(status_or_data); 410 if (!status_or_data.ok()) { 411 return; 412 } 413 std::unique_ptr<Literal> reference, result; 414 std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); 415 LiteralTestUtil::ExpectEqual(*reference, *result); 416 } 417 418 void ClientLibraryTestBase::ComputeAndCompare( 419 ComputationBuilder* builder, const ComputationDataHandle& operand, 420 tensorflow::gtl::ArraySlice<Literal> arguments, ErrorSpec error) { 421 auto status_or_data = ComputeValueAndReference(builder, operand, arguments); 422 EXPECT_IS_OK(status_or_data); 423 if (!status_or_data.ok()) { 424 return; 425 } 426 std::unique_ptr<Literal> reference, result; 427 std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); 428 LiteralTestUtil::ExpectNear(*reference, *result, error); 429 } 430 431 StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>> 432 ClientLibraryTestBase::ComputeValueAndReference( 433 ComputationBuilder* builder, const ComputationDataHandle& operand, 434 tensorflow::gtl::ArraySlice<Literal> arguments) { 435 // Transfer the arguments to the executor service. We put the unique_ptr's 436 // into a vector to keep the data alive on the service until the end of this 437 // function. 438 std::vector<std::unique_ptr<GlobalData>> argument_data; 439 for (const auto& arg : arguments) { 440 TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg)); 441 argument_data.push_back(std::move(data)); 442 } 443 444 // Create raw pointers to the GlobalData for the rest of the call stack. 445 std::vector<GlobalData*> argument_data_ptr; 446 std::transform( 447 argument_data.begin(), argument_data.end(), 448 std::back_inserter(argument_data_ptr), 449 [](const std::unique_ptr<GlobalData>& data) { return data.get(); }); 450 451 TF_ASSIGN_OR_RETURN( 452 auto reference, 453 builder->ComputeConstant(operand, /*output_layout=*/nullptr, arguments)); 454 TF_ASSIGN_OR_RETURN(auto result, 455 ExecuteAndTransfer(builder, argument_data_ptr)); 456 return std::make_pair(std::move(reference), std::move(result)); 457 } 458 459 Computation ClientLibraryTestBase::CreateScalarRelu() { 460 ComputationBuilder builder(client_, "relu"); 461 auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); 462 auto z_value = builder.Parameter(0, shape, "z_value"); 463 auto zero = use_bfloat16_ 464 ? builder.ConstantR0<bfloat16>(static_cast<bfloat16>(0.0f)) 465 : builder.ConstantR0<float>(0.0f); 466 builder.Max(z_value, zero); 467 auto computation_status = builder.Build(); 468 TF_CHECK_OK(computation_status.status()); 469 return computation_status.ConsumeValueOrDie(); 470 } 471 472 Computation ClientLibraryTestBase::CreateScalarMax() { 473 ComputationBuilder builder(client_, "max"); 474 auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); 475 auto x = builder.Parameter(0, shape, "x"); 476 auto y = builder.Parameter(1, shape, "y"); 477 builder.Max(x, y); 478 auto computation_status = builder.Build(); 479 TF_CHECK_OK(computation_status.status()); 480 return computation_status.ConsumeValueOrDie(); 481 } 482 483 Computation ClientLibraryTestBase::CreateScalarReluSensitivity() { 484 ComputationBuilder builder(client_, "relu_sensitivity"); 485 auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); 486 auto activation = builder.Parameter(0, shape, "activation"); 487 auto backprop = builder.Parameter(1, shape, "backprop"); 488 auto zero = use_bfloat16_ 489 ? builder.ConstantR0<bfloat16>(static_cast<bfloat16>(0.0f)) 490 : builder.ConstantR0<float>(0.0f); 491 auto activation_gtz = builder.Gt(activation, zero); 492 builder.Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero); 493 494 auto computation_status = builder.Build(); 495 TF_CHECK_OK(computation_status.status()); 496 return computation_status.ConsumeValueOrDie(); 497 } 498 499 std::unique_ptr<Array2D<float>> ClientLibraryTestBase::CreatePatternedMatrix( 500 int rows, int cols, float offset) { 501 auto array = MakeUnique<Array2D<float>>(rows, cols); 502 for (int64 row = 0; row < rows; ++row) { 503 for (int64 col = 0; col < cols; ++col) { 504 (*array)(row, col) = col + (row * 1000.0f) + offset; 505 } 506 } 507 return array; 508 } 509 510 std::unique_ptr<Array2D<float>> 511 ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, 512 int rows_padded, 513 int cols_padded) { 514 CHECK_GE(rows_padded, rows); 515 CHECK_GE(cols_padded, cols); 516 auto array = MakeUnique<Array2D<float>>(rows_padded, cols_padded, 0.0); 517 for (int64 row = 0; row < rows; ++row) { 518 for (int64 col = 0; col < cols; ++col) { 519 (*array)(row, col) = col + (row * 1000.0f); 520 } 521 } 522 return array; 523 } 524 525 std::unique_ptr<GlobalData> 526 ClientLibraryTestBase::CreateParameterAndTransferLiteral( 527 int64 parameter_number, const Literal& literal, const string& name, 528 ComputationBuilder* builder, ComputationDataHandle* data_handle) { 529 return CreateParameterAndTransferLiteral(parameter_number, literal, name, 530 nullptr, builder, data_handle); 531 } 532 533 std::unique_ptr<GlobalData> 534 ClientLibraryTestBase::CreateParameterAndTransferLiteral( 535 int64 parameter_number, const Literal& literal, const string& name, 536 const DeviceHandle* device_handle, ComputationBuilder* builder, 537 ComputationDataHandle* data_handle) { 538 const Literal* param_literal = &literal; 539 std::unique_ptr<Literal> converted_literal; 540 if (use_bfloat16_) { 541 converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); 542 param_literal = converted_literal.get(); 543 } 544 std::unique_ptr<GlobalData> data = 545 client_->TransferToServer(*param_literal, device_handle) 546 .ConsumeValueOrDie(); 547 *data_handle = 548 builder->Parameter(parameter_number, param_literal->shape(), name); 549 return data; 550 } 551 552 ComputationDataHandle ClientLibraryTestBase::AddParam( 553 const Literal& argument, ComputationBuilder* builder) { 554 ComputationDataHandle data_handle; 555 arguments_.push_back(CreateParameterAndTransferLiteral( 556 arguments_.size(), argument, "", builder, &data_handle)); 557 return data_handle; 558 } 559 560 ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral( 561 const Literal& literal, ComputationBuilder* builder) { 562 return builder->ConstantLiteral( 563 use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); 564 } 565 566 } // namespace xla 567