Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
     17 #define TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
     18 
     19 #include <memory>
     20 #include <string>
     21 #include <type_traits>
     22 #include <vector>
     23 
     24 #include "tensorflow/compiler/xla/array2d.h"
     25 #include "tensorflow/compiler/xla/array3d.h"
     26 #include "tensorflow/compiler/xla/array4d.h"
     27 #include "tensorflow/compiler/xla/client/client_library.h"
     28 #include "tensorflow/compiler/xla/client/computation.h"
     29 #include "tensorflow/compiler/xla/client/computation_builder.h"
     30 #include "tensorflow/compiler/xla/client/global_data.h"
     31 #include "tensorflow/compiler/xla/literal_util.h"
     32 #include "tensorflow/compiler/xla/ptr_util.h"
     33 #include "tensorflow/compiler/xla/statusor.h"
     34 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     35 #include "tensorflow/compiler/xla/tests/test_utils.h"
     36 #include "tensorflow/compiler/xla/xla_data.pb.h"
     37 #include "tensorflow/core/lib/core/bitmap.h"
     38 #include "tensorflow/core/lib/core/stringpiece.h"
     39 #include "tensorflow/core/lib/gtl/array_slice.h"
     40 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     41 #include "tensorflow/core/platform/test.h"
     42 #include "tensorflow/core/platform/types.h"
     43 
     44 namespace xla {
     45 
     46 // Sets the use_bfloat16 on a container of test cases according to the values in
     47 // use_bfloat16_params. Generates one set of test cases for each values in
     48 // use_bfloat16_params with that value. Returns the result.
     49 template <typename TestCase>
     50 std::vector<TestCase> ExpandUseBfloat16(
     51     tensorflow::gtl::ArraySlice<bool> use_bfloat16_params,
     52     tensorflow::gtl::ArraySlice<TestCase> specs) {
     53   std::vector<TestCase> expanded;
     54   for (bool use_bfloat16 : use_bfloat16_params) {
     55     for (const auto& spec : specs) {
     56       expanded.push_back(spec);
     57       expanded.back().use_bfloat16 = use_bfloat16;
     58     }
     59   }
     60   return expanded;
     61 }
     62 
     63 // A client library test establishes an in-process XLA client connection.
     64 class ClientLibraryTestBase : public ::testing::Test {
     65  protected:
     66   explicit ClientLibraryTestBase(
     67       perftools::gputools::Platform* platform = nullptr);
     68 
     69   // Creates a new ClientLibraryTestBase with custom client options.
     70   ClientLibraryTestBase(perftools::gputools::Platform* platform,
     71                         const LocalClientOptions& client_options);
     72 
     73   // Returns the name of the test currently being run.
     74   string TestName() const;
     75 
     76   void SetFastMathDisabled(bool disabled) {
     77     execution_options_.mutable_debug_options()->set_xla_enable_fast_math(
     78         !disabled);
     79   }
     80 
     81   void SetSeed(uint64 seed) { execution_options_.set_seed(seed); }
     82 
     83   // Provides mutable access to the execution DebugOptions field; this lets
     84   // tests tweak the options that will be used to compile/run the graph.
     85   DebugOptions* mutable_debug_options() {
     86     return execution_options_.mutable_debug_options();
     87   }
     88 
     89   // TODO(b/25566808): Add helper that populates a literal from a testdata file.
     90 
     91   // Convenience methods for building and running a computation with the member
     92   // execution options. Modify execution_options_ in your test if you want to
     93   // customize the options.
     94   StatusOr<std::unique_ptr<GlobalData>> Execute(
     95       ComputationBuilder* builder,
     96       tensorflow::gtl::ArraySlice<GlobalData*> arguments);
     97   StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
     98       ComputationBuilder* builder,
     99       tensorflow::gtl::ArraySlice<GlobalData*> arguments,
    100       const Shape* shape_with_output_layout = nullptr);
    101   StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
    102       const Computation& computation,
    103       tensorflow::gtl::ArraySlice<GlobalData*> arguments,
    104       const Shape* shape_with_output_layout = nullptr);
    105 
    106   // Convenience OrDie variants of above methods.
    107   std::unique_ptr<GlobalData> ExecuteOrDie(
    108       ComputationBuilder* builder,
    109       tensorflow::gtl::ArraySlice<GlobalData*> arguments);
    110   std::unique_ptr<Literal> ExecuteAndTransferOrDie(
    111       ComputationBuilder* builder,
    112       tensorflow::gtl::ArraySlice<GlobalData*> arguments);
    113 
    114   // Run a computation and return its value as a string. If an error
    115   // occurs, then instead return the error as a string.
    116   string ExecuteToString(ComputationBuilder* builder,
    117                          tensorflow::gtl::ArraySlice<GlobalData*> arguments);
    118 
    119   // Convenience methods for building and running a computation, transferring
    120   // the result, and comparing it to the expected value(s). Methods are
    121   // templated on the native host type which maps to specific XLA types (See
    122   // ComputationBuilder for details). For each rank, two forms are provided: one
    123   // for floating point types with an ErrorSpec parameter, and one for integral
    124   // types without the ErrorSpec parameter.
    125   template <typename NativeT>
    126   void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected,
    127                            tensorflow::gtl::ArraySlice<GlobalData*> arguments);
    128   template <typename NativeT>
    129   void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected,
    130                            tensorflow::gtl::ArraySlice<GlobalData*> arguments,
    131                            ErrorSpec error);
    132 
    133   template <typename NativeT>
    134   void ComputeAndCompareR1(ComputationBuilder* builder,
    135                            tensorflow::gtl::ArraySlice<NativeT> expected,
    136                            tensorflow::gtl::ArraySlice<GlobalData*> arguments);
    137   template <typename NativeT>
    138   void ComputeAndCompareR1(ComputationBuilder* builder,
    139                            tensorflow::gtl::ArraySlice<NativeT> expected,
    140                            tensorflow::gtl::ArraySlice<GlobalData*> arguments,
    141                            ErrorSpec error);
    142 
    143   // As above, but uses a bitmap to hold the predicate vector to avoid
    144   // deficiencies of vector<bool>.
    145   void ComputeAndCompareR1(ComputationBuilder* builder,
    146                            const tensorflow::core::Bitmap& expected,
    147                            tensorflow::gtl::ArraySlice<GlobalData*> arguments);
    148 
    149   template <typename NativeT>
    150   void ComputeAndCompareR2(ComputationBuilder* builder,
    151                            const Array2D<NativeT>& expected,
    152                            tensorflow::gtl::ArraySlice<GlobalData*> arguments);
    153   template <typename NativeT>
    154   void ComputeAndCompareR2(ComputationBuilder* builder,
    155                            const Array2D<NativeT>& expected,
    156                            tensorflow::gtl::ArraySlice<GlobalData*> arguments,
    157                            ErrorSpec error);
    158 
    159   template <typename NativeT>
    160   void ComputeAndCompareR3(ComputationBuilder* builder,
    161                            const Array3D<NativeT>& expected,
    162                            tensorflow::gtl::ArraySlice<GlobalData*> arguments);
    163   template <typename NativeT>
    164   void ComputeAndCompareR3(ComputationBuilder* builder,
    165                            const Array3D<NativeT>& expected,
    166                            tensorflow::gtl::ArraySlice<GlobalData*> arguments,
    167                            ErrorSpec error);
    168 
    169   template <typename NativeT>
    170   void ComputeAndCompareR4(ComputationBuilder* builder,
    171                            const Array4D<NativeT>& expected,
    172                            tensorflow::gtl::ArraySlice<GlobalData*> arguments);
    173   template <typename NativeT>
    174   void ComputeAndCompareR4(ComputationBuilder* builder,
    175                            const Array4D<NativeT>& expected,
    176                            tensorflow::gtl::ArraySlice<GlobalData*> arguments,
    177                            ErrorSpec error);
    178 
    179   // Build and run the computation and compare the result with the given
    180   // literal. shape_with_layout indicates the result layout to request when
    181   // calling Execute.
    182   void ComputeAndCompareLiteral(
    183       ComputationBuilder* builder, const Literal& expected,
    184       tensorflow::gtl::ArraySlice<GlobalData*> arguments,
    185       const Shape* shape_with_layout = nullptr);
    186   void ComputeAndCompareLiteral(
    187       ComputationBuilder* builder, const Literal& expected,
    188       tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
    189       const Shape* shape_with_layout = nullptr);
    190 
    191   // ComputeAndCompare variant which returns an error status.
    192   tensorflow::Status ComputeAndCompareLiteralWithStatus(
    193       ComputationBuilder* builder, const Literal& expected,
    194       tensorflow::gtl::ArraySlice<GlobalData*> arguments,
    195       const Shape* shape_with_layout = nullptr);
    196   tensorflow::Status ComputeAndCompareLiteralWithStatus(
    197       ComputationBuilder* builder, const Literal& expected,
    198       tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
    199       const Shape* shape_with_layout = nullptr);
    200 
    201   // Compare the result of the computation to a strings. In XLA strings are
    202   // represented using rank-1 U8 shapes.
    203   void ComputeAndCompareR1U8(
    204       ComputationBuilder* builder, tensorflow::StringPiece expected,
    205       tensorflow::gtl::ArraySlice<GlobalData*> arguments);
    206 
    207   // Convenience method for running a built computation, transferring the
    208   // result, and comparing it to the expected tuple literal.
    209   void ComputeAndCompareTuple(
    210       ComputationBuilder* builder, const Literal& expected,
    211       tensorflow::gtl::ArraySlice<GlobalData*> arguments);
    212   void ComputeAndCompareTuple(
    213       ComputationBuilder* builder, const Literal& expected,
    214       tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
    215 
    216   // Convenience method for running a built computation and comparing the result
    217   // with the HloEvaluator.
    218   void ComputeAndCompare(ComputationBuilder* builder,
    219                          const ComputationDataHandle& operand,
    220                          tensorflow::gtl::ArraySlice<Literal> arguments);
    221   void ComputeAndCompare(ComputationBuilder* builder,
    222                          const ComputationDataHandle& operand,
    223                          tensorflow::gtl::ArraySlice<Literal> arguments,
    224                          ErrorSpec error);
    225 
    226   // Create scalar operations for use in reductions.
    227   Computation CreateScalarRelu();
    228   Computation CreateScalarMax();
    229   Computation CreateScalarReluSensitivity();
    230 
    231   // Special case convenience functions for creating filled arrays.
    232 
    233   // Creates an array of pseudorandom values lying between the given minimum and
    234   // maximum values.
    235   template <typename NativeT>
    236   std::vector<NativeT> CreatePseudorandomR1(const int width, NativeT min_value,
    237                                             NativeT max_value, uint32 seed);
    238   template <typename NativeT>
    239   std::unique_ptr<Array2D<NativeT>> CreatePseudorandomR2(const int rows,
    240                                                          const int cols,
    241                                                          NativeT min_value,
    242                                                          NativeT max_value,
    243                                                          uint32 seed);
    244 
    245   // Creates a (rows x cols) array filled in the following form:
    246   //
    247   //  [      0              1 ...                   cols-1]
    248   //  [  1,000          1,001 ...          1000.0 + cols-1]
    249   //  [    ...            ... ...                      ...]
    250   //  [(rows-1)*1000.0    ... ... (rows-1)*1000.0 + cols-1]
    251   //
    252   // If provided, offset is added uniformly to every element (e.g. an offset of
    253   // 64 would cause 0 in the above to be 64, 1 to be 65, 1000 to be 1064, etc.)
    254   std::unique_ptr<Array2D<float>> CreatePatternedMatrix(const int rows,
    255                                                         const int cols,
    256                                                         float offset = 0.0);
    257 
    258   // Creates a (rows x cols) array as above, padded out to
    259   // (rows_padded x cols_padded) with zeroes.  Requires rows_padded >= rows
    260   // and cols_padded > cols.
    261   std::unique_ptr<Array2D<float>> CreatePatternedMatrixWithZeroPadding(
    262       const int rows, const int cols, const int rows_padded,
    263       const int cols_padded);
    264 
    265   // Creates a parameter instruction, transfers the literal for the parameter to
    266   // server, then stores into "data_handle" the global handle for that
    267   // parameter. When the use_bfloat16 flag is set but the literal has F32
    268   // elements, the literal will be converted to BF16 before being transferred.
    269   std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral(
    270       int64 parameter_number, const Literal& literal, const string& name,
    271       ComputationBuilder* builder, ComputationDataHandle* data_handle);
    272 
    273   // As above, but the caller can specify the device that the literal is
    274   // transferred to. If device_handle is nullptr, the literal will be
    275   // transferred to the default device.
    276   std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral(
    277       int64 parameter_number, const Literal& literal, const string& name,
    278       const DeviceHandle* device_handle, ComputationBuilder* builder,
    279       ComputationDataHandle* data_handle);
    280 
    281   // Creates a parameter instruction and sets the value that will be passed to
    282   // the computation as specified. This function must be used for all parameters
    283   // or none and no parameters must be passed when invoking the computation if
    284   // using this mechanism. If using this mechanism, then each parameter must be
    285   // set exactly once. The first added parameter gets index 0, then 1 and so on.
    286   ComputationDataHandle AddParam(const Literal& argument,
    287                                  ComputationBuilder* builder);
    288 
    289   template <class T>
    290   ComputationDataHandle AddParam(const Array<T>& argument,
    291                                  ComputationBuilder* builder) {
    292     return AddParam(*Literal::CreateFromArray(argument), builder);
    293   }
    294 
    295   // Creates a constant instruction with the given literal. When the
    296   // use_bfloat16 flag is set but the literal has F32 elements, the elements
    297   // will be converted to BF16s.
    298   ComputationDataHandle CreateConstantFromLiteral(const Literal& literal,
    299                                                   ComputationBuilder* builder);
    300 
    301   // Creates a constant instruction with the given array. When the use_bfloat16
    302   // flag is set but the array has float elements, the elements will be
    303   // converted to bfloat16s.
    304   template <typename NativeT>
    305   ComputationDataHandle CreateConstantFromArray(const Array<NativeT>& array,
    306                                                 ComputationBuilder* builder) {
    307     return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder);
    308   }
    309 
    310   // Same as CreateConstantFromArray, but for scalars.
    311   template <typename NativeT>
    312   ComputationDataHandle CreateConstantFromScalar(NativeT value,
    313                                                  ComputationBuilder* builder) {
    314     return CreateConstantFromLiteral(*Literal::CreateR0<NativeT>(value),
    315                                      builder);
    316   }
    317 
    318   // Creates a parameter instruction that wraps a given value and then stores
    319   // into "data_handle" the global handle for that parameter.
    320   //
    321   // "parameter_number" is the parameter number.
    322   // "name" is the name of the parameter instruction.
    323   //
    324   // When the use_bfloat16 flag is set but NativeT is float, the data will be
    325   // converted to bfloat16.
    326   template <typename NativeT>
    327   std::unique_ptr<GlobalData> CreateR0Parameter(
    328       NativeT value, int64 parameter_number, const string& name,
    329       ComputationBuilder* builder, ComputationDataHandle* data_handle);
    330 
    331   // Creates a parameter instruction that wraps the given values and then stores
    332   // into "data_handle" the global handle for that parameter.
    333   //
    334   // "parameter_number" is the parameter number.
    335   // "name" is the name of the parameter instruction.
    336   //
    337   // When the use_bfloat16 flag is set but NativeT is float, the data will be
    338   // converted to bfloat16.
    339   template <typename NativeT>
    340   std::unique_ptr<GlobalData> CreateR1Parameter(
    341       tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
    342       const string& name, ComputationBuilder* builder,
    343       ComputationDataHandle* data_handle);
    344 
    345   // Creates a parameter instruction that wraps the given constant array
    346   // "array_2d" and then stores to "data_handle" the global handle for that
    347   // parameter.
    348   //
    349   // "parameter_number" is the parameter number.
    350   // "name" is the name of the parameter instruction.
    351   //
    352   // When the use_bfloat16 flag is set but NativeT is float, the data will be
    353   // converted to bfloat16.
    354   template <typename NativeT>
    355   std::unique_ptr<GlobalData> CreateR2Parameter(
    356       const Array2D<NativeT>& array_2d, int64 parameter_number,
    357       const string& name, ComputationBuilder* builder,
    358       ComputationDataHandle* data_handle);
    359 
    360   // Creates a parameter instruction that wraps the given constant array
    361   // "array_3d" and then stores to "data_handle" the global handle for that
    362   // parameter.
    363   //
    364   // "parameter_number" is the parameter number.
    365   // "name" is the name of the parameter instruction.
    366   //
    367   // When the use_bfloat16 flag is set but NativeT is float, the data will be
    368   // converted to bfloat16.
    369   template <typename NativeT>
    370   std::unique_ptr<GlobalData> CreateR3Parameter(
    371       const Array3D<NativeT>& array_3d, int64 parameter_number,
    372       const string& name, ComputationBuilder* builder,
    373       ComputationDataHandle* data_handle);
    374 
    375   // Getter and setter for the use_bfloat16 flag, which indicates whether to run
    376   // tests with all float-type input/output converted to bfloat16.
    377   bool use_bfloat16() const { return use_bfloat16_; }
    378   void set_use_bfloat16(bool value) { use_bfloat16_ = value; }
    379 
    380   // The float type used in this test, BF16 or F32 according to use_bfloat16.
    381   PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; }
    382 
    383   Client* client_;
    384   ExecutionOptions execution_options_;
    385 
    386  private:
    387   // Build and run the computation with all permutations of output layouts.
    388   tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts(
    389       const xla::Computation& computation, const Literal& expected,
    390       tensorflow::gtl::ArraySlice<GlobalData*> arguments,
    391       const std::function<void(const Literal& actual,
    392                                const string& error_message)>& verify_output);
    393   // Build and run the computation with all permutations of layouts of all input
    394   // arguments.
    395   tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts(
    396       const xla::Computation& computation, const Literal& expected,
    397       tensorflow::gtl::ArraySlice<GlobalData*> arguments,
    398       const std::function<void(const Literal& actual,
    399                                const string& error_message)>& verify_output,
    400       const Shape* output_with_layout = nullptr);
    401 
    402   // Executes the computation and calculates the expected reference value using
    403   // the HloEvaluator. Returns two literal in the order of (expected, actual).
    404   StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
    405   ComputeValueAndReference(ComputationBuilder* builder,
    406                            const ComputationDataHandle& operand,
    407                            tensorflow::gtl::ArraySlice<Literal> arguments);
    408 
    409   // Whether to run tests with all float-type input/output converted to
    410   // bfloat16.
    411   bool use_bfloat16_ = false;
    412 
    413   // Arguments to be passed to the computation when it runs.
    414   std::vector<std::unique_ptr<GlobalData>> arguments_;
    415 };
    416 
    417 template <typename NativeT>
    418 void ClientLibraryTestBase::ComputeAndCompareR0(
    419     ComputationBuilder* builder, NativeT expected,
    420     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
    421   std::unique_ptr<Literal> expected_literal =
    422       Literal::CreateR0<NativeT>(expected);
    423   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
    424                                                   arguments);
    425 }
    426 
    427 template <typename NativeT>
    428 void ClientLibraryTestBase::ComputeAndCompareR0(
    429     ComputationBuilder* builder, NativeT expected,
    430     tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
    431   static_assert(std::is_same<NativeT, float>::value ||
    432                     std::is_same<NativeT, double>::value ||
    433                     std::is_same<NativeT, bfloat16>::value ||
    434                     std::is_same<NativeT, half>::value ||
    435                     std::is_same<NativeT, complex64>::value,
    436                 "Float or complex type required when specifying an ErrorSpec");
    437   std::unique_ptr<Literal> expected_literal =
    438       Literal::CreateR0<NativeT>(expected);
    439   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
    440                                                   arguments, error);
    441 }
    442 
    443 template <typename NativeT>
    444 void ClientLibraryTestBase::ComputeAndCompareR1(
    445     ComputationBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
    446     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
    447   std::unique_ptr<Literal> expected_literal =
    448       Literal::CreateR1<NativeT>(expected);
    449   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
    450                                                   arguments);
    451 }
    452 
    453 template <typename NativeT>
    454 void ClientLibraryTestBase::ComputeAndCompareR1(
    455     ComputationBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
    456     tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
    457   static_assert(std::is_same<NativeT, float>::value ||
    458                     std::is_same<NativeT, double>::value ||
    459                     std::is_same<NativeT, bfloat16>::value ||
    460                     std::is_same<NativeT, half>::value ||
    461                     std::is_same<NativeT, complex64>::value,
    462                 "Float or complex type required when specifying an ErrorSpec");
    463   std::unique_ptr<Literal> expected_literal =
    464       Literal::CreateR1<NativeT>(expected);
    465   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
    466                                                   arguments, error);
    467 }
    468 
    469 template <typename NativeT>
    470 void ClientLibraryTestBase::ComputeAndCompareR2(
    471     ComputationBuilder* builder, const Array2D<NativeT>& expected,
    472     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
    473   std::unique_ptr<Literal> expected_literal =
    474       Literal::CreateR2FromArray2D<NativeT>(expected);
    475   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
    476                                                   arguments);
    477 }
    478 
    479 template <typename NativeT>
    480 void ClientLibraryTestBase::ComputeAndCompareR2(
    481     ComputationBuilder* builder, const Array2D<NativeT>& expected,
    482     tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
    483   static_assert(std::is_same<NativeT, float>::value ||
    484                     std::is_same<NativeT, double>::value ||
    485                     std::is_same<NativeT, bfloat16>::value ||
    486                     std::is_same<NativeT, half>::value ||
    487                     std::is_same<NativeT, complex64>::value,
    488                 "Float or complex type required when specifying an ErrorSpec");
    489   std::unique_ptr<Literal> expected_literal =
    490       Literal::CreateR2FromArray2D<NativeT>(expected);
    491   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
    492                                                   arguments, error);
    493 }
    494 
    495 template <typename NativeT>
    496 void ClientLibraryTestBase::ComputeAndCompareR3(
    497     ComputationBuilder* builder, const Array3D<NativeT>& expected,
    498     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
    499   std::unique_ptr<Literal> expected_literal =
    500       Literal::CreateR3FromArray3D<NativeT>(expected);
    501   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
    502                                                   arguments);
    503 }
    504 
    505 template <typename NativeT>
    506 void ClientLibraryTestBase::ComputeAndCompareR3(
    507     ComputationBuilder* builder, const Array3D<NativeT>& expected,
    508     tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
    509   static_assert(std::is_same<NativeT, float>::value ||
    510                     std::is_same<NativeT, double>::value ||
    511                     std::is_same<NativeT, bfloat16>::value ||
    512                     std::is_same<NativeT, half>::value ||
    513                     std::is_same<NativeT, complex64>::value,
    514                 "Float or complex type required when specifying an ErrorSpec");
    515   std::unique_ptr<Literal> expected_literal =
    516       Literal::CreateR3FromArray3D<NativeT>(expected);
    517   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
    518                                                   arguments, error);
    519 }
    520 
    521 template <typename NativeT>
    522 void ClientLibraryTestBase::ComputeAndCompareR4(
    523     ComputationBuilder* builder, const Array4D<NativeT>& expected,
    524     tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
    525   std::unique_ptr<Literal> expected_literal =
    526       Literal::CreateR4FromArray4D<NativeT>(expected);
    527   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
    528                                                   arguments);
    529 }
    530 
    531 template <typename NativeT>
    532 void ClientLibraryTestBase::ComputeAndCompareR4(
    533     ComputationBuilder* builder, const Array4D<NativeT>& expected,
    534     tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
    535   static_assert(std::is_same<NativeT, float>::value ||
    536                     std::is_same<NativeT, double>::value ||
    537                     std::is_same<NativeT, bfloat16>::value ||
    538                     std::is_same<NativeT, half>::value ||
    539                     std::is_same<NativeT, complex64>::value,
    540                 "Float or complex type required when specifying an ErrorSpec");
    541   std::unique_ptr<Literal> expected_literal =
    542       Literal::CreateR4FromArray4D<NativeT>(expected);
    543   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
    544                                                   arguments, error);
    545 }
    546 
    547 template <typename NativeT>
    548 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
    549     NativeT value, int64 parameter_number, const string& name,
    550     ComputationBuilder* builder, ComputationDataHandle* data_handle) {
    551   std::unique_ptr<Literal> literal = Literal::CreateR0(value);
    552   if (use_bfloat16_ && literal->shape().element_type() == F32) {
    553     literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
    554   }
    555   std::unique_ptr<GlobalData> data =
    556       client_->TransferToServer(*literal).ConsumeValueOrDie();
    557   *data_handle = builder->Parameter(parameter_number, literal->shape(), name);
    558   return data;
    559 }
    560 
    561 template <typename NativeT>
    562 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
    563     tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
    564     const string& name, ComputationBuilder* builder,
    565     ComputationDataHandle* data_handle) {
    566   std::unique_ptr<Literal> literal = Literal::CreateR1(values);
    567   if (use_bfloat16_ && literal->shape().element_type() == F32) {
    568     literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
    569   }
    570   std::unique_ptr<GlobalData> data =
    571       client_->TransferToServer(*literal).ConsumeValueOrDie();
    572   *data_handle = builder->Parameter(parameter_number, literal->shape(), name);
    573   return data;
    574 }
    575 
    576 template <typename NativeT>
    577 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
    578     const Array2D<NativeT>& array_2d, int64 parameter_number,
    579     const string& name, ComputationBuilder* builder,
    580     ComputationDataHandle* data_handle) {
    581   std::unique_ptr<Literal> literal = Literal::CreateR2FromArray2D(array_2d);
    582   if (use_bfloat16_ && literal->shape().element_type() == F32) {
    583     literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
    584   }
    585   std::unique_ptr<GlobalData> data =
    586       client_->TransferToServer(*literal).ConsumeValueOrDie();
    587   *data_handle = builder->Parameter(parameter_number, literal->shape(), name);
    588   return data;
    589 }
    590 
    591 template <typename NativeT>
    592 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
    593     const Array3D<NativeT>& array_3d, int64 parameter_number,
    594     const string& name, ComputationBuilder* builder,
    595     ComputationDataHandle* data_handle) {
    596   std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(array_3d);
    597   if (use_bfloat16_ && literal->shape().element_type() == F32) {
    598     literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
    599   }
    600   std::unique_ptr<GlobalData> data =
    601       client_->TransferToServer(*literal).ConsumeValueOrDie();
    602   *data_handle = builder->Parameter(parameter_number, literal->shape(), name);
    603   return data;
    604 }
    605 
    606 template <typename NativeT>
    607 std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
    608     const int width, NativeT min_value, NativeT max_value, uint32 seed) {
    609   std::vector<NativeT> result(width);
    610   PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
    611   for (int i = 0; i < width; ++i) {
    612     result[i] = generator.get();
    613   }
    614   return result;
    615 }
    616 
    617 template <typename NativeT>
    618 std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
    619     const int rows, const int cols, NativeT min_value, NativeT max_value,
    620     uint32 seed) {
    621   auto result = MakeUnique<Array2D<NativeT>>(rows, cols);
    622   PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
    623   for (int y = 0; y < rows; ++y) {
    624     for (int x = 0; x < cols; ++x) {
    625       (*result)(y, x) = generator.get();
    626     }
    627   }
    628   return result;
    629 }
    630 
    631 }  // namespace xla
    632 
    633 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
    634