1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ 17 #define TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ 18 19 #include <memory> 20 #include <string> 21 #include <type_traits> 22 #include <vector> 23 24 #include "tensorflow/compiler/xla/array2d.h" 25 #include "tensorflow/compiler/xla/array3d.h" 26 #include "tensorflow/compiler/xla/array4d.h" 27 #include "tensorflow/compiler/xla/client/client_library.h" 28 #include "tensorflow/compiler/xla/client/computation.h" 29 #include "tensorflow/compiler/xla/client/computation_builder.h" 30 #include "tensorflow/compiler/xla/client/global_data.h" 31 #include "tensorflow/compiler/xla/literal_util.h" 32 #include "tensorflow/compiler/xla/ptr_util.h" 33 #include "tensorflow/compiler/xla/statusor.h" 34 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 35 #include "tensorflow/compiler/xla/tests/test_utils.h" 36 #include "tensorflow/compiler/xla/xla_data.pb.h" 37 #include "tensorflow/core/lib/core/bitmap.h" 38 #include "tensorflow/core/lib/core/stringpiece.h" 39 #include "tensorflow/core/lib/gtl/array_slice.h" 40 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 41 #include "tensorflow/core/platform/test.h" 42 #include "tensorflow/core/platform/types.h" 43 44 namespace xla { 45 46 // Sets the use_bfloat16 on a container of test cases according to the values in 47 // use_bfloat16_params. Generates one set of test cases for each values in 48 // use_bfloat16_params with that value. Returns the result. 49 template <typename TestCase> 50 std::vector<TestCase> ExpandUseBfloat16( 51 tensorflow::gtl::ArraySlice<bool> use_bfloat16_params, 52 tensorflow::gtl::ArraySlice<TestCase> specs) { 53 std::vector<TestCase> expanded; 54 for (bool use_bfloat16 : use_bfloat16_params) { 55 for (const auto& spec : specs) { 56 expanded.push_back(spec); 57 expanded.back().use_bfloat16 = use_bfloat16; 58 } 59 } 60 return expanded; 61 } 62 63 // A client library test establishes an in-process XLA client connection. 64 class ClientLibraryTestBase : public ::testing::Test { 65 protected: 66 explicit ClientLibraryTestBase( 67 perftools::gputools::Platform* platform = nullptr); 68 69 // Creates a new ClientLibraryTestBase with custom client options. 70 ClientLibraryTestBase(perftools::gputools::Platform* platform, 71 const LocalClientOptions& client_options); 72 73 // Returns the name of the test currently being run. 74 string TestName() const; 75 76 void SetFastMathDisabled(bool disabled) { 77 execution_options_.mutable_debug_options()->set_xla_enable_fast_math( 78 !disabled); 79 } 80 81 void SetSeed(uint64 seed) { execution_options_.set_seed(seed); } 82 83 // Provides mutable access to the execution DebugOptions field; this lets 84 // tests tweak the options that will be used to compile/run the graph. 85 DebugOptions* mutable_debug_options() { 86 return execution_options_.mutable_debug_options(); 87 } 88 89 // TODO(b/25566808): Add helper that populates a literal from a testdata file. 90 91 // Convenience methods for building and running a computation with the member 92 // execution options. Modify execution_options_ in your test if you want to 93 // customize the options. 94 StatusOr<std::unique_ptr<GlobalData>> Execute( 95 ComputationBuilder* builder, 96 tensorflow::gtl::ArraySlice<GlobalData*> arguments); 97 StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer( 98 ComputationBuilder* builder, 99 tensorflow::gtl::ArraySlice<GlobalData*> arguments, 100 const Shape* shape_with_output_layout = nullptr); 101 StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer( 102 const Computation& computation, 103 tensorflow::gtl::ArraySlice<GlobalData*> arguments, 104 const Shape* shape_with_output_layout = nullptr); 105 106 // Convenience OrDie variants of above methods. 107 std::unique_ptr<GlobalData> ExecuteOrDie( 108 ComputationBuilder* builder, 109 tensorflow::gtl::ArraySlice<GlobalData*> arguments); 110 std::unique_ptr<Literal> ExecuteAndTransferOrDie( 111 ComputationBuilder* builder, 112 tensorflow::gtl::ArraySlice<GlobalData*> arguments); 113 114 // Run a computation and return its value as a string. If an error 115 // occurs, then instead return the error as a string. 116 string ExecuteToString(ComputationBuilder* builder, 117 tensorflow::gtl::ArraySlice<GlobalData*> arguments); 118 119 // Convenience methods for building and running a computation, transferring 120 // the result, and comparing it to the expected value(s). Methods are 121 // templated on the native host type which maps to specific XLA types (See 122 // ComputationBuilder for details). For each rank, two forms are provided: one 123 // for floating point types with an ErrorSpec parameter, and one for integral 124 // types without the ErrorSpec parameter. 125 template <typename NativeT> 126 void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected, 127 tensorflow::gtl::ArraySlice<GlobalData*> arguments); 128 template <typename NativeT> 129 void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected, 130 tensorflow::gtl::ArraySlice<GlobalData*> arguments, 131 ErrorSpec error); 132 133 template <typename NativeT> 134 void ComputeAndCompareR1(ComputationBuilder* builder, 135 tensorflow::gtl::ArraySlice<NativeT> expected, 136 tensorflow::gtl::ArraySlice<GlobalData*> arguments); 137 template <typename NativeT> 138 void ComputeAndCompareR1(ComputationBuilder* builder, 139 tensorflow::gtl::ArraySlice<NativeT> expected, 140 tensorflow::gtl::ArraySlice<GlobalData*> arguments, 141 ErrorSpec error); 142 143 // As above, but uses a bitmap to hold the predicate vector to avoid 144 // deficiencies of vector<bool>. 145 void ComputeAndCompareR1(ComputationBuilder* builder, 146 const tensorflow::core::Bitmap& expected, 147 tensorflow::gtl::ArraySlice<GlobalData*> arguments); 148 149 template <typename NativeT> 150 void ComputeAndCompareR2(ComputationBuilder* builder, 151 const Array2D<NativeT>& expected, 152 tensorflow::gtl::ArraySlice<GlobalData*> arguments); 153 template <typename NativeT> 154 void ComputeAndCompareR2(ComputationBuilder* builder, 155 const Array2D<NativeT>& expected, 156 tensorflow::gtl::ArraySlice<GlobalData*> arguments, 157 ErrorSpec error); 158 159 template <typename NativeT> 160 void ComputeAndCompareR3(ComputationBuilder* builder, 161 const Array3D<NativeT>& expected, 162 tensorflow::gtl::ArraySlice<GlobalData*> arguments); 163 template <typename NativeT> 164 void ComputeAndCompareR3(ComputationBuilder* builder, 165 const Array3D<NativeT>& expected, 166 tensorflow::gtl::ArraySlice<GlobalData*> arguments, 167 ErrorSpec error); 168 169 template <typename NativeT> 170 void ComputeAndCompareR4(ComputationBuilder* builder, 171 const Array4D<NativeT>& expected, 172 tensorflow::gtl::ArraySlice<GlobalData*> arguments); 173 template <typename NativeT> 174 void ComputeAndCompareR4(ComputationBuilder* builder, 175 const Array4D<NativeT>& expected, 176 tensorflow::gtl::ArraySlice<GlobalData*> arguments, 177 ErrorSpec error); 178 179 // Build and run the computation and compare the result with the given 180 // literal. shape_with_layout indicates the result layout to request when 181 // calling Execute. 182 void ComputeAndCompareLiteral( 183 ComputationBuilder* builder, const Literal& expected, 184 tensorflow::gtl::ArraySlice<GlobalData*> arguments, 185 const Shape* shape_with_layout = nullptr); 186 void ComputeAndCompareLiteral( 187 ComputationBuilder* builder, const Literal& expected, 188 tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error, 189 const Shape* shape_with_layout = nullptr); 190 191 // ComputeAndCompare variant which returns an error status. 192 tensorflow::Status ComputeAndCompareLiteralWithStatus( 193 ComputationBuilder* builder, const Literal& expected, 194 tensorflow::gtl::ArraySlice<GlobalData*> arguments, 195 const Shape* shape_with_layout = nullptr); 196 tensorflow::Status ComputeAndCompareLiteralWithStatus( 197 ComputationBuilder* builder, const Literal& expected, 198 tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error, 199 const Shape* shape_with_layout = nullptr); 200 201 // Compare the result of the computation to a strings. In XLA strings are 202 // represented using rank-1 U8 shapes. 203 void ComputeAndCompareR1U8( 204 ComputationBuilder* builder, tensorflow::StringPiece expected, 205 tensorflow::gtl::ArraySlice<GlobalData*> arguments); 206 207 // Convenience method for running a built computation, transferring the 208 // result, and comparing it to the expected tuple literal. 209 void ComputeAndCompareTuple( 210 ComputationBuilder* builder, const Literal& expected, 211 tensorflow::gtl::ArraySlice<GlobalData*> arguments); 212 void ComputeAndCompareTuple( 213 ComputationBuilder* builder, const Literal& expected, 214 tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error); 215 216 // Convenience method for running a built computation and comparing the result 217 // with the HloEvaluator. 218 void ComputeAndCompare(ComputationBuilder* builder, 219 const ComputationDataHandle& operand, 220 tensorflow::gtl::ArraySlice<Literal> arguments); 221 void ComputeAndCompare(ComputationBuilder* builder, 222 const ComputationDataHandle& operand, 223 tensorflow::gtl::ArraySlice<Literal> arguments, 224 ErrorSpec error); 225 226 // Create scalar operations for use in reductions. 227 Computation CreateScalarRelu(); 228 Computation CreateScalarMax(); 229 Computation CreateScalarReluSensitivity(); 230 231 // Special case convenience functions for creating filled arrays. 232 233 // Creates an array of pseudorandom values lying between the given minimum and 234 // maximum values. 235 template <typename NativeT> 236 std::vector<NativeT> CreatePseudorandomR1(const int width, NativeT min_value, 237 NativeT max_value, uint32 seed); 238 template <typename NativeT> 239 std::unique_ptr<Array2D<NativeT>> CreatePseudorandomR2(const int rows, 240 const int cols, 241 NativeT min_value, 242 NativeT max_value, 243 uint32 seed); 244 245 // Creates a (rows x cols) array filled in the following form: 246 // 247 // [ 0 1 ... cols-1] 248 // [ 1,000 1,001 ... 1000.0 + cols-1] 249 // [ ... ... ... ...] 250 // [(rows-1)*1000.0 ... ... (rows-1)*1000.0 + cols-1] 251 // 252 // If provided, offset is added uniformly to every element (e.g. an offset of 253 // 64 would cause 0 in the above to be 64, 1 to be 65, 1000 to be 1064, etc.) 254 std::unique_ptr<Array2D<float>> CreatePatternedMatrix(const int rows, 255 const int cols, 256 float offset = 0.0); 257 258 // Creates a (rows x cols) array as above, padded out to 259 // (rows_padded x cols_padded) with zeroes. Requires rows_padded >= rows 260 // and cols_padded > cols. 261 std::unique_ptr<Array2D<float>> CreatePatternedMatrixWithZeroPadding( 262 const int rows, const int cols, const int rows_padded, 263 const int cols_padded); 264 265 // Creates a parameter instruction, transfers the literal for the parameter to 266 // server, then stores into "data_handle" the global handle for that 267 // parameter. When the use_bfloat16 flag is set but the literal has F32 268 // elements, the literal will be converted to BF16 before being transferred. 269 std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral( 270 int64 parameter_number, const Literal& literal, const string& name, 271 ComputationBuilder* builder, ComputationDataHandle* data_handle); 272 273 // As above, but the caller can specify the device that the literal is 274 // transferred to. If device_handle is nullptr, the literal will be 275 // transferred to the default device. 276 std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral( 277 int64 parameter_number, const Literal& literal, const string& name, 278 const DeviceHandle* device_handle, ComputationBuilder* builder, 279 ComputationDataHandle* data_handle); 280 281 // Creates a parameter instruction and sets the value that will be passed to 282 // the computation as specified. This function must be used for all parameters 283 // or none and no parameters must be passed when invoking the computation if 284 // using this mechanism. If using this mechanism, then each parameter must be 285 // set exactly once. The first added parameter gets index 0, then 1 and so on. 286 ComputationDataHandle AddParam(const Literal& argument, 287 ComputationBuilder* builder); 288 289 template <class T> 290 ComputationDataHandle AddParam(const Array<T>& argument, 291 ComputationBuilder* builder) { 292 return AddParam(*Literal::CreateFromArray(argument), builder); 293 } 294 295 // Creates a constant instruction with the given literal. When the 296 // use_bfloat16 flag is set but the literal has F32 elements, the elements 297 // will be converted to BF16s. 298 ComputationDataHandle CreateConstantFromLiteral(const Literal& literal, 299 ComputationBuilder* builder); 300 301 // Creates a constant instruction with the given array. When the use_bfloat16 302 // flag is set but the array has float elements, the elements will be 303 // converted to bfloat16s. 304 template <typename NativeT> 305 ComputationDataHandle CreateConstantFromArray(const Array<NativeT>& array, 306 ComputationBuilder* builder) { 307 return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); 308 } 309 310 // Same as CreateConstantFromArray, but for scalars. 311 template <typename NativeT> 312 ComputationDataHandle CreateConstantFromScalar(NativeT value, 313 ComputationBuilder* builder) { 314 return CreateConstantFromLiteral(*Literal::CreateR0<NativeT>(value), 315 builder); 316 } 317 318 // Creates a parameter instruction that wraps a given value and then stores 319 // into "data_handle" the global handle for that parameter. 320 // 321 // "parameter_number" is the parameter number. 322 // "name" is the name of the parameter instruction. 323 // 324 // When the use_bfloat16 flag is set but NativeT is float, the data will be 325 // converted to bfloat16. 326 template <typename NativeT> 327 std::unique_ptr<GlobalData> CreateR0Parameter( 328 NativeT value, int64 parameter_number, const string& name, 329 ComputationBuilder* builder, ComputationDataHandle* data_handle); 330 331 // Creates a parameter instruction that wraps the given values and then stores 332 // into "data_handle" the global handle for that parameter. 333 // 334 // "parameter_number" is the parameter number. 335 // "name" is the name of the parameter instruction. 336 // 337 // When the use_bfloat16 flag is set but NativeT is float, the data will be 338 // converted to bfloat16. 339 template <typename NativeT> 340 std::unique_ptr<GlobalData> CreateR1Parameter( 341 tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number, 342 const string& name, ComputationBuilder* builder, 343 ComputationDataHandle* data_handle); 344 345 // Creates a parameter instruction that wraps the given constant array 346 // "array_2d" and then stores to "data_handle" the global handle for that 347 // parameter. 348 // 349 // "parameter_number" is the parameter number. 350 // "name" is the name of the parameter instruction. 351 // 352 // When the use_bfloat16 flag is set but NativeT is float, the data will be 353 // converted to bfloat16. 354 template <typename NativeT> 355 std::unique_ptr<GlobalData> CreateR2Parameter( 356 const Array2D<NativeT>& array_2d, int64 parameter_number, 357 const string& name, ComputationBuilder* builder, 358 ComputationDataHandle* data_handle); 359 360 // Creates a parameter instruction that wraps the given constant array 361 // "array_3d" and then stores to "data_handle" the global handle for that 362 // parameter. 363 // 364 // "parameter_number" is the parameter number. 365 // "name" is the name of the parameter instruction. 366 // 367 // When the use_bfloat16 flag is set but NativeT is float, the data will be 368 // converted to bfloat16. 369 template <typename NativeT> 370 std::unique_ptr<GlobalData> CreateR3Parameter( 371 const Array3D<NativeT>& array_3d, int64 parameter_number, 372 const string& name, ComputationBuilder* builder, 373 ComputationDataHandle* data_handle); 374 375 // Getter and setter for the use_bfloat16 flag, which indicates whether to run 376 // tests with all float-type input/output converted to bfloat16. 377 bool use_bfloat16() const { return use_bfloat16_; } 378 void set_use_bfloat16(bool value) { use_bfloat16_ = value; } 379 380 // The float type used in this test, BF16 or F32 according to use_bfloat16. 381 PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; } 382 383 Client* client_; 384 ExecutionOptions execution_options_; 385 386 private: 387 // Build and run the computation with all permutations of output layouts. 388 tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts( 389 const xla::Computation& computation, const Literal& expected, 390 tensorflow::gtl::ArraySlice<GlobalData*> arguments, 391 const std::function<void(const Literal& actual, 392 const string& error_message)>& verify_output); 393 // Build and run the computation with all permutations of layouts of all input 394 // arguments. 395 tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts( 396 const xla::Computation& computation, const Literal& expected, 397 tensorflow::gtl::ArraySlice<GlobalData*> arguments, 398 const std::function<void(const Literal& actual, 399 const string& error_message)>& verify_output, 400 const Shape* output_with_layout = nullptr); 401 402 // Executes the computation and calculates the expected reference value using 403 // the HloEvaluator. Returns two literal in the order of (expected, actual). 404 StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>> 405 ComputeValueAndReference(ComputationBuilder* builder, 406 const ComputationDataHandle& operand, 407 tensorflow::gtl::ArraySlice<Literal> arguments); 408 409 // Whether to run tests with all float-type input/output converted to 410 // bfloat16. 411 bool use_bfloat16_ = false; 412 413 // Arguments to be passed to the computation when it runs. 414 std::vector<std::unique_ptr<GlobalData>> arguments_; 415 }; 416 417 template <typename NativeT> 418 void ClientLibraryTestBase::ComputeAndCompareR0( 419 ComputationBuilder* builder, NativeT expected, 420 tensorflow::gtl::ArraySlice<GlobalData*> arguments) { 421 std::unique_ptr<Literal> expected_literal = 422 Literal::CreateR0<NativeT>(expected); 423 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, 424 arguments); 425 } 426 427 template <typename NativeT> 428 void ClientLibraryTestBase::ComputeAndCompareR0( 429 ComputationBuilder* builder, NativeT expected, 430 tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) { 431 static_assert(std::is_same<NativeT, float>::value || 432 std::is_same<NativeT, double>::value || 433 std::is_same<NativeT, bfloat16>::value || 434 std::is_same<NativeT, half>::value || 435 std::is_same<NativeT, complex64>::value, 436 "Float or complex type required when specifying an ErrorSpec"); 437 std::unique_ptr<Literal> expected_literal = 438 Literal::CreateR0<NativeT>(expected); 439 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, 440 arguments, error); 441 } 442 443 template <typename NativeT> 444 void ClientLibraryTestBase::ComputeAndCompareR1( 445 ComputationBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected, 446 tensorflow::gtl::ArraySlice<GlobalData*> arguments) { 447 std::unique_ptr<Literal> expected_literal = 448 Literal::CreateR1<NativeT>(expected); 449 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, 450 arguments); 451 } 452 453 template <typename NativeT> 454 void ClientLibraryTestBase::ComputeAndCompareR1( 455 ComputationBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected, 456 tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) { 457 static_assert(std::is_same<NativeT, float>::value || 458 std::is_same<NativeT, double>::value || 459 std::is_same<NativeT, bfloat16>::value || 460 std::is_same<NativeT, half>::value || 461 std::is_same<NativeT, complex64>::value, 462 "Float or complex type required when specifying an ErrorSpec"); 463 std::unique_ptr<Literal> expected_literal = 464 Literal::CreateR1<NativeT>(expected); 465 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, 466 arguments, error); 467 } 468 469 template <typename NativeT> 470 void ClientLibraryTestBase::ComputeAndCompareR2( 471 ComputationBuilder* builder, const Array2D<NativeT>& expected, 472 tensorflow::gtl::ArraySlice<GlobalData*> arguments) { 473 std::unique_ptr<Literal> expected_literal = 474 Literal::CreateR2FromArray2D<NativeT>(expected); 475 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, 476 arguments); 477 } 478 479 template <typename NativeT> 480 void ClientLibraryTestBase::ComputeAndCompareR2( 481 ComputationBuilder* builder, const Array2D<NativeT>& expected, 482 tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) { 483 static_assert(std::is_same<NativeT, float>::value || 484 std::is_same<NativeT, double>::value || 485 std::is_same<NativeT, bfloat16>::value || 486 std::is_same<NativeT, half>::value || 487 std::is_same<NativeT, complex64>::value, 488 "Float or complex type required when specifying an ErrorSpec"); 489 std::unique_ptr<Literal> expected_literal = 490 Literal::CreateR2FromArray2D<NativeT>(expected); 491 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, 492 arguments, error); 493 } 494 495 template <typename NativeT> 496 void ClientLibraryTestBase::ComputeAndCompareR3( 497 ComputationBuilder* builder, const Array3D<NativeT>& expected, 498 tensorflow::gtl::ArraySlice<GlobalData*> arguments) { 499 std::unique_ptr<Literal> expected_literal = 500 Literal::CreateR3FromArray3D<NativeT>(expected); 501 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, 502 arguments); 503 } 504 505 template <typename NativeT> 506 void ClientLibraryTestBase::ComputeAndCompareR3( 507 ComputationBuilder* builder, const Array3D<NativeT>& expected, 508 tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) { 509 static_assert(std::is_same<NativeT, float>::value || 510 std::is_same<NativeT, double>::value || 511 std::is_same<NativeT, bfloat16>::value || 512 std::is_same<NativeT, half>::value || 513 std::is_same<NativeT, complex64>::value, 514 "Float or complex type required when specifying an ErrorSpec"); 515 std::unique_ptr<Literal> expected_literal = 516 Literal::CreateR3FromArray3D<NativeT>(expected); 517 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, 518 arguments, error); 519 } 520 521 template <typename NativeT> 522 void ClientLibraryTestBase::ComputeAndCompareR4( 523 ComputationBuilder* builder, const Array4D<NativeT>& expected, 524 tensorflow::gtl::ArraySlice<GlobalData*> arguments) { 525 std::unique_ptr<Literal> expected_literal = 526 Literal::CreateR4FromArray4D<NativeT>(expected); 527 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, 528 arguments); 529 } 530 531 template <typename NativeT> 532 void ClientLibraryTestBase::ComputeAndCompareR4( 533 ComputationBuilder* builder, const Array4D<NativeT>& expected, 534 tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) { 535 static_assert(std::is_same<NativeT, float>::value || 536 std::is_same<NativeT, double>::value || 537 std::is_same<NativeT, bfloat16>::value || 538 std::is_same<NativeT, half>::value || 539 std::is_same<NativeT, complex64>::value, 540 "Float or complex type required when specifying an ErrorSpec"); 541 std::unique_ptr<Literal> expected_literal = 542 Literal::CreateR4FromArray4D<NativeT>(expected); 543 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, 544 arguments, error); 545 } 546 547 template <typename NativeT> 548 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter( 549 NativeT value, int64 parameter_number, const string& name, 550 ComputationBuilder* builder, ComputationDataHandle* data_handle) { 551 std::unique_ptr<Literal> literal = Literal::CreateR0(value); 552 if (use_bfloat16_ && literal->shape().element_type() == F32) { 553 literal = LiteralTestUtil::ConvertF32ToBF16(*literal); 554 } 555 std::unique_ptr<GlobalData> data = 556 client_->TransferToServer(*literal).ConsumeValueOrDie(); 557 *data_handle = builder->Parameter(parameter_number, literal->shape(), name); 558 return data; 559 } 560 561 template <typename NativeT> 562 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter( 563 tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number, 564 const string& name, ComputationBuilder* builder, 565 ComputationDataHandle* data_handle) { 566 std::unique_ptr<Literal> literal = Literal::CreateR1(values); 567 if (use_bfloat16_ && literal->shape().element_type() == F32) { 568 literal = LiteralTestUtil::ConvertF32ToBF16(*literal); 569 } 570 std::unique_ptr<GlobalData> data = 571 client_->TransferToServer(*literal).ConsumeValueOrDie(); 572 *data_handle = builder->Parameter(parameter_number, literal->shape(), name); 573 return data; 574 } 575 576 template <typename NativeT> 577 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter( 578 const Array2D<NativeT>& array_2d, int64 parameter_number, 579 const string& name, ComputationBuilder* builder, 580 ComputationDataHandle* data_handle) { 581 std::unique_ptr<Literal> literal = Literal::CreateR2FromArray2D(array_2d); 582 if (use_bfloat16_ && literal->shape().element_type() == F32) { 583 literal = LiteralTestUtil::ConvertF32ToBF16(*literal); 584 } 585 std::unique_ptr<GlobalData> data = 586 client_->TransferToServer(*literal).ConsumeValueOrDie(); 587 *data_handle = builder->Parameter(parameter_number, literal->shape(), name); 588 return data; 589 } 590 591 template <typename NativeT> 592 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter( 593 const Array3D<NativeT>& array_3d, int64 parameter_number, 594 const string& name, ComputationBuilder* builder, 595 ComputationDataHandle* data_handle) { 596 std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(array_3d); 597 if (use_bfloat16_ && literal->shape().element_type() == F32) { 598 literal = LiteralTestUtil::ConvertF32ToBF16(*literal); 599 } 600 std::unique_ptr<GlobalData> data = 601 client_->TransferToServer(*literal).ConsumeValueOrDie(); 602 *data_handle = builder->Parameter(parameter_number, literal->shape(), name); 603 return data; 604 } 605 606 template <typename NativeT> 607 std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1( 608 const int width, NativeT min_value, NativeT max_value, uint32 seed) { 609 std::vector<NativeT> result(width); 610 PseudorandomGenerator<NativeT> generator(min_value, max_value, seed); 611 for (int i = 0; i < width; ++i) { 612 result[i] = generator.get(); 613 } 614 return result; 615 } 616 617 template <typename NativeT> 618 std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2( 619 const int rows, const int cols, NativeT min_value, NativeT max_value, 620 uint32 seed) { 621 auto result = MakeUnique<Array2D<NativeT>>(rows, cols); 622 PseudorandomGenerator<NativeT> generator(min_value, max_value, seed); 623 for (int y = 0; y < rows; ++y) { 624 for (int x = 0; x < cols; ++x) { 625 (*result)(y, x) = generator.get(); 626 } 627 } 628 return result; 629 } 630 631 } // namespace xla 632 633 #endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ 634