Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <memory>
     17 #include <vector>
     18 
     19 #include "tensorflow/compiler/xla/client/computation_builder.h"
     20 #include "tensorflow/compiler/xla/client/global_data.h"
     21 #include "tensorflow/compiler/xla/client/local_client.h"
     22 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     23 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     24 #include "tensorflow/compiler/xla/tests/test_macros.h"
     25 #include "tensorflow/compiler/xla/xla_data.pb.h"
     26 #include "tensorflow/core/platform/test.h"
     27 #include "tensorflow/core/platform/types.h"
     28 
     29 namespace xla {
     30 namespace {
     31 
     32 class SelectTest : public ClientLibraryTestBase {
     33  public:
     34   ErrorSpec error_spec_{0.0001};
     35 };
     36 
     37 TEST_F(SelectTest, SelectScalarF32True) {
     38   ComputationBuilder builder(client_, TestName());
     39   auto pred = builder.ConstantR0<bool>(true);
     40   auto on_true = builder.ConstantR0<float>(123.0f);
     41   auto on_false = builder.ConstantR0<float>(42.0f);
     42   auto result = builder.Select(pred, on_true, on_false);
     43 
     44   ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
     45 }
     46 
     47 TEST_F(SelectTest, SelectScalarS32True) {
     48   ComputationBuilder builder(client_, TestName());
     49   auto pred = builder.ConstantR0<bool>(true);
     50   auto on_true = builder.ConstantR0<int32>(-42);
     51   auto on_false = builder.ConstantR0<int32>(42);
     52   auto result = builder.Select(pred, on_true, on_false);
     53 
     54   ComputeAndCompareR0<int32>(&builder, -42, {});
     55 }
     56 
     57 TEST_F(SelectTest, SelectScalarF32False) {
     58   ComputationBuilder builder(client_, TestName());
     59   auto pred = builder.ConstantR0<bool>(false);
     60   auto on_true = builder.ConstantR0<float>(123.0f);
     61   auto on_false = builder.ConstantR0<float>(42.0f);
     62   auto result = builder.Select(pred, on_true, on_false);
     63 
     64   ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
     65 }
     66 
     67 XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) {
     68   ComputationBuilder builder(client_, TestName());
     69   auto pred = builder.ConstantR1<bool>({});
     70   auto on_true = builder.ConstantR1<float>({});
     71   auto on_false = builder.ConstantR1<float>({});
     72   auto select = builder.Select(pred, on_true, on_false);
     73 
     74   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
     75 }
     76 
     77 TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) {
     78   ComputationBuilder builder(client_, TestName());
     79   auto pred = builder.ConstantR1<bool>({false, true, false, true, false});
     80   auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
     81   auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
     82   auto select = builder.Select(pred, on_true, on_false);
     83 
     84   ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
     85                              error_spec_);
     86 }
     87 
     88 XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) {
     89   // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector
     90   // is not a constant, but rather the result of comparing two other vectors.
     91   ComputationBuilder builder(client_, TestName());
     92   auto v1 = builder.ConstantR1<int32>({});
     93   auto v2 = builder.ConstantR1<int32>({});
     94   auto cmp = builder.Eq(v1, v2);
     95   auto on_true = builder.ConstantR1<float>({});
     96   auto on_false = builder.ConstantR1<float>({});
     97   auto select = builder.Select(cmp, on_true, on_false);
     98 
     99   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
    100 }
    101 
    102 TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) {
    103   // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is
    104   // not a constant, but rather the result of comparing two other vectors.
    105   ComputationBuilder builder(client_, TestName());
    106   auto v1 = builder.ConstantR1<int32>({1, 2, 3, 4, 5});
    107   auto v2 = builder.ConstantR1<int32>({9, 2, 9, 4, 9});
    108   auto cmp = builder.Eq(v1, v2);
    109   auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
    110   auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
    111   auto select = builder.Select(cmp, on_true, on_false);
    112 
    113   ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
    114                              error_spec_);
    115 }
    116 
    117 TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) {
    118   // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s.
    119   ComputationBuilder builder(client_, TestName());
    120   auto v1 = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
    121   auto v2 = builder.ConstantR1<float>({-1.0f, -2.0f, 13.0f, 14.0f, 4.4f});
    122   auto cmp = builder.Gt(v1, v2);
    123   auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
    124   auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
    125   auto select = builder.Select(cmp, on_true, on_false);
    126 
    127   ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {},
    128                              error_spec_);
    129 }
    130 
    131 TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) {
    132   // Selects among two R1F32s, which come from parameters. v1 and v2 are
    133   // compared, and selection between them happens based on a gt-comparison mask.
    134   ComputationBuilder builder(client_, TestName());
    135 
    136   ComputationDataHandle v1, v2;
    137   std::unique_ptr<GlobalData> param0_data = CreateR1Parameter<float>(
    138       {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1",
    139       /*builder=*/&builder, /*data_handle=*/&v1);
    140   std::unique_ptr<GlobalData> param1_data = CreateR1Parameter<float>(
    141       {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2",
    142       /*builder=*/&builder, /*data_handle=*/&v2);
    143 
    144   auto cmp = builder.Gt(v1, v2);
    145   auto select = builder.Select(cmp, v1, v2);
    146   ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f},
    147                              {param0_data.get(), param1_data.get()},
    148                              error_spec_);
    149 }
    150 
    151 TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) {
    152   // Similar to SelectR1F32WithCmpR1F32sFromParamsSmall, except that the
    153   // data size passed in and out is large.
    154   ComputationBuilder builder(client_, TestName());
    155 
    156   // Number of floats in the data passed into and out of the computation.
    157   constexpr int datalen = 15 * 1000;
    158 
    159   // The inputs are initialized with a special pattern where in the first third
    160   // of the data v1[i] > v2[i] and elsewhere it's vice versa.
    161   std::vector<float> v1vec;
    162   std::vector<float> v2vec;
    163   std::vector<float> expected_vec;
    164   for (int i = 0; i < datalen; ++i) {
    165     float smaller = i;
    166     float larger = i * 2;
    167     if (i < datalen / 3) {
    168       v1vec.push_back(larger);
    169       v2vec.push_back(smaller);
    170     } else {
    171       v1vec.push_back(smaller);
    172       v2vec.push_back(larger);
    173     }
    174     expected_vec.push_back(larger);
    175   }
    176 
    177   ComputationDataHandle v1, v2;
    178   std::unique_ptr<GlobalData> param0_data =
    179       CreateR1Parameter<float>(v1vec, /*parameter_number=*/0, /*name=*/"v1",
    180                                /*builder=*/&builder, /*data_handle=*/&v1);
    181   std::unique_ptr<GlobalData> param1_data =
    182       CreateR1Parameter<float>(v2vec, /*parameter_number=*/1, /*name=*/"v2",
    183                                /*builder=*/&builder, /*data_handle=*/&v2);
    184 
    185   auto cmp = builder.Gt(v1, v2);
    186   auto select = builder.Select(cmp, v1, v2);
    187   ComputeAndCompareR1<float>(&builder, expected_vec,
    188                              {param0_data.get(), param1_data.get()},
    189                              error_spec_);
    190 }
    191 
    192 TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) {
    193   // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to
    194   // select between two R1F32s.
    195   ComputationBuilder builder(client_, TestName());
    196   auto v = builder.ConstantR1<int32>({1, -1, 2, -2});
    197   auto s = builder.ConstantR0<int32>(0);
    198   auto cmp = builder.Gt(v, s);
    199 
    200   auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f});
    201   auto on_false =
    202       builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f});
    203   auto select = builder.Select(cmp, on_true, on_false);
    204 
    205   ComputeAndCompareR1<float>(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {},
    206                              error_spec_);
    207 }
    208 
    209 TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) {
    210   // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to
    211   // select between two R1F32s.
    212   ComputationBuilder builder(client_, TestName());
    213   auto v = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
    214   auto s = builder.ConstantR0<float>(2.5f);
    215   auto cmp = builder.Gt(v, s);
    216 
    217   auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f});
    218   auto on_false =
    219       builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f});
    220   auto select = builder.Select(cmp, on_true, on_false);
    221 
    222   ComputeAndCompareR1<float>(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {},
    223                              error_spec_);
    224 }
    225 
    226 XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) {
    227   for (bool which : {false, true}) {
    228     ComputationBuilder builder(client_, TestName());
    229     auto pred = builder.ConstantR0<bool>(which);
    230     auto on_true = builder.ConstantR1<float>({});
    231     auto on_false = builder.ConstantR1<float>({});
    232     auto select = builder.Select(pred, on_true, on_false);
    233 
    234     ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
    235   }
    236 }
    237 
    238 TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) {
    239   ComputationBuilder builder(client_, TestName());
    240   auto pred = builder.ConstantR0<bool>(true);
    241   auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f});
    242   auto on_false = builder.ConstantR1<float>({10.0f, 5.0f});
    243   auto select = builder.Select(pred, on_true, on_false);
    244 
    245   ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f}, {}, error_spec_);
    246 }
    247 
    248 TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) {
    249   ComputationBuilder builder(client_, TestName());
    250   auto pred = builder.ConstantR0<bool>(false);
    251   auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f});
    252   auto on_false = builder.ConstantR1<float>({10.0f, 5.0f});
    253   auto select = builder.Select(pred, on_true, on_false);
    254 
    255   ComputeAndCompareR1<float>(&builder, {10.0f, 5.0f}, {}, error_spec_);
    256 }
    257 }  // namespace
    258 }  // namespace xla
    259