Home | History | Annotate | Download | only in tests
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <cmath>
     17 #include <vector>
     18 
     19 #include "tensorflow/compiler/xla/client/xla_builder.h"
     20 #include "tensorflow/compiler/xla/literal.h"
     21 #include "tensorflow/compiler/xla/statusor.h"
     22 #include "tensorflow/compiler/xla/test.h"
     23 #include "tensorflow/compiler/xla/test_helpers.h"
     24 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
     25 #include "tensorflow/compiler/xla/tests/test_macros.h"
     26 #include "tensorflow/compiler/xla/tests/test_utils.h"
     27 
     28 // Tests the handling of the basic mathematics operations with F16 operands.
     29 
     30 namespace xla {
     31 namespace {
     32 
     33 class HalfTestBase : public ClientLibraryTestBase {
     34  protected:
     35   const ErrorSpec error_spec_{0.001, 0.001};
     36   // Number of elements in the input buffers.
     37   static const int kNumElements = 4;
     38 };
     39 
     40 using UnaryBuildFuncTy = std::function<void(const xla::XlaOp& src)>;
     41 
     42 struct UnaryOpTestParam {
     43   std::function<half(half)> compute_func;
     44   UnaryBuildFuncTy build_func;
     45 };
     46 
     47 class UnaryOpTest : public HalfTestBase,
     48                     public ::testing::WithParamInterface<UnaryOpTestParam> {};
     49 
     50 XLA_TEST_P(UnaryOpTest, Ops) {
     51   std::vector<half> x({half(1.4), half(-2.3), half(3.2), half(-4.1), half(9.0),
     52                        half(42.0), half(-9.0), half(-100.0)});
     53   XlaBuilder builder(TestName());
     54   XlaOp x_opnd;
     55   auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x",
     56                                         &builder, &x_opnd);
     57 
     58   std::function<half(half)> compute_func = GetParam().compute_func;
     59   std::vector<half> expected;
     60   for (int64 i = 0; i < x.size(); ++i) {
     61     expected.push_back(compute_func(x[i]));
     62   }
     63 
     64   UnaryBuildFuncTy build_func = GetParam().build_func;
     65   build_func(x_opnd);
     66 
     67   ComputeAndCompareR1<half>(&builder, expected, {x_data.get()}, error_spec_);
     68 }
     69 
     70 half sign_imp(half value) {
     71   const float x(std::move(value));
     72   return half((x < .0) ? -1 : (x > .0));
     73 }
     74 
     75 half round_imp(half value) {
     76   return half(std::round(static_cast<float>(std::move(value))));
     77 }
     78 
     79 INSTANTIATE_TEST_CASE_P(
     80     half, UnaryOpTest,
     81     ::testing::Values(
     82         UnaryOpTestParam{[](half x) { return abs(x); }, &Abs},
     83         UnaryOpTestParam{[](half x) { return round_imp(x); }, &Round},
     84         UnaryOpTestParam{[](half x) { return ceil(x); }, &Ceil},
     85         UnaryOpTestParam{[](half x) { return cos(x); }, &Cos},
     86         UnaryOpTestParam{[](half x) { return exp(x); }, &Exp},
     87         UnaryOpTestParam{[](half x) { return floor(x); }, &Floor},
     88         UnaryOpTestParam{[](half x) { return log(x); }, &Log},
     89         UnaryOpTestParam{[](half x) { return -x; }, &Neg},
     90         UnaryOpTestParam{[](half x) { return sign_imp(x); }, &Sign},
     91         UnaryOpTestParam{[](half x) { return sin(x); }, &Sin},
     92         UnaryOpTestParam{[](half x) { return tanh(x); }, &Tanh}
     93 
     94         ));
     95 
     96 struct UnaryPredTestParam {
     97   std::function<bool(half)> compute_func;
     98   UnaryBuildFuncTy build_func;
     99 };
    100 
    101 class UnaryPredTest : public HalfTestBase,
    102                       public ::testing::WithParamInterface<UnaryPredTestParam> {
    103 };
    104 
    105 XLA_TEST_P(UnaryPredTest, Ops) {
    106   std::vector<half> x({half(1.4), half(-2.3), half(3.2), half(-4.1)});
    107   XlaBuilder builder(TestName());
    108   XlaOp x_opnd;
    109   auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x",
    110                                         &builder, &x_opnd);
    111 
    112   std::function<bool(half)> compute_func = GetParam().compute_func;
    113   CHECK_EQ(kNumElements, x.size());
    114   bool expected[kNumElements];
    115   for (int64 i = 0; i < x.size(); ++i) {
    116     expected[i] = compute_func(x[i]);
    117   }
    118 
    119   UnaryBuildFuncTy build_func = GetParam().build_func;
    120   build_func(x_opnd);
    121 
    122   ComputeAndCompareR1<bool>(&builder, expected, {x_data.get()});
    123 }
    124 
    125 INSTANTIATE_TEST_CASE_P(half, UnaryPredTest,
    126                         ::testing::Values(UnaryPredTestParam{
    127                             [](half x) { return isfinite(x); }, &IsFinite}));
    128 
    129 using BinaryBuildFuncTy = std::function<void(
    130     const xla::XlaOp& x, const xla::XlaOp& y, absl::Span<const int64>)>;
    131 
    132 struct BinaryOpTestParam {
    133   std::function<half(half, half)> compute_func;
    134   BinaryBuildFuncTy build_func;
    135 };
    136 
    137 class BinaryOpTest : public HalfTestBase,
    138                      public ::testing::WithParamInterface<BinaryOpTestParam> {};
    139 
    140 XLA_TEST_P(BinaryOpTest, Ops) {
    141   std::vector<half> x({half(1.0), half(2.0), half(3.0), half(-4.0)});
    142   std::vector<half> y({half(0.4), half(-0.3), half(0.2), half(0.1)});
    143   XlaBuilder builder(TestName());
    144   XlaOp x_opnd;
    145   auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x",
    146                                         &builder, &x_opnd);
    147 
    148   XlaOp y_opnd;
    149   auto y_data = CreateR1Parameter<half>(y, /*parameter_number=*/1, "y",
    150                                         &builder, &y_opnd);
    151 
    152   std::function<half(half, half)> compute_func = GetParam().compute_func;
    153   std::vector<half> expected;
    154   for (int64 i = 0; i < x.size(); ++i) {
    155     expected.push_back(compute_func(x[i], y[i]));
    156   }
    157 
    158   BinaryBuildFuncTy build_func = GetParam().build_func;
    159   build_func(x_opnd, y_opnd, {});
    160 
    161   ComputeAndCompareR1<half>(&builder, expected, {x_data.get(), y_data.get()},
    162                             error_spec_);
    163 }
    164 
    165 half atan2_imp(half x, half y) {
    166   return half(atan2(static_cast<float>(std::move(x)),
    167                     static_cast<float>(std::move(y))));
    168 }
    169 
    170 INSTANTIATE_TEST_CASE_P(
    171     half, BinaryOpTest,
    172     ::testing::Values(
    173         BinaryOpTestParam{[](half x, half y) { return x + y; }, &Add},
    174         BinaryOpTestParam{[](half x, half y) { return atan2_imp(x, y); },
    175                           &Atan2},
    176         BinaryOpTestParam{[](half x, half y) { return x / y; }, &Div},
    177         BinaryOpTestParam{[](half x, half y) { return max(x, y); }, &Max},
    178         BinaryOpTestParam{[](half x, half y) { return min(x, y); }, &Min},
    179         BinaryOpTestParam{[](half x, half y) { return x * y; }, &Mul},
    180         BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, &Pow},
    181         BinaryOpTestParam{[](half x, half y) { return x - y; }, &Sub}
    182 
    183         ));
    184 
    185 struct BinaryPredTestParam {
    186   std::function<bool(half, half)> compute_func;
    187   BinaryBuildFuncTy build_func;
    188 };
    189 
    190 class BinaryPredTest
    191     : public HalfTestBase,
    192       public ::testing::WithParamInterface<BinaryPredTestParam> {};
    193 
    194 XLA_TEST_P(BinaryPredTest, Ops) {
    195   std::vector<half> x({half(1.0), half(2.0), half(0.2), half(-4.0)});
    196   std::vector<half> y({half(0.4), half(-0.3), half(0.2), half(0.1)});
    197   XlaBuilder builder(TestName());
    198   XlaOp x_opnd;
    199   auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x",
    200                                         &builder, &x_opnd);
    201 
    202   XlaOp y_opnd;
    203   auto y_data = CreateR1Parameter<half>(y, /*parameter_number=*/1, "y",
    204                                         &builder, &y_opnd);
    205 
    206   std::function<bool(half, half)> compute_func = GetParam().compute_func;
    207   CHECK_EQ(kNumElements, x.size());
    208   bool expected[kNumElements];
    209   for (int64 i = 0; i < x.size(); ++i) {
    210     expected[i] = compute_func(x[i], y[i]);
    211   }
    212 
    213   BinaryBuildFuncTy build_func = GetParam().build_func;
    214   build_func(x_opnd, y_opnd, {});
    215 
    216   ComputeAndCompareR1<bool>(&builder, expected, {x_data.get(), y_data.get()});
    217 }
    218 
    219 INSTANTIATE_TEST_CASE_P(
    220     half, BinaryPredTest,
    221     ::testing::Values(
    222         BinaryPredTestParam{[](half x, half y) { return x == y; }, &Eq},
    223         BinaryPredTestParam{[](half x, half y) { return x != y; }, &Ne},
    224         BinaryPredTestParam{[](half x, half y) { return x >= y; }, &Ge},
    225         BinaryPredTestParam{[](half x, half y) { return x > y; }, &Gt},
    226         BinaryPredTestParam{[](half x, half y) { return x <= y; }, &Le},
    227         BinaryPredTestParam{[](half x, half y) { return x < y; }, &Lt}
    228 
    229         ));
    230 
    231 }  // namespace
    232 }  // namespace xla
    233