1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <cmath> 17 #include <vector> 18 19 #include "tensorflow/compiler/xla/client/xla_builder.h" 20 #include "tensorflow/compiler/xla/literal.h" 21 #include "tensorflow/compiler/xla/statusor.h" 22 #include "tensorflow/compiler/xla/test.h" 23 #include "tensorflow/compiler/xla/test_helpers.h" 24 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 25 #include "tensorflow/compiler/xla/tests/test_macros.h" 26 #include "tensorflow/compiler/xla/tests/test_utils.h" 27 28 // Tests the handling of the basic mathematics operations with F16 operands. 29 30 namespace xla { 31 namespace { 32 33 class HalfTestBase : public ClientLibraryTestBase { 34 protected: 35 const ErrorSpec error_spec_{0.001, 0.001}; 36 // Number of elements in the input buffers. 37 static const int kNumElements = 4; 38 }; 39 40 using UnaryBuildFuncTy = std::function<void(const xla::XlaOp& src)>; 41 42 struct UnaryOpTestParam { 43 std::function<half(half)> compute_func; 44 UnaryBuildFuncTy build_func; 45 }; 46 47 class UnaryOpTest : public HalfTestBase, 48 public ::testing::WithParamInterface<UnaryOpTestParam> {}; 49 50 XLA_TEST_P(UnaryOpTest, Ops) { 51 std::vector<half> x({half(1.4), half(-2.3), half(3.2), half(-4.1), half(9.0), 52 half(42.0), half(-9.0), half(-100.0)}); 53 XlaBuilder builder(TestName()); 54 XlaOp x_opnd; 55 auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x", 56 &builder, &x_opnd); 57 58 std::function<half(half)> compute_func = GetParam().compute_func; 59 std::vector<half> expected; 60 for (int64 i = 0; i < x.size(); ++i) { 61 expected.push_back(compute_func(x[i])); 62 } 63 64 UnaryBuildFuncTy build_func = GetParam().build_func; 65 build_func(x_opnd); 66 67 ComputeAndCompareR1<half>(&builder, expected, {x_data.get()}, error_spec_); 68 } 69 70 half sign_imp(half value) { 71 const float x(std::move(value)); 72 return half((x < .0) ? -1 : (x > .0)); 73 } 74 75 half round_imp(half value) { 76 return half(std::round(static_cast<float>(std::move(value)))); 77 } 78 79 INSTANTIATE_TEST_CASE_P( 80 half, UnaryOpTest, 81 ::testing::Values( 82 UnaryOpTestParam{[](half x) { return abs(x); }, &Abs}, 83 UnaryOpTestParam{[](half x) { return round_imp(x); }, &Round}, 84 UnaryOpTestParam{[](half x) { return ceil(x); }, &Ceil}, 85 UnaryOpTestParam{[](half x) { return cos(x); }, &Cos}, 86 UnaryOpTestParam{[](half x) { return exp(x); }, &Exp}, 87 UnaryOpTestParam{[](half x) { return floor(x); }, &Floor}, 88 UnaryOpTestParam{[](half x) { return log(x); }, &Log}, 89 UnaryOpTestParam{[](half x) { return -x; }, &Neg}, 90 UnaryOpTestParam{[](half x) { return sign_imp(x); }, &Sign}, 91 UnaryOpTestParam{[](half x) { return sin(x); }, &Sin}, 92 UnaryOpTestParam{[](half x) { return tanh(x); }, &Tanh} 93 94 )); 95 96 struct UnaryPredTestParam { 97 std::function<bool(half)> compute_func; 98 UnaryBuildFuncTy build_func; 99 }; 100 101 class UnaryPredTest : public HalfTestBase, 102 public ::testing::WithParamInterface<UnaryPredTestParam> { 103 }; 104 105 XLA_TEST_P(UnaryPredTest, Ops) { 106 std::vector<half> x({half(1.4), half(-2.3), half(3.2), half(-4.1)}); 107 XlaBuilder builder(TestName()); 108 XlaOp x_opnd; 109 auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x", 110 &builder, &x_opnd); 111 112 std::function<bool(half)> compute_func = GetParam().compute_func; 113 CHECK_EQ(kNumElements, x.size()); 114 bool expected[kNumElements]; 115 for (int64 i = 0; i < x.size(); ++i) { 116 expected[i] = compute_func(x[i]); 117 } 118 119 UnaryBuildFuncTy build_func = GetParam().build_func; 120 build_func(x_opnd); 121 122 ComputeAndCompareR1<bool>(&builder, expected, {x_data.get()}); 123 } 124 125 INSTANTIATE_TEST_CASE_P(half, UnaryPredTest, 126 ::testing::Values(UnaryPredTestParam{ 127 [](half x) { return isfinite(x); }, &IsFinite})); 128 129 using BinaryBuildFuncTy = std::function<void( 130 const xla::XlaOp& x, const xla::XlaOp& y, absl::Span<const int64>)>; 131 132 struct BinaryOpTestParam { 133 std::function<half(half, half)> compute_func; 134 BinaryBuildFuncTy build_func; 135 }; 136 137 class BinaryOpTest : public HalfTestBase, 138 public ::testing::WithParamInterface<BinaryOpTestParam> {}; 139 140 XLA_TEST_P(BinaryOpTest, Ops) { 141 std::vector<half> x({half(1.0), half(2.0), half(3.0), half(-4.0)}); 142 std::vector<half> y({half(0.4), half(-0.3), half(0.2), half(0.1)}); 143 XlaBuilder builder(TestName()); 144 XlaOp x_opnd; 145 auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x", 146 &builder, &x_opnd); 147 148 XlaOp y_opnd; 149 auto y_data = CreateR1Parameter<half>(y, /*parameter_number=*/1, "y", 150 &builder, &y_opnd); 151 152 std::function<half(half, half)> compute_func = GetParam().compute_func; 153 std::vector<half> expected; 154 for (int64 i = 0; i < x.size(); ++i) { 155 expected.push_back(compute_func(x[i], y[i])); 156 } 157 158 BinaryBuildFuncTy build_func = GetParam().build_func; 159 build_func(x_opnd, y_opnd, {}); 160 161 ComputeAndCompareR1<half>(&builder, expected, {x_data.get(), y_data.get()}, 162 error_spec_); 163 } 164 165 half atan2_imp(half x, half y) { 166 return half(atan2(static_cast<float>(std::move(x)), 167 static_cast<float>(std::move(y)))); 168 } 169 170 INSTANTIATE_TEST_CASE_P( 171 half, BinaryOpTest, 172 ::testing::Values( 173 BinaryOpTestParam{[](half x, half y) { return x + y; }, &Add}, 174 BinaryOpTestParam{[](half x, half y) { return atan2_imp(x, y); }, 175 &Atan2}, 176 BinaryOpTestParam{[](half x, half y) { return x / y; }, &Div}, 177 BinaryOpTestParam{[](half x, half y) { return max(x, y); }, &Max}, 178 BinaryOpTestParam{[](half x, half y) { return min(x, y); }, &Min}, 179 BinaryOpTestParam{[](half x, half y) { return x * y; }, &Mul}, 180 BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, &Pow}, 181 BinaryOpTestParam{[](half x, half y) { return x - y; }, &Sub} 182 183 )); 184 185 struct BinaryPredTestParam { 186 std::function<bool(half, half)> compute_func; 187 BinaryBuildFuncTy build_func; 188 }; 189 190 class BinaryPredTest 191 : public HalfTestBase, 192 public ::testing::WithParamInterface<BinaryPredTestParam> {}; 193 194 XLA_TEST_P(BinaryPredTest, Ops) { 195 std::vector<half> x({half(1.0), half(2.0), half(0.2), half(-4.0)}); 196 std::vector<half> y({half(0.4), half(-0.3), half(0.2), half(0.1)}); 197 XlaBuilder builder(TestName()); 198 XlaOp x_opnd; 199 auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x", 200 &builder, &x_opnd); 201 202 XlaOp y_opnd; 203 auto y_data = CreateR1Parameter<half>(y, /*parameter_number=*/1, "y", 204 &builder, &y_opnd); 205 206 std::function<bool(half, half)> compute_func = GetParam().compute_func; 207 CHECK_EQ(kNumElements, x.size()); 208 bool expected[kNumElements]; 209 for (int64 i = 0; i < x.size(); ++i) { 210 expected[i] = compute_func(x[i], y[i]); 211 } 212 213 BinaryBuildFuncTy build_func = GetParam().build_func; 214 build_func(x_opnd, y_opnd, {}); 215 216 ComputeAndCompareR1<bool>(&builder, expected, {x_data.get(), y_data.get()}); 217 } 218 219 INSTANTIATE_TEST_CASE_P( 220 half, BinaryPredTest, 221 ::testing::Values( 222 BinaryPredTestParam{[](half x, half y) { return x == y; }, &Eq}, 223 BinaryPredTestParam{[](half x, half y) { return x != y; }, &Ne}, 224 BinaryPredTestParam{[](half x, half y) { return x >= y; }, &Ge}, 225 BinaryPredTestParam{[](half x, half y) { return x > y; }, &Gt}, 226 BinaryPredTestParam{[](half x, half y) { return x <= y; }, &Le}, 227 BinaryPredTestParam{[](half x, half y) { return x < y; }, &Lt} 228 229 )); 230 231 } // namespace 232 } // namespace xla 233