Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <memory>
     17 #include <vector>
     18 
     19 #include "tensorflow/compiler/xla/client/computation_builder.h"
     20 #include "tensorflow/compiler/xla/client/global_data.h"
     21 #include "tensorflow/compiler/xla/client/local_client.h"
     22 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     23 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
     24 #include "tensorflow/compiler/xla/tests/test_macros.h"
     25 #include "tensorflow/compiler/xla/xla_data.pb.h"
     26 #include "tensorflow/core/platform/test.h"
     27 #include "tensorflow/core/platform/types.h"
     28 
     29 namespace xla {
     30 namespace {
     31 
     32 class UnaryOpTest : public ClientLibraryTestBase {
     33  protected:
     34   template <typename T>
     35   T inf() {
     36     return std::numeric_limits<T>::infinity();
     37   }
     38   template <typename T>
     39   void AbsSize0TestHelper() {
     40     ComputationBuilder builder(client_, TestName());
     41     auto arg = builder.ConstantR1<T>({});
     42     auto abs = builder.Abs(arg);
     43 
     44     if (primitive_util::NativeToPrimitiveType<T>() == C64) {
     45       ComputeAndCompareR1<float>(&builder, {}, {});
     46     } else {
     47       ComputeAndCompareR1<T>(&builder, {}, {});
     48     }
     49   }
     50 
     51   template <typename T>
     52   void AbsTestHelper() {
     53     ComputationBuilder builder(client_, TestName());
     54     auto arg = builder.ConstantR1<T>({-2, 25, 0, -123, inf<T>(), -inf<T>()});
     55     auto abs = builder.Abs(arg);
     56 
     57     ComputeAndCompareR1<T>(&builder, {2, 25, 0, 123, inf<T>(), inf<T>()}, {});
     58   }
     59 
     60   template <typename T>
     61   void SignTestHelper() {
     62     ComputationBuilder builder(client_, TestName());
     63     auto arg = builder.ConstantR1<T>(
     64         {-2, 25, 0, static_cast<T>(-0.0), -123, inf<T>(), -inf<T>()});
     65     auto sign = builder.Sign(arg);
     66 
     67     ComputeAndCompareR1<T>(&builder, {-1, 1, 0, 0, -1, 1, -1}, {});
     68   }
     69 
     70   template <typename T>
     71   void SignAbsTestHelper() {
     72     ComputationBuilder builder(client_, TestName());
     73     auto arg = builder.ConstantR1<T>({-2, 25, 0, -123});
     74     auto sign = builder.Sign(arg);
     75     auto abs = builder.Abs(arg);
     76     builder.Sub(builder.Mul(sign, abs), arg);
     77 
     78     ComputeAndCompareR1<T>(&builder, {0, 0, 0, 0}, {});
     79   }
     80 };
     81 
     82 template <>
     83 int UnaryOpTest::inf<int>() {
     84   return 2147483647;
     85 }
     86 
     87 template <>
     88 void UnaryOpTest::AbsTestHelper<complex64>() {
     89   ComputationBuilder builder(client_, TestName());
     90   auto arg = builder.ConstantR1<complex64>({{-2, 0},
     91                                             {0, 25},
     92                                             {0, 0},
     93                                             {-0.3f, 0.4f},
     94                                             {0, inf<float>()},
     95                                             {-inf<float>(), 0}});
     96   auto abs = builder.Abs(arg);
     97 
     98   std::unique_ptr<Literal> expected =
     99       Literal::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()});
    100   ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
    101 }
    102 
    103 template <>
    104 void UnaryOpTest::SignTestHelper<complex64>() {
    105   ComputationBuilder builder(client_, TestName());
    106   auto arg = builder.ConstantR1<complex64>(
    107       {{-2, 0}, {0, 25}, {0, 0}, {static_cast<float>(-0.0), 0}, {-1, 1}});
    108   auto sign = builder.Sign(arg);
    109 
    110   std::unique_ptr<Literal> expected = Literal::CreateR1<complex64>(
    111       {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}});
    112   ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
    113 }
    114 
    115 template <>
    116 void UnaryOpTest::SignAbsTestHelper<complex64>() {
    117   ComputationBuilder builder(client_, TestName());
    118   auto arg =
    119       builder.ConstantR1<complex64>({{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}});
    120   auto sign = builder.Sign(arg);
    121   auto abs = builder.Abs(arg);
    122   builder.Sub(builder.Mul(sign, builder.ConvertElementType(abs, C64)), arg);
    123 
    124   std::unique_ptr<Literal> expected =
    125       Literal::CreateR1<complex64>({0, 0, 0, 0});
    126   ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
    127 }
    128 
    129 XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) {
    130   AbsSize0TestHelper<int>();
    131   AbsSize0TestHelper<float>();
    132   AbsSize0TestHelper<complex64>();
    133 }
    134 
    135 XLA_TEST_F(UnaryOpTest, AbsTestR1) {
    136   AbsTestHelper<int>();
    137   AbsTestHelper<float>();
    138   AbsTestHelper<complex64>();
    139 }
    140 
    141 XLA_TEST_F(UnaryOpTest, AbsTestR0) {
    142   ComputationBuilder builder(client_, TestName());
    143   auto argi = builder.ConstantR0<int>(-5);
    144   auto absi = builder.Abs(argi);
    145   auto argf = builder.ConstantR0<float>(-3.0f);
    146   auto absf = builder.Abs(argf);
    147   auto argf0 = builder.ConstantR0<float>(-0.0f);
    148   auto absf0 = builder.Abs(argf0);
    149   auto argc = builder.ConstantR0<complex64>({-0.3f, 0.4f});
    150   auto absc = builder.Abs(argc);
    151   builder.Add(builder.Add(absc, absf0),
    152               builder.Add(absf, builder.ConvertElementType(absi, F32)));
    153 
    154   ComputeAndCompareR0<float>(&builder, 8.5f, {});
    155 }
    156 
    157 XLA_TEST_F(UnaryOpTest, SignTestR0) {
    158   ComputationBuilder builder(client_, TestName());
    159   auto argi = builder.ConstantR0<int>(-5);
    160   auto sgni = builder.Sign(argi);  // -1
    161   auto argf = builder.ConstantR0<float>(-4.0f);
    162   auto sgnf = builder.Sign(argf);  // -1
    163   auto argf0 = builder.ConstantR0<float>(-0.0f);
    164   auto sgnf0 = builder.Sign(argf0);  // 0
    165   auto argc = builder.ConstantR0<complex64>({-.3, .4});
    166   auto sgnc = builder.Sign(argc);  // (-.6, .8)
    167   builder.Add(sgnc, builder.ConvertElementType(
    168                         builder.Add(builder.Add(sgnf0, sgnf),
    169                                     builder.ConvertElementType(sgni, F32)),
    170                         C64));
    171 
    172   std::unique_ptr<Literal> expected =
    173       Literal::CreateR0<complex64>({-2.6f, 0.8f});
    174   ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
    175 }
    176 
    177 XLA_TEST_F(UnaryOpTest, SignTestR1) {
    178   SignTestHelper<int>();
    179   SignTestHelper<float>();
    180   SignTestHelper<complex64>();
    181 }
    182 
    183 XLA_TEST_F(UnaryOpTest, SignAbsTestR1) {
    184   SignAbsTestHelper<int>();
    185   SignAbsTestHelper<float>();
    186   SignAbsTestHelper<complex64>();
    187 }
    188 
    189 XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) {
    190   ComputationBuilder builder(client_, TestName());
    191   auto arg = builder.ConstantR1<unsigned int>(
    192       {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
    193   auto abs = builder.Abs(arg);
    194 
    195   ComputeAndCompareR1<unsigned int>(
    196       &builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}, {});
    197 }
    198 
    199 XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) {
    200   ComputationBuilder builder(client_, TestName());
    201   auto arg = builder.ConstantR1<unsigned int>(
    202       {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
    203   auto sign = builder.Sign(arg);
    204 
    205   ComputeAndCompareR1<unsigned int>(&builder, {1, 1, 0, 1, 1}, {});
    206 }
    207 
    208 XLA_TEST_F(UnaryOpTest, SignAbsTestR2) {
    209   ComputationBuilder builder(client_, TestName());
    210   auto arg = builder.ConstantR2<float>({{1.0, -2.0}, {-3.0, 4.0}});
    211   auto sign = builder.Sign(arg);
    212   auto abs = builder.Abs(arg);
    213   builder.Sub(builder.Mul(sign, abs), arg);
    214 
    215   ComputeAndCompareR2<float>(&builder, {{0, 0}, {0, 0}}, {});
    216 }
    217 
    218 XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) {
    219   ComputationBuilder builder(client_, TestName());
    220   auto lhs = builder.ConstantR1<int32>({0, 1});
    221   auto rhs = builder.ConstantR1<int32>({1, 1});
    222   builder.ConvertElementType(builder.Eq(lhs, rhs), S32);
    223 
    224   ComputeAndCompareR1<int32>(&builder, {0, 1}, {});
    225 }
    226 
    227 XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToF32) {
    228   ComputationBuilder builder(client_, TestName());
    229   auto lhs = builder.ConstantR1<int32>({0, 1});
    230   auto rhs = builder.ConstantR1<int32>({1, 1});
    231   builder.ConvertElementType(builder.Eq(lhs, rhs), F32);
    232 
    233   ComputeAndCompareR1<float>(&builder, {0.0, 1.0}, {});
    234 }
    235 
    236 }  // namespace
    237 }  // namespace xla
    238