Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <memory>
     17 
     18 #include "tensorflow/compiler/xla/array2d.h"
     19 #include "tensorflow/compiler/xla/client/computation.h"
     20 #include "tensorflow/compiler/xla/client/computation_builder.h"
     21 #include "tensorflow/compiler/xla/client/global_data.h"
     22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
     23 #include "tensorflow/compiler/xla/client/local_client.h"
     24 #include "tensorflow/compiler/xla/literal_util.h"
     25 #include "tensorflow/compiler/xla/shape_util.h"
     26 #include "tensorflow/compiler/xla/statusor.h"
     27 #include "tensorflow/compiler/xla/test.h"
     28 #include "tensorflow/compiler/xla/test_helpers.h"
     29 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     30 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     31 #include "tensorflow/compiler/xla/tests/test_macros.h"
     32 #include "tensorflow/compiler/xla/tests/test_utils.h"
     33 #include "tensorflow/compiler/xla/xla_data.pb.h"
     34 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     35 #include "tensorflow/core/platform/types.h"
     36 
     37 namespace xla {
     38 namespace {
     39 
     40 class MapTest : public ClientLibraryTestBase {
     41  public:
     42   explicit MapTest(perftools::gputools::Platform* platform = nullptr)
     43       : ClientLibraryTestBase(platform) {
     44     mutable_debug_options()->add_xla_disable_hlo_passes("algsimp");
     45     mutable_debug_options()->add_xla_disable_hlo_passes("inline");
     46   }
     47 
     48   // Creates a function that adds its scalar argument with the constant 1.0.
     49   //
     50   // x {R0F32} ----> (add)
     51   //                /
     52   // 1.0f ---------/
     53   Computation CreateAdderToOne() {
     54     ComputationBuilder mapped_builder(client_, TestName());
     55     auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
     56     auto one = mapped_builder.ConstantR0<float>(1.0);
     57     auto adder_to_one = mapped_builder.Add(x, one);
     58     auto computation_status = mapped_builder.Build();
     59     TF_CHECK_OK(computation_status.status());
     60     return computation_status.ConsumeValueOrDie();
     61   }
     62 
     63   Computation CreateMax() {
     64     ComputationBuilder b(client_, TestName());
     65     auto lhs = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
     66     auto rhs = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
     67     b.Max(lhs, rhs);
     68     auto computation_status = b.Build();
     69     TF_CHECK_OK(computation_status.status());
     70     return computation_status.ConsumeValueOrDie();
     71   }
     72 
     73   // Creates a computation that accepts an F32 and returns T(1) (ignoring the
     74   // argument).
     75   template <class T>
     76   Computation CreateScalarOne() {
     77     ComputationBuilder mapped_builder(client_, "scalar_one");
     78     (void)mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
     79     mapped_builder.ConstantR0<T>(1);
     80     auto computation_status = mapped_builder.Build();
     81     TF_CHECK_OK(computation_status.status());
     82     return computation_status.ConsumeValueOrDie();
     83   }
     84 
     85   // Creates a function that multiplies its scalar argument by the constant 2.0
     86   //
     87   // x {R0F32} ----> (mul)
     88   //                /
     89   // 2.0f ---------/
     90   Computation CreateMulByTwo() {
     91     ComputationBuilder mapped_builder(client_, TestName());
     92     auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
     93     auto two = mapped_builder.ConstantR0<float>(2.0);
     94     auto mul_by_two = mapped_builder.Mul(x, two);
     95     auto computation_status = mapped_builder.Build();
     96     TF_CHECK_OK(computation_status.status());
     97     return computation_status.ConsumeValueOrDie();
     98   }
     99 
    100   // Creates a function that adds its scalar argument with the constant 1.0 and
    101   // then multiplies by the original element.
    102   //
    103   //           /------------------|
    104   //          /                   |
    105   // x {R0F32} ----> (add) ----> (mul)
    106   //                /
    107   // 1.0f ---------/
    108   Computation CreateAdderToOneTimesItself() {
    109     ComputationBuilder mapped_builder(client_, TestName());
    110     auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    111     auto one = mapped_builder.ConstantR0<float>(1.0);
    112     auto adder_to_one = mapped_builder.Add(x, one);
    113     auto result = mapped_builder.Mul(x, adder_to_one);
    114     auto computation_status = mapped_builder.Build();
    115     TF_CHECK_OK(computation_status.status());
    116     return computation_status.ConsumeValueOrDie();
    117   }
    118 
    119   // Creates a function that takes a single parameter and calls map with
    120   // "embedded_computation" on it, and then adds "n" to the result.
    121   //
    122   // x {R0F32} -----------> (map) ----> (add)
    123   //                         /           /
    124   // embedded_computation --/       n --/
    125   Computation CreateMapPlusN(const Computation& embedded_computation, float n) {
    126     ComputationBuilder builder(client_, TestName());
    127     auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    128     auto map = builder.Map({x}, embedded_computation, {});
    129     auto constant_n = builder.ConstantR0<float>(n);
    130     auto add = builder.Add(map, constant_n);
    131     auto computation_status = builder.Build();
    132     TF_CHECK_OK(computation_status.status());
    133     return computation_status.ConsumeValueOrDie();
    134   }
    135 
    136   // Creates a binary function with signature (F32, F32) -> Pred
    137   // defined by (x, y) -> x > y.
    138   Computation CreateGt() {
    139     ComputationBuilder b(client_, "Gt");
    140     auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    141     auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
    142     auto gt = b.Gt(x, y);
    143     auto computation_status = b.Build();
    144     TF_CHECK_OK(computation_status.status());
    145     return computation_status.ConsumeValueOrDie();
    146   }
    147 
    148   // Creates a function that adds three scalar arguments
    149   //
    150   // x {R0F32} -------|
    151   //                  |
    152   // y {R0F32} ----> (add) ---> (add)
    153   //                           /
    154   // z {R0F32} ---------------/
    155   Computation CreateTernaryAdder() {
    156     ComputationBuilder mapped_builder(client_, "TernaryAdder");
    157     auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    158     auto y = mapped_builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
    159     auto z = mapped_builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "z");
    160     auto xy = mapped_builder.Add(x, y);
    161     auto xyz = mapped_builder.Add(xy, z);
    162     auto computation_status = mapped_builder.Build();
    163     TF_CHECK_OK(computation_status.status());
    164     return computation_status.ConsumeValueOrDie();
    165   }
    166 };
    167 
    168 TEST_F(MapTest, MapEachElemPlusOneR0) {
    169   // Applies lambda (x) (+ x 1)) to an input scalar.
    170   ComputationBuilder builder(client_, TestName());
    171   std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(42.0);
    172   std::unique_ptr<GlobalData> param0_data =
    173       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    174 
    175   auto param = builder.Parameter(0, param0_literal->shape(), "param0");
    176   auto map = builder.Map({param}, CreateAdderToOne(), {});
    177 
    178   ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()},
    179                              ErrorSpec(0.01f));
    180 }
    181 
    182 XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
    183   // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0.
    184   ComputationBuilder builder(client_, TestName());
    185   std::unique_ptr<Literal> param0_literal = Literal::CreateR1<float>({});
    186   std::unique_ptr<GlobalData> param0_data =
    187       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    188 
    189   auto param = builder.Parameter(0, param0_literal->shape(), "param0");
    190   auto map = builder.Map({param}, CreateAdderToOne(), {0});
    191 
    192   ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
    193                              ErrorSpec(0.01f));
    194 }
    195 
    196 TEST_F(MapTest, MapEachElemPlusOneR1S4) {
    197   // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4.
    198   ComputationBuilder builder(client_, TestName());
    199   std::unique_ptr<Literal> param0_literal =
    200       Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
    201   std::unique_ptr<GlobalData> param0_data =
    202       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    203 
    204   auto param = builder.Parameter(0, param0_literal->shape(), "param0");
    205   auto map = builder.Map({param}, CreateAdderToOne(), {0});
    206 
    207   ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f},
    208                              {param0_data.get()}, ErrorSpec(0.01f));
    209 }
    210 
    211 TEST_F(MapTest, MapEachF32ElementToS32Constant) {
    212   ComputationBuilder builder(client_, TestName());
    213   std::unique_ptr<Literal> param0_literal =
    214       Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
    215   std::unique_ptr<GlobalData> param0_data =
    216       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    217 
    218   auto param = builder.Parameter(0, param0_literal->shape(), "param0");
    219   auto map = builder.Map({param}, CreateScalarOne<int32>(), {0});
    220 
    221   ComputeAndCompareR1<int32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
    222 }
    223 
    224 TEST_F(MapTest, MapEachF32ElementToU32Constant) {
    225   ComputationBuilder builder(client_, TestName());
    226   std::unique_ptr<Literal> param0_literal =
    227       Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
    228   std::unique_ptr<GlobalData> param0_data =
    229       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    230 
    231   auto param = builder.Parameter(0, param0_literal->shape(), "param0");
    232   auto map = builder.Map({param}, CreateScalarOne<uint32>(), {0});
    233 
    234   ComputeAndCompareR1<uint32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
    235 }
    236 
    237 TEST_F(MapTest, MapEachElemLongerChainR1) {
    238   // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector.
    239   ComputationBuilder builder(client_, TestName());
    240   std::unique_ptr<Literal> param0_literal =
    241       Literal::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
    242   std::unique_ptr<GlobalData> param0_data =
    243       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    244 
    245   auto param = builder.Parameter(0, param0_literal->shape(), "param0");
    246   auto map = builder.Map({param}, CreateAdderToOneTimesItself(), {0});
    247 
    248   ComputeAndCompareR1<float>(
    249       &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f},
    250       {param0_data.get()}, ErrorSpec(0.01f));
    251 }
    252 
    253 XLA_TEST_F(MapTest, MapMultipleMapsR1S0) {
    254   // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then
    255   // maps (lambda (x) (* x 2)) on the result.
    256   ComputationBuilder builder(client_, TestName());
    257   std::unique_ptr<Literal> param0_literal = Literal::CreateR1<float>({});
    258   std::unique_ptr<GlobalData> param0_data =
    259       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    260 
    261   auto param = builder.Parameter(0, param0_literal->shape(), "param0");
    262   auto map1 = builder.Map({param}, CreateAdderToOne(), {0});
    263   auto map2 = builder.Map({map1}, CreateMulByTwo(), {0});
    264 
    265   ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
    266                              ErrorSpec(0.01f));
    267 }
    268 
    269 TEST_F(MapTest, MapMultipleMapsR1S4) {
    270   // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then
    271   // maps (lambda (x) (* x 2)) on the result.
    272   ComputationBuilder builder(client_, TestName());
    273   std::unique_ptr<Literal> param0_literal =
    274       Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
    275   std::unique_ptr<GlobalData> param0_data =
    276       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    277 
    278   auto param = builder.Parameter(0, param0_literal->shape(), "param0");
    279   auto map1 = builder.Map({param}, CreateAdderToOne(), {0});
    280   auto map2 = builder.Map({map1}, CreateMulByTwo(), {0});
    281 
    282   ComputeAndCompareR1<float>(&builder, {6.4f, 8.6f, 10.8f, 13.0f},
    283                              {param0_data.get()}, ErrorSpec(0.01f));
    284 }
    285 
    286 TEST_F(MapTest, MapEachElemPlusOneR2) {
    287   // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector.
    288   ComputationBuilder builder(client_, TestName());
    289   std::unique_ptr<Literal> param0_literal = Literal::CreateR2<float>(
    290       {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}});
    291   std::unique_ptr<GlobalData> param0_data =
    292       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    293 
    294   auto param = builder.Parameter(0, param0_literal->shape(), "param0");
    295   auto map = builder.Map({param}, CreateAdderToOne(), {0, 1});
    296 
    297   Array2D<float> expected_array(
    298       {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}});
    299   ComputeAndCompareR2<float>(&builder, expected_array, {param0_data.get()},
    300                              ErrorSpec(0.01f));
    301 }
    302 
    303 XLA_TEST_F(MapTest, ComplexNestedMaps) {
    304   // Constructs a complex graph of embedded computations to test the computation
    305   // lowering order. Python equivalent:
    306   //
    307   //   embed1 = lambda x: x + 1                  #  x + 1
    308   //   embed2 = lambda x: embed1(x) + 2          #  x + 3
    309   //   embed3 = lambda x: embed1(x) + 4          #  x + 5
    310   //   embed4 = lambda x: embed2(x) + embed3(x)  # 2x + 8
    311   //   embed5 = lambda x: embed2(x) + 6          #  x + 9
    312   //   result = embed5(42) + embed4(7)           # (42 + 9) + (2 * 7 + 8) = 73
    313 
    314   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
    315 
    316   auto embed1 = CreateAdderToOne();
    317   auto embed2 = CreateMapPlusN(embed1, 2.0);
    318   auto embed3 = CreateMapPlusN(embed1, 4.0);
    319 
    320   ComputationBuilder embed4_builder(client_, "embed4");
    321   auto embed4_param = embed4_builder.Parameter(0, scalar_shape, "x");
    322   auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2, {});
    323   auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3, {});
    324   auto embed4_add = embed4_builder.Add(embed4_map_lhs, embed4_map_rhs);
    325   auto embed4_status = embed4_builder.Build();
    326   ASSERT_IS_OK(embed4_status.status());
    327   auto embed4 = embed4_status.ConsumeValueOrDie();
    328 
    329   auto embed5 = CreateMapPlusN(embed2, 6.0);
    330 
    331   ComputationBuilder builder(client_, TestName());
    332   auto constant_42 = builder.ConstantR0<float>(42.0);
    333   auto constant_7 = builder.ConstantR0<float>(7.0);
    334   auto map_42 = builder.Map({constant_42}, embed5, {});
    335   auto map_7 = builder.Map({constant_7}, embed4, {});
    336   builder.Add(map_42, map_7);
    337 
    338   ComputeAndCompareR0<float>(&builder, 73.0, {}, ErrorSpec(0.01f));
    339 }
    340 
    341 TEST_F(MapTest, VersionedEmbeddedComputation) {
    342   // Build a computation X, use it in a map, then add an additional operation to
    343   // computation X and use it again in a different map. Verify that the proper
    344   // versions of computation X are used in each of the maps.
    345 
    346   // Create a (embedded) computation which adds one to its parameter argument.
    347   ComputationBuilder embedded_builder(client_, "EmbeddedComputation");
    348   auto param_0 =
    349       embedded_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0");
    350   auto constant_one = embedded_builder.ConstantR0<float>(1.0);
    351   auto adder_to_one = embedded_builder.Add(param_0, constant_one);
    352   auto computation_status = embedded_builder.Build();
    353   ASSERT_IS_OK(computation_status.status());
    354   auto embedded_computation = computation_status.ConsumeValueOrDie();
    355 
    356   ComputationBuilder builder(client_, TestName());
    357   auto constant_vector = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
    358   auto map_plus_1 = builder.Map({constant_vector}, embedded_computation, {0});
    359 
    360   // Add another Add(1) operation to the existing embedded computation. This
    361   // requires using the stub interface because the ComputationBuilder does not
    362   // allow modification to the Computation objects after they have been built.
    363   BinaryOpRequest request;
    364   request.set_binop(BINOP_ADD);
    365   *request.mutable_lhs() = adder_to_one;
    366   *request.mutable_rhs() = constant_one;
    367   OpRequest op_request;
    368   *op_request.mutable_computation() = embedded_computation.handle();
    369   *op_request.mutable_binary_op_request() = request;
    370   OpResponse response;
    371   tensorflow::Status s = client_->stub()->Op(&op_request, &response);
    372   ASSERT_TRUE(s.ok());
    373 
    374   auto map_plus_2 = builder.Map({map_plus_1}, embedded_computation, {0});
    375 
    376   // The original vector has Add(1) applied to it with a map, followed by
    377   // Add(1+1) resulting in a net Add(3).
    378   ComputeAndCompareR1<float>(&builder, {4.0, 5.0, 6.0, 7.0}, {},
    379                              ErrorSpec(0.01f));
    380 }
    381 
    382 TEST_F(MapTest, MapBinaryAdder) {
    383   // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors.
    384   ComputationBuilder builder(client_, TestName());
    385   std::unique_ptr<Literal> param0_literal =
    386       Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
    387   std::unique_ptr<GlobalData> param0_data =
    388       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    389   std::unique_ptr<Literal> param1_literal =
    390       Literal::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
    391   std::unique_ptr<GlobalData> param1_data =
    392       client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
    393 
    394   auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
    395   auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
    396   auto map = builder.Map({param0, param1},
    397                          CreateScalarAddComputation(F32, &builder), {0});
    398 
    399   ComputeAndCompareR1<float>(&builder, {7.3f, 7.7, 4.3f, 0},
    400                              {param0_data.get(), param1_data.get()},
    401                              ErrorSpec(0.01f));
    402 }
    403 
    404 // Adds two rank-2 arrays with different layouts. This test exercises a path
    405 // for Map that used to fail in shape inference (b/28989438).
    406 XLA_TEST_F(MapTest, AddWithMixedLayouts) {
    407   ComputationBuilder builder(client_, TestName());
    408   std::unique_ptr<Literal> param0_literal = Literal::CreateR2WithLayout(
    409       {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
    410   std::unique_ptr<GlobalData> param0_data =
    411       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    412 
    413   std::unique_ptr<Literal> param1_literal = Literal::CreateR2WithLayout(
    414       {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
    415   std::unique_ptr<GlobalData> param1_data =
    416       client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
    417 
    418   auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
    419   auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
    420   auto map = builder.Map({param0, param1},
    421                          CreateScalarAddComputation(S32, &builder), {0, 1});
    422 
    423   Array2D<int32> expected(2, 2);
    424   expected(0, 0) = 11;
    425   expected(0, 1) = 22;
    426   expected(1, 0) = 33;
    427   expected(1, 1) = 44;
    428   ComputeAndCompareR2<int32>(&builder, expected,
    429                              {param0_data.get(), param1_data.get()});
    430 }
    431 
    432 XLA_TEST_F(MapTest, AddR3_3x0x2) {
    433   ComputationBuilder builder(client_, TestName());
    434   std::unique_ptr<Literal> param0_literal =
    435       Literal::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
    436   std::unique_ptr<GlobalData> param0_data =
    437       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    438 
    439   std::unique_ptr<Literal> param1_literal =
    440       Literal::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
    441   std::unique_ptr<GlobalData> param1_data =
    442       client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
    443 
    444   auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
    445   auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
    446   auto map = builder.Map({param0, param1},
    447                          CreateScalarAddComputation(S32, &builder), {0, 1, 2});
    448 
    449   ComputeAndCompareR3<int32>(&builder, Array3D<int32>(3, 0, 2),
    450                              {param0_data.get(), param1_data.get()});
    451 }
    452 
    453 TEST_F(MapTest, MapTernaryAdder) {
    454   // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors.
    455   ComputationBuilder builder(client_, TestName());
    456   std::unique_ptr<Literal> param0_literal =
    457       Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
    458   std::unique_ptr<GlobalData> param0_data =
    459       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    460   std::unique_ptr<Literal> param1_literal =
    461       Literal::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
    462   std::unique_ptr<GlobalData> param1_data =
    463       client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
    464   std::unique_ptr<Literal> param2_literal =
    465       Literal::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
    466   std::unique_ptr<GlobalData> param2_data =
    467       client_->TransferToServer(*param2_literal).ConsumeValueOrDie();
    468 
    469   auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
    470   auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
    471   auto param2 = builder.Parameter(2, param2_literal->shape(), "param2");
    472   auto map = builder.Map({param0, param1, param2}, CreateTernaryAdder(), {0});
    473 
    474   ComputeAndCompareR1<float>(
    475       &builder, {-2.7f, -92.3f, -895.7f, -400.0f},
    476       {param0_data.get(), param1_data.get(), param2_data.get()},
    477       ErrorSpec(0.01f));
    478 }
    479 
    480 TEST_F(MapTest, MapGt) {
    481   // Maps (x,y) -> x > y onto two R1F32 vectors.
    482   ComputationBuilder b(client_, TestName());
    483   auto gt = CreateGt();
    484   b.Map({b.ConstantR1<float>({1, 20}), b.ConstantR1<float>({10, 2})}, gt, {0});
    485   ComputeAndCompareR1<bool>(&b, {false, true}, {});
    486 }
    487 
    488 TEST_F(MapTest, NestedBinaryMap) {
    489   Computation max_with_square;
    490   {
    491     // max_with_square(x) = do max(x, x^2) via a map.
    492     ComputationBuilder b(client_, "max_with_square");
    493     auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    494     b.Map({x, b.Mul(x, x)}, CreateMax(), {});
    495     auto computation_status = b.Build();
    496     ASSERT_IS_OK(computation_status.status());
    497     max_with_square = computation_status.ConsumeValueOrDie();
    498   }
    499   ComputationBuilder b(client_, TestName());
    500   auto input = b.ConstantR1<float>({0.1f, 0.5f, -0.5f, 1.0f, 2.0f});
    501   b.Map({input}, max_with_square, {0});
    502   ComputeAndCompareR1<float>(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {});
    503 }
    504 
    505 TEST_F(MapTest, MapOperantionWithBuildError) {
    506   // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors but uses an unsupported
    507   // type combination (F32 + U16) to test that the error is reported to the
    508   // outermost ComputationBuilder.
    509   ComputationBuilder builder(client_, TestName());
    510 
    511   auto sub_builder = builder.CreateSubBuilder("ErrorAdd");
    512   auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    513   auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(U16, {}), "y");
    514   auto adder = sub_builder->Add(x, y);
    515   auto error_add = sub_builder->BuildAndNoteError();
    516 
    517   std::unique_ptr<Literal> param0_literal =
    518       Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
    519   std::unique_ptr<GlobalData> param0_data =
    520       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    521   std::unique_ptr<Literal> param1_literal =
    522       Literal::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
    523   std::unique_ptr<GlobalData> param1_data =
    524       client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
    525 
    526   auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
    527   auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
    528   auto map = builder.Map({param0, param1}, error_add, {0});
    529 
    530   StatusOr<Computation> computation_status = builder.Build();
    531   ASSERT_TRUE(!computation_status.ok());
    532   EXPECT_THAT(
    533       computation_status.status().ToString(),
    534       ::testing::HasSubstr("error from: ErrorAdd: binary op BINOP_ADD with "
    535                            "different element types: f32[] and u16[]"));
    536 }
    537 
    538 // MapTest disables inline and algsimp. MapTestWithFullOpt runs all
    539 // optimizations.
    540 using MapTestWithFullOpt = ClientLibraryTestBase;
    541 
    542 // Regression test for b/31466798. The inliner simplifies map(param0, param1,
    543 // power) to power(param0, param1) without deleting the old subcomputation which
    544 // is the same as the new entry computation. HloSubcomputationUnification used
    545 // to have issues with such patterns and maybe invalidate the pointer to entry
    546 // computation.
    547 TEST_F(MapTestWithFullOpt, MapScalarPower) {
    548   ComputationBuilder builder(client_, TestName());
    549 
    550   auto sub_builder = builder.CreateSubBuilder("power");
    551   auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    552   auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
    553   sub_builder->Pow(x, y);
    554   auto power = sub_builder->BuildAndNoteError();
    555 
    556   std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(2.0f);
    557   std::unique_ptr<Literal> param1_literal = Literal::CreateR0<float>(5.0f);
    558   std::unique_ptr<GlobalData> param0_data =
    559       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    560   std::unique_ptr<GlobalData> param1_data =
    561       client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
    562 
    563   auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
    564   auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
    565   builder.Map({param0, param1}, power, {});
    566 
    567   ComputeAndCompareR0<float>(&builder, 32.0f,
    568                              {param0_data.get(), param1_data.get()},
    569                              ErrorSpec(0.01f));
    570 }
    571 
    572 // Regression test for b/35786417, where the inliner would not notice the change
    573 // of parameter order inside the map.
    574 TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) {
    575   ComputationBuilder builder(client_, TestName());
    576 
    577   auto sub_builder = builder.CreateSubBuilder("power");
    578   auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    579   auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
    580   sub_builder->Sub(y, x);  // note that this is y - x, not x - y
    581   auto sub_opposite = sub_builder->BuildAndNoteError();
    582 
    583   std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(2.0f);
    584   std::unique_ptr<Literal> param1_literal = Literal::CreateR0<float>(5.0f);
    585   std::unique_ptr<GlobalData> param0_data =
    586       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    587   std::unique_ptr<GlobalData> param1_data =
    588       client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
    589 
    590   auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
    591   auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
    592   builder.Map({param0, param1}, sub_opposite, {});
    593 
    594   ComputeAndCompareR0<float>(
    595       &builder, 3.0f, {param0_data.get(), param1_data.get()}, ErrorSpec(0.01f));
    596 }
    597 
    598 // Regression test for b/35786417, where the inliner would CHECK-fail due to the
    599 // mul inside the map having more parameters than the map does.
    600 TEST_F(MapTestWithFullOpt, MapSquare) {
    601   ComputationBuilder builder(client_, TestName());
    602 
    603   auto sub_builder = builder.CreateSubBuilder("power");
    604   auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
    605   sub_builder->Mul(x, x);
    606   auto square = sub_builder->BuildAndNoteError();
    607 
    608   std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(10.0f);
    609   std::unique_ptr<GlobalData> param0_data =
    610       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
    611 
    612   auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
    613   builder.Map({param0}, square, {});
    614 
    615   ComputeAndCompareR0<float>(&builder, 100.0f, {param0_data.get()},
    616                              ErrorSpec(0.01f));
    617 }
    618 
    619 }  // namespace
    620 }  // namespace xla
    621