Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     17 
     18 #include <string>
     19 
     20 #include "tensorflow/compiler/xla/client/client_library.h"
     21 #include "tensorflow/compiler/xla/client/computation.h"
     22 #include "tensorflow/compiler/xla/client/local_client.h"
     23 #include "tensorflow/compiler/xla/execution_options_util.h"
     24 #include "tensorflow/compiler/xla/literal_util.h"
     25 #include "tensorflow/compiler/xla/ptr_util.h"
     26 #include "tensorflow/compiler/xla/shape_util.h"
     27 #include "tensorflow/compiler/xla/status_macros.h"
     28 #include "tensorflow/compiler/xla/statusor.h"
     29 #include "tensorflow/compiler/xla/test_helpers.h"
     30 #include "tensorflow/core/lib/strings/str_util.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 #include "tensorflow/core/platform/types.h"
     33 
     34 namespace se = ::perftools::gputools;
     35 
     36 namespace xla {
     37 namespace {
     38 // Wrapper function that creates a nicer error message (than a bare
     39 // ValueOrDie()) if the platform we intend to test is not available.
     40 Client* GetOrCreateLocalClientOrDie(const LocalClientOptions& client_options) {
     41   StatusOr<Client*> result =
     42       ClientLibrary::GetOrCreateLocalClient(client_options);
     43   TF_CHECK_OK(result.status()) << " could not create local client for testing";
     44   return result.ValueOrDie();
     45 }
     46 }  // namespace
     47 
     48 ClientLibraryTestBase::ClientLibraryTestBase(
     49     perftools::gputools::Platform* platform,
     50     const LocalClientOptions& client_options)
     51     : client_(GetOrCreateLocalClientOrDie(client_options)),
     52       execution_options_(CreateDefaultExecutionOptions()) {
     53   CHECK_EQ(platform, client_options.platform());
     54   // Disabling constant_folding so that tests (usually written using Constants)
     55   // will exercise the intended code paths, instead of being constant folded.
     56   //
     57   // TODO(b/38354253): Constant folding is currently disabled. Change tests to
     58   // use Parameters instead of Constants, and re-enable constant folding by
     59   // default.
     60   execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
     61       "constant_folding");
     62 }
     63 
     64 ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform)
     65     : execution_options_(CreateDefaultExecutionOptions()) {
     66   LocalClientOptions default_options;
     67   default_options.set_platform(platform);
     68   client_ = GetOrCreateLocalClientOrDie(default_options);
     69   execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
     70       "constant_folding");
     71 }
     72 
     73 string ClientLibraryTestBase::TestName() const {
     74   return ::testing::UnitTest::GetInstance()->current_test_info()->name();
     75 }
     76 
     77 StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
     78     ComputationBuilder* builder,
     79     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
     80   // Build the computation, as a convenience.
     81   TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
     82   return client_->Execute(computation, arguments, &execution_options_);
     83 }
     84 
     85 StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
     86     const Computation& computation,
     87     tensorflow::gtl::ArraySlice<GlobalData*> arguments,
     88     const Shape* shape_with_output_layout) {
     89   ExecutionOptions execution_options = execution_options_;
     90   if (shape_with_output_layout != nullptr) {
     91     *execution_options.mutable_shape_with_output_layout() =
     92         *shape_with_output_layout;
     93   }
     94   return client_->ExecuteAndTransfer(computation, arguments,
     95                                      &execution_options);
     96 }
     97 
     98 StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
     99     ComputationBuilder* builder,
    100     tensorflow::gtl::ArraySlice<GlobalData*> arguments,
    101     const Shape* shape_with_output_layout) {
    102   // Build the computation, as a convenience.
    103   TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
    104   return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
    105 }
    106 
    107 std::unique_ptr<GlobalData> ClientLibraryTestBase::ExecuteOrDie(
    108     ComputationBuilder* builder,
    109     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
    110   return Execute(builder, arguments).ConsumeValueOrDie();
    111 }
    112 
    113 std::unique_ptr<Literal> ClientLibraryTestBase::ExecuteAndTransferOrDie(
    114     ComputationBuilder* builder,
    115     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
    116   return ExecuteAndTransfer(builder, arguments).ConsumeValueOrDie();
    117 }
    118 
    119 string ClientLibraryTestBase::ExecuteToString(
    120     ComputationBuilder* builder,
    121     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
    122   StatusOr<Computation> computation_status = builder->Build();
    123   if (!computation_status.ok()) {
    124     return computation_status.status().ToString();
    125   }
    126   Computation computation = computation_status.ConsumeValueOrDie();
    127 
    128   auto result =
    129       client_->ExecuteAndTransfer(computation, arguments, &execution_options_);
    130   if (!result.ok()) {
    131     return result.status().ToString();
    132   } else {
    133     return result.ValueOrDie()->ToString();
    134   }
    135 }
    136 
    137 void ClientLibraryTestBase::ComputeAndCompareR1(
    138     ComputationBuilder* builder, const tensorflow::core::Bitmap& expected,
    139     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
    140   std::unique_ptr<Literal> expected_literal = Literal::CreateR1(expected);
    141   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
    142                                                   arguments);
    143 }
    144 
    145 void ClientLibraryTestBase::ComputeAndCompareLiteral(
    146     ComputationBuilder* builder, const Literal& expected,
    147     tensorflow::gtl::ArraySlice<GlobalData*> arguments,
    148     const Shape* shape_with_layout) {
    149   EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
    150                                                   shape_with_layout));
    151 }
    152 
    153 void ClientLibraryTestBase::ComputeAndCompareLiteral(
    154     ComputationBuilder* builder, const Literal& expected,
    155     tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
    156     const Shape* shape_with_layout) {
    157   EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
    158                                                   error, shape_with_layout));
    159 }
    160 
    161 tensorflow::Status
    162 ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
    163     const xla::Computation& computation, const Literal& expected,
    164     tensorflow::gtl::ArraySlice<GlobalData*> arguments,
    165     const std::function<void(const Literal& actual,
    166                              const string& error_message)>& verify_output) {
    167   // Try with no layout requirement.
    168   TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments));
    169   verify_output(*actual, "");
    170 
    171   // Try with all output layouts.
    172   std::vector<int64> minor_to_major(ShapeUtil::Rank(expected.shape()));
    173   std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
    174   do {
    175     auto layout = ShapeUtil::MakeShapeWithLayout(
    176         expected.shape().element_type(),
    177         AsInt64Slice(expected.shape().dimensions()), minor_to_major);
    178     TF_ASSIGN_OR_RETURN(auto actual,
    179                         ExecuteAndTransfer(computation, arguments, &layout));
    180     verify_output(*actual, tensorflow::strings::StrCat(
    181                                "Test with output layout: ",
    182                                ShapeUtil::HumanStringWithLayout(layout)));
    183   } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
    184   return tensorflow::Status::OK();
    185 }
    186 
    187 tensorflow::Status
    188 ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
    189     const xla::Computation& computation, const Literal& expected,
    190     tensorflow::gtl::ArraySlice<GlobalData*> arguments,
    191     const std::function<void(const Literal& actual,
    192                              const string& error_message)>& verify_output,
    193     const Shape* output_with_layout) {
    194   std::vector<GlobalData*> arguments_with_layout;
    195   std::vector<string> layout_strings;
    196   // This is a recursive function. It's an std::function instead of a lambda
    197   // because it needs to capture itself. The index is the index of the argument
    198   // to try all layouts for.
    199   std::function<tensorflow::Status(int64)> choose;
    200   choose = [&, this](int64 index) -> tensorflow::Status {
    201     if (index < arguments.size()) {
    202       // Try out all layouts for the operand.
    203       TF_ASSIGN_OR_RETURN(auto literal,
    204                           client_->Transfer(*arguments[index], nullptr));
    205       // Skip tuples because they don't have a rank.
    206       if (ShapeUtil::IsTuple(literal->shape())) {
    207         layout_strings.push_back(
    208             ShapeUtil::HumanStringWithLayout(literal->shape()));
    209         arguments_with_layout.push_back(arguments[index]);
    210         TF_RETURN_IF_ERROR(choose(index + 1));
    211         arguments_with_layout.pop_back();
    212         layout_strings.pop_back();
    213         return tensorflow::Status::OK();
    214       }
    215 
    216       std::vector<int64> minor_to_major(ShapeUtil::Rank(literal->shape()));
    217       std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
    218       do {
    219         auto literal_relayout =
    220             literal->Relayout(LayoutUtil::MakeLayout(minor_to_major));
    221         layout_strings.push_back(
    222             ShapeUtil::HumanStringWithLayout(literal_relayout->shape()));
    223         TF_ASSIGN_OR_RETURN(auto data,
    224                             client_->TransferToServer(*literal_relayout));
    225         arguments_with_layout.push_back(data.get());
    226         TF_RETURN_IF_ERROR(choose(index + 1));
    227         arguments_with_layout.pop_back();
    228         layout_strings.pop_back();
    229       } while (
    230           std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
    231       return tensorflow::Status::OK();
    232     }
    233 
    234     // Every argument has an assigned layout.
    235     TF_ASSIGN_OR_RETURN(
    236         auto actual,
    237         ExecuteAndTransfer(
    238             computation,
    239             tensorflow::gtl::ArraySlice<GlobalData*>(arguments_with_layout),
    240             output_with_layout));
    241     string error_message = "Test with input layouts: ";
    242     for (const auto& str : layout_strings) {
    243       tensorflow::strings::StrAppend(&error_message, str, " ");
    244     }
    245     verify_output(*actual, error_message);
    246     return tensorflow::Status::OK();
    247   };
    248 
    249   return choose(0);
    250 }
    251 
    252 tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
    253     ComputationBuilder* builder, const Literal& expected,
    254     tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in,
    255     const Shape* shape_with_layout) {
    256   std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
    257                                      arguments_passed_in.end());
    258   if (!arguments_.empty()) {
    259     CHECK(arguments.empty());
    260     for (const auto& argument : arguments_) {
    261       arguments.push_back(argument.get());
    262     }
    263   }
    264 
    265   TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
    266   if (ShapeUtil::ElementIsFloating(expected.shape()) ||
    267       ShapeUtil::ElementIsComplex(expected.shape())) {
    268     LOG(WARNING) << "performing exact comparison of floating point numbers";
    269   } else {
    270     TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) ||
    271                  expected.shape().element_type() == PRED)
    272         << ShapeUtil::HumanString(expected.shape());
    273   }
    274   // We allow using a float expected literal for a bfloat16 output. In this
    275   // case, we need to convert the expected literal to bfloat16.
    276   const Literal* expected_ptr = &expected;
    277   std::unique_ptr<Literal> converted_expected;
    278   Shape layout_shape;
    279   if (use_bfloat16_) {
    280     converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected);
    281     expected_ptr = converted_expected.get();
    282     if (shape_with_layout != nullptr) {
    283       layout_shape = *shape_with_layout;
    284       ShapeUtil::ForEachMutableSubshape(
    285           &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) {
    286             if (subshape->element_type() == F32) {
    287               subshape->set_element_type(BF16);
    288             }
    289           });
    290       shape_with_layout = &layout_shape;
    291     }
    292   }
    293   auto expect_equal = [&](const Literal& actual, const string& error_message) {
    294     LiteralTestUtil::ExpectEqual(*expected_ptr, actual, error_message);
    295   };
    296   if (execution_options_.debug_options().xla_test_all_output_layouts()) {
    297     return ComputeAndCompareLiteralWithAllOutputLayouts(
    298         computation, *expected_ptr, arguments, expect_equal);
    299   }
    300   if (execution_options_.debug_options().xla_test_all_input_layouts()) {
    301     return ComputeAndCompareLiteralWithAllInputLayouts(
    302         computation, *expected_ptr, arguments, expect_equal, shape_with_layout);
    303   }
    304   TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
    305                                                       shape_with_layout));
    306   LiteralTestUtil::ExpectEqual(*expected_ptr, *actual);
    307   return tensorflow::Status::OK();
    308 }
    309 
    310 tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
    311     ComputationBuilder* builder, const Literal& expected,
    312     tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in,
    313     ErrorSpec error, const Shape* shape_with_layout) {
    314   std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
    315                                      arguments_passed_in.end());
    316   if (!arguments_.empty()) {
    317     CHECK(arguments.empty());
    318     for (const auto& argument : arguments_) {
    319       arguments.push_back(argument.get());
    320     }
    321   }
    322 
    323   TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) ||
    324                ShapeUtil::ElementIsComplex(expected.shape()));
    325   TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
    326   // We allow using a float expected literal for a bfloat16 output. In this
    327   // case, we need to convert the expected literal to bfloat16.
    328   const Literal* expected_ptr = &expected;
    329   std::unique_ptr<Literal> converted_expected;
    330   Shape layout_shape;
    331   if (use_bfloat16_) {
    332     converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected);
    333     expected_ptr = converted_expected.get();
    334     if (shape_with_layout != nullptr) {
    335       layout_shape = *shape_with_layout;
    336       ShapeUtil::ForEachMutableSubshape(
    337           &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) {
    338             if (subshape->element_type() == F32) {
    339               subshape->set_element_type(BF16);
    340             }
    341           });
    342       shape_with_layout = &layout_shape;
    343     }
    344   }
    345   auto expect_near = [&](const Literal& actual, const string& error_message) {
    346     LiteralTestUtil::ExpectNear(*expected_ptr, actual, error, error_message);
    347   };
    348   if (execution_options_.debug_options().xla_test_all_output_layouts()) {
    349     return ComputeAndCompareLiteralWithAllOutputLayouts(
    350         computation, *expected_ptr, arguments, expect_near);
    351   }
    352   if (execution_options_.debug_options().xla_test_all_input_layouts()) {
    353     return ComputeAndCompareLiteralWithAllInputLayouts(
    354         computation, *expected_ptr, arguments, expect_near, shape_with_layout);
    355   }
    356   TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
    357                                                       shape_with_layout));
    358   LiteralTestUtil::ExpectNear(*expected_ptr, *actual, error);
    359   return tensorflow::Status::OK();
    360 }
    361 
    362 void ClientLibraryTestBase::ComputeAndCompareR1U8(
    363     ComputationBuilder* builder, tensorflow::StringPiece expected,
    364     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
    365   auto actual_status = ExecuteAndTransfer(builder, arguments);
    366   EXPECT_IS_OK(actual_status.status());
    367   if (!actual_status.ok()) {
    368     return;
    369   }
    370   auto actual = actual_status.ConsumeValueOrDie();
    371 
    372   // Turn the expected value into a literal.
    373   std::unique_ptr<Literal> expected_literal = Literal::CreateR1U8(expected);
    374 
    375   VLOG(1) << "expected: " << expected_literal->ToString();
    376   VLOG(1) << "actual:   " << actual->ToString();
    377 
    378   EXPECT_EQ(expected, actual->GetR1U8AsString());
    379 }
    380 
    381 void ClientLibraryTestBase::ComputeAndCompareTuple(
    382     ComputationBuilder* builder, const Literal& expected,
    383     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
    384   auto actual_status = ExecuteAndTransfer(builder, arguments);
    385   EXPECT_IS_OK(actual_status.status());
    386   if (!actual_status.ok()) {
    387     return;
    388   }
    389   auto actual = actual_status.ConsumeValueOrDie();
    390   LiteralTestUtil::ExpectEqual(expected, *actual);
    391 }
    392 
    393 void ClientLibraryTestBase::ComputeAndCompareTuple(
    394     ComputationBuilder* builder, const Literal& expected,
    395     tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
    396   auto actual_status = ExecuteAndTransfer(builder, arguments);
    397   EXPECT_IS_OK(actual_status.status());
    398   if (!actual_status.ok()) {
    399     return;
    400   }
    401   auto actual = actual_status.ConsumeValueOrDie();
    402   LiteralTestUtil::ExpectNear(expected, *actual, error);
    403 }
    404 
    405 void ClientLibraryTestBase::ComputeAndCompare(
    406     ComputationBuilder* builder, const ComputationDataHandle& operand,
    407     tensorflow::gtl::ArraySlice<Literal> arguments) {
    408   auto status_or_data = ComputeValueAndReference(builder, operand, arguments);
    409   EXPECT_IS_OK(status_or_data);
    410   if (!status_or_data.ok()) {
    411     return;
    412   }
    413   std::unique_ptr<Literal> reference, result;
    414   std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
    415   LiteralTestUtil::ExpectEqual(*reference, *result);
    416 }
    417 
    418 void ClientLibraryTestBase::ComputeAndCompare(
    419     ComputationBuilder* builder, const ComputationDataHandle& operand,
    420     tensorflow::gtl::ArraySlice<Literal> arguments, ErrorSpec error) {
    421   auto status_or_data = ComputeValueAndReference(builder, operand, arguments);
    422   EXPECT_IS_OK(status_or_data);
    423   if (!status_or_data.ok()) {
    424     return;
    425   }
    426   std::unique_ptr<Literal> reference, result;
    427   std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
    428   LiteralTestUtil::ExpectNear(*reference, *result, error);
    429 }
    430 
    431 StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
    432 ClientLibraryTestBase::ComputeValueAndReference(
    433     ComputationBuilder* builder, const ComputationDataHandle& operand,
    434     tensorflow::gtl::ArraySlice<Literal> arguments) {
    435   // Transfer the arguments to the executor service. We put the unique_ptr's
    436   // into a vector to keep the data alive on the service until the end of this
    437   // function.
    438   std::vector<std::unique_ptr<GlobalData>> argument_data;
    439   for (const auto& arg : arguments) {
    440     TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg));
    441     argument_data.push_back(std::move(data));
    442   }
    443 
    444   // Create raw pointers to the GlobalData for the rest of the call stack.
    445   std::vector<GlobalData*> argument_data_ptr;
    446   std::transform(
    447       argument_data.begin(), argument_data.end(),
    448       std::back_inserter(argument_data_ptr),
    449       [](const std::unique_ptr<GlobalData>& data) { return data.get(); });
    450 
    451   TF_ASSIGN_OR_RETURN(
    452       auto reference,
    453       builder->ComputeConstant(operand, /*output_layout=*/nullptr, arguments));
    454   TF_ASSIGN_OR_RETURN(auto result,
    455                       ExecuteAndTransfer(builder, argument_data_ptr));
    456   return std::make_pair(std::move(reference), std::move(result));
    457 }
    458 
    459 Computation ClientLibraryTestBase::CreateScalarRelu() {
    460   ComputationBuilder builder(client_, "relu");
    461   auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
    462   auto z_value = builder.Parameter(0, shape, "z_value");
    463   auto zero = use_bfloat16_
    464                   ? builder.ConstantR0<bfloat16>(static_cast<bfloat16>(0.0f))
    465                   : builder.ConstantR0<float>(0.0f);
    466   builder.Max(z_value, zero);
    467   auto computation_status = builder.Build();
    468   TF_CHECK_OK(computation_status.status());
    469   return computation_status.ConsumeValueOrDie();
    470 }
    471 
    472 Computation ClientLibraryTestBase::CreateScalarMax() {
    473   ComputationBuilder builder(client_, "max");
    474   auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
    475   auto x = builder.Parameter(0, shape, "x");
    476   auto y = builder.Parameter(1, shape, "y");
    477   builder.Max(x, y);
    478   auto computation_status = builder.Build();
    479   TF_CHECK_OK(computation_status.status());
    480   return computation_status.ConsumeValueOrDie();
    481 }
    482 
    483 Computation ClientLibraryTestBase::CreateScalarReluSensitivity() {
    484   ComputationBuilder builder(client_, "relu_sensitivity");
    485   auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
    486   auto activation = builder.Parameter(0, shape, "activation");
    487   auto backprop = builder.Parameter(1, shape, "backprop");
    488   auto zero = use_bfloat16_
    489                   ? builder.ConstantR0<bfloat16>(static_cast<bfloat16>(0.0f))
    490                   : builder.ConstantR0<float>(0.0f);
    491   auto activation_gtz = builder.Gt(activation, zero);
    492   builder.Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero);
    493 
    494   auto computation_status = builder.Build();
    495   TF_CHECK_OK(computation_status.status());
    496   return computation_status.ConsumeValueOrDie();
    497 }
    498 
    499 std::unique_ptr<Array2D<float>> ClientLibraryTestBase::CreatePatternedMatrix(
    500     int rows, int cols, float offset) {
    501   auto array = MakeUnique<Array2D<float>>(rows, cols);
    502   for (int64 row = 0; row < rows; ++row) {
    503     for (int64 col = 0; col < cols; ++col) {
    504       (*array)(row, col) = col + (row * 1000.0f) + offset;
    505     }
    506   }
    507   return array;
    508 }
    509 
    510 std::unique_ptr<Array2D<float>>
    511 ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols,
    512                                                             int rows_padded,
    513                                                             int cols_padded) {
    514   CHECK_GE(rows_padded, rows);
    515   CHECK_GE(cols_padded, cols);
    516   auto array = MakeUnique<Array2D<float>>(rows_padded, cols_padded, 0.0);
    517   for (int64 row = 0; row < rows; ++row) {
    518     for (int64 col = 0; col < cols; ++col) {
    519       (*array)(row, col) = col + (row * 1000.0f);
    520     }
    521   }
    522   return array;
    523 }
    524 
    525 std::unique_ptr<GlobalData>
    526 ClientLibraryTestBase::CreateParameterAndTransferLiteral(
    527     int64 parameter_number, const Literal& literal, const string& name,
    528     ComputationBuilder* builder, ComputationDataHandle* data_handle) {
    529   return CreateParameterAndTransferLiteral(parameter_number, literal, name,
    530                                            nullptr, builder, data_handle);
    531 }
    532 
    533 std::unique_ptr<GlobalData>
    534 ClientLibraryTestBase::CreateParameterAndTransferLiteral(
    535     int64 parameter_number, const Literal& literal, const string& name,
    536     const DeviceHandle* device_handle, ComputationBuilder* builder,
    537     ComputationDataHandle* data_handle) {
    538   const Literal* param_literal = &literal;
    539   std::unique_ptr<Literal> converted_literal;
    540   if (use_bfloat16_) {
    541     converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal);
    542     param_literal = converted_literal.get();
    543   }
    544   std::unique_ptr<GlobalData> data =
    545       client_->TransferToServer(*param_literal, device_handle)
    546           .ConsumeValueOrDie();
    547   *data_handle =
    548       builder->Parameter(parameter_number, param_literal->shape(), name);
    549   return data;
    550 }
    551 
    552 ComputationDataHandle ClientLibraryTestBase::AddParam(
    553     const Literal& argument, ComputationBuilder* builder) {
    554   ComputationDataHandle data_handle;
    555   arguments_.push_back(CreateParameterAndTransferLiteral(
    556       arguments_.size(), argument, "", builder, &data_handle));
    557   return data_handle;
    558 }
    559 
    560 ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral(
    561     const Literal& literal, ComputationBuilder* builder) {
    562   return builder->ConstantLiteral(
    563       use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal);
    564 }
    565 
    566 }  // namespace xla
    567