1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <memory> 17 18 #include "tensorflow/compiler/xla/array2d.h" 19 #include "tensorflow/compiler/xla/client/computation.h" 20 #include "tensorflow/compiler/xla/client/computation_builder.h" 21 #include "tensorflow/compiler/xla/client/global_data.h" 22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 23 #include "tensorflow/compiler/xla/client/local_client.h" 24 #include "tensorflow/compiler/xla/literal_util.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 #include "tensorflow/compiler/xla/statusor.h" 27 #include "tensorflow/compiler/xla/test.h" 28 #include "tensorflow/compiler/xla/test_helpers.h" 29 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 30 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 31 #include "tensorflow/compiler/xla/tests/test_macros.h" 32 #include "tensorflow/compiler/xla/tests/test_utils.h" 33 #include "tensorflow/compiler/xla/xla_data.pb.h" 34 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 35 #include "tensorflow/core/platform/types.h" 36 37 namespace xla { 38 namespace { 39 40 class MapTest : public ClientLibraryTestBase { 41 public: 42 explicit MapTest(perftools::gputools::Platform* platform = nullptr) 43 : ClientLibraryTestBase(platform) { 44 mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); 45 mutable_debug_options()->add_xla_disable_hlo_passes("inline"); 46 } 47 48 // Creates a function that adds its scalar argument with the constant 1.0. 49 // 50 // x {R0F32} ----> (add) 51 // / 52 // 1.0f ---------/ 53 Computation CreateAdderToOne() { 54 ComputationBuilder mapped_builder(client_, TestName()); 55 auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 56 auto one = mapped_builder.ConstantR0<float>(1.0); 57 auto adder_to_one = mapped_builder.Add(x, one); 58 auto computation_status = mapped_builder.Build(); 59 TF_CHECK_OK(computation_status.status()); 60 return computation_status.ConsumeValueOrDie(); 61 } 62 63 Computation CreateMax() { 64 ComputationBuilder b(client_, TestName()); 65 auto lhs = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 66 auto rhs = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); 67 b.Max(lhs, rhs); 68 auto computation_status = b.Build(); 69 TF_CHECK_OK(computation_status.status()); 70 return computation_status.ConsumeValueOrDie(); 71 } 72 73 // Creates a computation that accepts an F32 and returns T(1) (ignoring the 74 // argument). 75 template <class T> 76 Computation CreateScalarOne() { 77 ComputationBuilder mapped_builder(client_, "scalar_one"); 78 (void)mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 79 mapped_builder.ConstantR0<T>(1); 80 auto computation_status = mapped_builder.Build(); 81 TF_CHECK_OK(computation_status.status()); 82 return computation_status.ConsumeValueOrDie(); 83 } 84 85 // Creates a function that multiplies its scalar argument by the constant 2.0 86 // 87 // x {R0F32} ----> (mul) 88 // / 89 // 2.0f ---------/ 90 Computation CreateMulByTwo() { 91 ComputationBuilder mapped_builder(client_, TestName()); 92 auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 93 auto two = mapped_builder.ConstantR0<float>(2.0); 94 auto mul_by_two = mapped_builder.Mul(x, two); 95 auto computation_status = mapped_builder.Build(); 96 TF_CHECK_OK(computation_status.status()); 97 return computation_status.ConsumeValueOrDie(); 98 } 99 100 // Creates a function that adds its scalar argument with the constant 1.0 and 101 // then multiplies by the original element. 102 // 103 // /------------------| 104 // / | 105 // x {R0F32} ----> (add) ----> (mul) 106 // / 107 // 1.0f ---------/ 108 Computation CreateAdderToOneTimesItself() { 109 ComputationBuilder mapped_builder(client_, TestName()); 110 auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 111 auto one = mapped_builder.ConstantR0<float>(1.0); 112 auto adder_to_one = mapped_builder.Add(x, one); 113 auto result = mapped_builder.Mul(x, adder_to_one); 114 auto computation_status = mapped_builder.Build(); 115 TF_CHECK_OK(computation_status.status()); 116 return computation_status.ConsumeValueOrDie(); 117 } 118 119 // Creates a function that takes a single parameter and calls map with 120 // "embedded_computation" on it, and then adds "n" to the result. 121 // 122 // x {R0F32} -----------> (map) ----> (add) 123 // / / 124 // embedded_computation --/ n --/ 125 Computation CreateMapPlusN(const Computation& embedded_computation, float n) { 126 ComputationBuilder builder(client_, TestName()); 127 auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 128 auto map = builder.Map({x}, embedded_computation, {}); 129 auto constant_n = builder.ConstantR0<float>(n); 130 auto add = builder.Add(map, constant_n); 131 auto computation_status = builder.Build(); 132 TF_CHECK_OK(computation_status.status()); 133 return computation_status.ConsumeValueOrDie(); 134 } 135 136 // Creates a binary function with signature (F32, F32) -> Pred 137 // defined by (x, y) -> x > y. 138 Computation CreateGt() { 139 ComputationBuilder b(client_, "Gt"); 140 auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 141 auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); 142 auto gt = b.Gt(x, y); 143 auto computation_status = b.Build(); 144 TF_CHECK_OK(computation_status.status()); 145 return computation_status.ConsumeValueOrDie(); 146 } 147 148 // Creates a function that adds three scalar arguments 149 // 150 // x {R0F32} -------| 151 // | 152 // y {R0F32} ----> (add) ---> (add) 153 // / 154 // z {R0F32} ---------------/ 155 Computation CreateTernaryAdder() { 156 ComputationBuilder mapped_builder(client_, "TernaryAdder"); 157 auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 158 auto y = mapped_builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); 159 auto z = mapped_builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "z"); 160 auto xy = mapped_builder.Add(x, y); 161 auto xyz = mapped_builder.Add(xy, z); 162 auto computation_status = mapped_builder.Build(); 163 TF_CHECK_OK(computation_status.status()); 164 return computation_status.ConsumeValueOrDie(); 165 } 166 }; 167 168 TEST_F(MapTest, MapEachElemPlusOneR0) { 169 // Applies lambda (x) (+ x 1)) to an input scalar. 170 ComputationBuilder builder(client_, TestName()); 171 std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(42.0); 172 std::unique_ptr<GlobalData> param0_data = 173 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 174 175 auto param = builder.Parameter(0, param0_literal->shape(), "param0"); 176 auto map = builder.Map({param}, CreateAdderToOne(), {}); 177 178 ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()}, 179 ErrorSpec(0.01f)); 180 } 181 182 XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { 183 // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0. 184 ComputationBuilder builder(client_, TestName()); 185 std::unique_ptr<Literal> param0_literal = Literal::CreateR1<float>({}); 186 std::unique_ptr<GlobalData> param0_data = 187 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 188 189 auto param = builder.Parameter(0, param0_literal->shape(), "param0"); 190 auto map = builder.Map({param}, CreateAdderToOne(), {0}); 191 192 ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()}, 193 ErrorSpec(0.01f)); 194 } 195 196 TEST_F(MapTest, MapEachElemPlusOneR1S4) { 197 // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4. 198 ComputationBuilder builder(client_, TestName()); 199 std::unique_ptr<Literal> param0_literal = 200 Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f}); 201 std::unique_ptr<GlobalData> param0_data = 202 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 203 204 auto param = builder.Parameter(0, param0_literal->shape(), "param0"); 205 auto map = builder.Map({param}, CreateAdderToOne(), {0}); 206 207 ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f}, 208 {param0_data.get()}, ErrorSpec(0.01f)); 209 } 210 211 TEST_F(MapTest, MapEachF32ElementToS32Constant) { 212 ComputationBuilder builder(client_, TestName()); 213 std::unique_ptr<Literal> param0_literal = 214 Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f}); 215 std::unique_ptr<GlobalData> param0_data = 216 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 217 218 auto param = builder.Parameter(0, param0_literal->shape(), "param0"); 219 auto map = builder.Map({param}, CreateScalarOne<int32>(), {0}); 220 221 ComputeAndCompareR1<int32>(&builder, {1, 1, 1, 1}, {param0_data.get()}); 222 } 223 224 TEST_F(MapTest, MapEachF32ElementToU32Constant) { 225 ComputationBuilder builder(client_, TestName()); 226 std::unique_ptr<Literal> param0_literal = 227 Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f}); 228 std::unique_ptr<GlobalData> param0_data = 229 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 230 231 auto param = builder.Parameter(0, param0_literal->shape(), "param0"); 232 auto map = builder.Map({param}, CreateScalarOne<uint32>(), {0}); 233 234 ComputeAndCompareR1<uint32>(&builder, {1, 1, 1, 1}, {param0_data.get()}); 235 } 236 237 TEST_F(MapTest, MapEachElemLongerChainR1) { 238 // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector. 239 ComputationBuilder builder(client_, TestName()); 240 std::unique_ptr<Literal> param0_literal = 241 Literal::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); 242 std::unique_ptr<GlobalData> param0_data = 243 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 244 245 auto param = builder.Parameter(0, param0_literal->shape(), "param0"); 246 auto map = builder.Map({param}, CreateAdderToOneTimesItself(), {0}); 247 248 ComputeAndCompareR1<float>( 249 &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f}, 250 {param0_data.get()}, ErrorSpec(0.01f)); 251 } 252 253 XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { 254 // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then 255 // maps (lambda (x) (* x 2)) on the result. 256 ComputationBuilder builder(client_, TestName()); 257 std::unique_ptr<Literal> param0_literal = Literal::CreateR1<float>({}); 258 std::unique_ptr<GlobalData> param0_data = 259 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 260 261 auto param = builder.Parameter(0, param0_literal->shape(), "param0"); 262 auto map1 = builder.Map({param}, CreateAdderToOne(), {0}); 263 auto map2 = builder.Map({map1}, CreateMulByTwo(), {0}); 264 265 ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()}, 266 ErrorSpec(0.01f)); 267 } 268 269 TEST_F(MapTest, MapMultipleMapsR1S4) { 270 // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then 271 // maps (lambda (x) (* x 2)) on the result. 272 ComputationBuilder builder(client_, TestName()); 273 std::unique_ptr<Literal> param0_literal = 274 Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f}); 275 std::unique_ptr<GlobalData> param0_data = 276 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 277 278 auto param = builder.Parameter(0, param0_literal->shape(), "param0"); 279 auto map1 = builder.Map({param}, CreateAdderToOne(), {0}); 280 auto map2 = builder.Map({map1}, CreateMulByTwo(), {0}); 281 282 ComputeAndCompareR1<float>(&builder, {6.4f, 8.6f, 10.8f, 13.0f}, 283 {param0_data.get()}, ErrorSpec(0.01f)); 284 } 285 286 TEST_F(MapTest, MapEachElemPlusOneR2) { 287 // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector. 288 ComputationBuilder builder(client_, TestName()); 289 std::unique_ptr<Literal> param0_literal = Literal::CreateR2<float>( 290 {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}}); 291 std::unique_ptr<GlobalData> param0_data = 292 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 293 294 auto param = builder.Parameter(0, param0_literal->shape(), "param0"); 295 auto map = builder.Map({param}, CreateAdderToOne(), {0, 1}); 296 297 Array2D<float> expected_array( 298 {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}}); 299 ComputeAndCompareR2<float>(&builder, expected_array, {param0_data.get()}, 300 ErrorSpec(0.01f)); 301 } 302 303 XLA_TEST_F(MapTest, ComplexNestedMaps) { 304 // Constructs a complex graph of embedded computations to test the computation 305 // lowering order. Python equivalent: 306 // 307 // embed1 = lambda x: x + 1 # x + 1 308 // embed2 = lambda x: embed1(x) + 2 # x + 3 309 // embed3 = lambda x: embed1(x) + 4 # x + 5 310 // embed4 = lambda x: embed2(x) + embed3(x) # 2x + 8 311 // embed5 = lambda x: embed2(x) + 6 # x + 9 312 // result = embed5(42) + embed4(7) # (42 + 9) + (2 * 7 + 8) = 73 313 314 Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); 315 316 auto embed1 = CreateAdderToOne(); 317 auto embed2 = CreateMapPlusN(embed1, 2.0); 318 auto embed3 = CreateMapPlusN(embed1, 4.0); 319 320 ComputationBuilder embed4_builder(client_, "embed4"); 321 auto embed4_param = embed4_builder.Parameter(0, scalar_shape, "x"); 322 auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2, {}); 323 auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3, {}); 324 auto embed4_add = embed4_builder.Add(embed4_map_lhs, embed4_map_rhs); 325 auto embed4_status = embed4_builder.Build(); 326 ASSERT_IS_OK(embed4_status.status()); 327 auto embed4 = embed4_status.ConsumeValueOrDie(); 328 329 auto embed5 = CreateMapPlusN(embed2, 6.0); 330 331 ComputationBuilder builder(client_, TestName()); 332 auto constant_42 = builder.ConstantR0<float>(42.0); 333 auto constant_7 = builder.ConstantR0<float>(7.0); 334 auto map_42 = builder.Map({constant_42}, embed5, {}); 335 auto map_7 = builder.Map({constant_7}, embed4, {}); 336 builder.Add(map_42, map_7); 337 338 ComputeAndCompareR0<float>(&builder, 73.0, {}, ErrorSpec(0.01f)); 339 } 340 341 TEST_F(MapTest, VersionedEmbeddedComputation) { 342 // Build a computation X, use it in a map, then add an additional operation to 343 // computation X and use it again in a different map. Verify that the proper 344 // versions of computation X are used in each of the maps. 345 346 // Create a (embedded) computation which adds one to its parameter argument. 347 ComputationBuilder embedded_builder(client_, "EmbeddedComputation"); 348 auto param_0 = 349 embedded_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); 350 auto constant_one = embedded_builder.ConstantR0<float>(1.0); 351 auto adder_to_one = embedded_builder.Add(param_0, constant_one); 352 auto computation_status = embedded_builder.Build(); 353 ASSERT_IS_OK(computation_status.status()); 354 auto embedded_computation = computation_status.ConsumeValueOrDie(); 355 356 ComputationBuilder builder(client_, TestName()); 357 auto constant_vector = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0}); 358 auto map_plus_1 = builder.Map({constant_vector}, embedded_computation, {0}); 359 360 // Add another Add(1) operation to the existing embedded computation. This 361 // requires using the stub interface because the ComputationBuilder does not 362 // allow modification to the Computation objects after they have been built. 363 BinaryOpRequest request; 364 request.set_binop(BINOP_ADD); 365 *request.mutable_lhs() = adder_to_one; 366 *request.mutable_rhs() = constant_one; 367 OpRequest op_request; 368 *op_request.mutable_computation() = embedded_computation.handle(); 369 *op_request.mutable_binary_op_request() = request; 370 OpResponse response; 371 tensorflow::Status s = client_->stub()->Op(&op_request, &response); 372 ASSERT_TRUE(s.ok()); 373 374 auto map_plus_2 = builder.Map({map_plus_1}, embedded_computation, {0}); 375 376 // The original vector has Add(1) applied to it with a map, followed by 377 // Add(1+1) resulting in a net Add(3). 378 ComputeAndCompareR1<float>(&builder, {4.0, 5.0, 6.0, 7.0}, {}, 379 ErrorSpec(0.01f)); 380 } 381 382 TEST_F(MapTest, MapBinaryAdder) { 383 // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors. 384 ComputationBuilder builder(client_, TestName()); 385 std::unique_ptr<Literal> param0_literal = 386 Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f}); 387 std::unique_ptr<GlobalData> param0_data = 388 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 389 std::unique_ptr<Literal> param1_literal = 390 Literal::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f}); 391 std::unique_ptr<GlobalData> param1_data = 392 client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); 393 394 auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); 395 auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); 396 auto map = builder.Map({param0, param1}, 397 CreateScalarAddComputation(F32, &builder), {0}); 398 399 ComputeAndCompareR1<float>(&builder, {7.3f, 7.7, 4.3f, 0}, 400 {param0_data.get(), param1_data.get()}, 401 ErrorSpec(0.01f)); 402 } 403 404 // Adds two rank-2 arrays with different layouts. This test exercises a path 405 // for Map that used to fail in shape inference (b/28989438). 406 XLA_TEST_F(MapTest, AddWithMixedLayouts) { 407 ComputationBuilder builder(client_, TestName()); 408 std::unique_ptr<Literal> param0_literal = Literal::CreateR2WithLayout( 409 {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0})); 410 std::unique_ptr<GlobalData> param0_data = 411 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 412 413 std::unique_ptr<Literal> param1_literal = Literal::CreateR2WithLayout( 414 {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1})); 415 std::unique_ptr<GlobalData> param1_data = 416 client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); 417 418 auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); 419 auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); 420 auto map = builder.Map({param0, param1}, 421 CreateScalarAddComputation(S32, &builder), {0, 1}); 422 423 Array2D<int32> expected(2, 2); 424 expected(0, 0) = 11; 425 expected(0, 1) = 22; 426 expected(1, 0) = 33; 427 expected(1, 1) = 44; 428 ComputeAndCompareR2<int32>(&builder, expected, 429 {param0_data.get(), param1_data.get()}); 430 } 431 432 XLA_TEST_F(MapTest, AddR3_3x0x2) { 433 ComputationBuilder builder(client_, TestName()); 434 std::unique_ptr<Literal> param0_literal = 435 Literal::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2)); 436 std::unique_ptr<GlobalData> param0_data = 437 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 438 439 std::unique_ptr<Literal> param1_literal = 440 Literal::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2)); 441 std::unique_ptr<GlobalData> param1_data = 442 client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); 443 444 auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); 445 auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); 446 auto map = builder.Map({param0, param1}, 447 CreateScalarAddComputation(S32, &builder), {0, 1, 2}); 448 449 ComputeAndCompareR3<int32>(&builder, Array3D<int32>(3, 0, 2), 450 {param0_data.get(), param1_data.get()}); 451 } 452 453 TEST_F(MapTest, MapTernaryAdder) { 454 // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors. 455 ComputationBuilder builder(client_, TestName()); 456 std::unique_ptr<Literal> param0_literal = 457 Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f}); 458 std::unique_ptr<GlobalData> param0_data = 459 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 460 std::unique_ptr<Literal> param1_literal = 461 Literal::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f}); 462 std::unique_ptr<GlobalData> param1_data = 463 client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); 464 std::unique_ptr<Literal> param2_literal = 465 Literal::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f}); 466 std::unique_ptr<GlobalData> param2_data = 467 client_->TransferToServer(*param2_literal).ConsumeValueOrDie(); 468 469 auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); 470 auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); 471 auto param2 = builder.Parameter(2, param2_literal->shape(), "param2"); 472 auto map = builder.Map({param0, param1, param2}, CreateTernaryAdder(), {0}); 473 474 ComputeAndCompareR1<float>( 475 &builder, {-2.7f, -92.3f, -895.7f, -400.0f}, 476 {param0_data.get(), param1_data.get(), param2_data.get()}, 477 ErrorSpec(0.01f)); 478 } 479 480 TEST_F(MapTest, MapGt) { 481 // Maps (x,y) -> x > y onto two R1F32 vectors. 482 ComputationBuilder b(client_, TestName()); 483 auto gt = CreateGt(); 484 b.Map({b.ConstantR1<float>({1, 20}), b.ConstantR1<float>({10, 2})}, gt, {0}); 485 ComputeAndCompareR1<bool>(&b, {false, true}, {}); 486 } 487 488 TEST_F(MapTest, NestedBinaryMap) { 489 Computation max_with_square; 490 { 491 // max_with_square(x) = do max(x, x^2) via a map. 492 ComputationBuilder b(client_, "max_with_square"); 493 auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 494 b.Map({x, b.Mul(x, x)}, CreateMax(), {}); 495 auto computation_status = b.Build(); 496 ASSERT_IS_OK(computation_status.status()); 497 max_with_square = computation_status.ConsumeValueOrDie(); 498 } 499 ComputationBuilder b(client_, TestName()); 500 auto input = b.ConstantR1<float>({0.1f, 0.5f, -0.5f, 1.0f, 2.0f}); 501 b.Map({input}, max_with_square, {0}); 502 ComputeAndCompareR1<float>(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {}); 503 } 504 505 TEST_F(MapTest, MapOperantionWithBuildError) { 506 // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors but uses an unsupported 507 // type combination (F32 + U16) to test that the error is reported to the 508 // outermost ComputationBuilder. 509 ComputationBuilder builder(client_, TestName()); 510 511 auto sub_builder = builder.CreateSubBuilder("ErrorAdd"); 512 auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 513 auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(U16, {}), "y"); 514 auto adder = sub_builder->Add(x, y); 515 auto error_add = sub_builder->BuildAndNoteError(); 516 517 std::unique_ptr<Literal> param0_literal = 518 Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f}); 519 std::unique_ptr<GlobalData> param0_data = 520 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 521 std::unique_ptr<Literal> param1_literal = 522 Literal::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f}); 523 std::unique_ptr<GlobalData> param1_data = 524 client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); 525 526 auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); 527 auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); 528 auto map = builder.Map({param0, param1}, error_add, {0}); 529 530 StatusOr<Computation> computation_status = builder.Build(); 531 ASSERT_TRUE(!computation_status.ok()); 532 EXPECT_THAT( 533 computation_status.status().ToString(), 534 ::testing::HasSubstr("error from: ErrorAdd: binary op BINOP_ADD with " 535 "different element types: f32[] and u16[]")); 536 } 537 538 // MapTest disables inline and algsimp. MapTestWithFullOpt runs all 539 // optimizations. 540 using MapTestWithFullOpt = ClientLibraryTestBase; 541 542 // Regression test for b/31466798. The inliner simplifies map(param0, param1, 543 // power) to power(param0, param1) without deleting the old subcomputation which 544 // is the same as the new entry computation. HloSubcomputationUnification used 545 // to have issues with such patterns and maybe invalidate the pointer to entry 546 // computation. 547 TEST_F(MapTestWithFullOpt, MapScalarPower) { 548 ComputationBuilder builder(client_, TestName()); 549 550 auto sub_builder = builder.CreateSubBuilder("power"); 551 auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 552 auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); 553 sub_builder->Pow(x, y); 554 auto power = sub_builder->BuildAndNoteError(); 555 556 std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(2.0f); 557 std::unique_ptr<Literal> param1_literal = Literal::CreateR0<float>(5.0f); 558 std::unique_ptr<GlobalData> param0_data = 559 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 560 std::unique_ptr<GlobalData> param1_data = 561 client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); 562 563 auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); 564 auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); 565 builder.Map({param0, param1}, power, {}); 566 567 ComputeAndCompareR0<float>(&builder, 32.0f, 568 {param0_data.get(), param1_data.get()}, 569 ErrorSpec(0.01f)); 570 } 571 572 // Regression test for b/35786417, where the inliner would not notice the change 573 // of parameter order inside the map. 574 TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { 575 ComputationBuilder builder(client_, TestName()); 576 577 auto sub_builder = builder.CreateSubBuilder("power"); 578 auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 579 auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); 580 sub_builder->Sub(y, x); // note that this is y - x, not x - y 581 auto sub_opposite = sub_builder->BuildAndNoteError(); 582 583 std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(2.0f); 584 std::unique_ptr<Literal> param1_literal = Literal::CreateR0<float>(5.0f); 585 std::unique_ptr<GlobalData> param0_data = 586 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 587 std::unique_ptr<GlobalData> param1_data = 588 client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); 589 590 auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); 591 auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); 592 builder.Map({param0, param1}, sub_opposite, {}); 593 594 ComputeAndCompareR0<float>( 595 &builder, 3.0f, {param0_data.get(), param1_data.get()}, ErrorSpec(0.01f)); 596 } 597 598 // Regression test for b/35786417, where the inliner would CHECK-fail due to the 599 // mul inside the map having more parameters than the map does. 600 TEST_F(MapTestWithFullOpt, MapSquare) { 601 ComputationBuilder builder(client_, TestName()); 602 603 auto sub_builder = builder.CreateSubBuilder("power"); 604 auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); 605 sub_builder->Mul(x, x); 606 auto square = sub_builder->BuildAndNoteError(); 607 608 std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(10.0f); 609 std::unique_ptr<GlobalData> param0_data = 610 client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); 611 612 auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); 613 builder.Map({param0}, square, {}); 614 615 ComputeAndCompareR0<float>(&builder, 100.0f, {param0_data.get()}, 616 ErrorSpec(0.01f)); 617 } 618 619 } // namespace 620 } // namespace xla 621