1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <memory> 17 #include <vector> 18 19 #include "tensorflow/compiler/xla/client/computation_builder.h" 20 #include "tensorflow/compiler/xla/client/global_data.h" 21 #include "tensorflow/compiler/xla/client/local_client.h" 22 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 23 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 24 #include "tensorflow/compiler/xla/tests/test_macros.h" 25 #include "tensorflow/compiler/xla/xla_data.pb.h" 26 #include "tensorflow/core/platform/test.h" 27 #include "tensorflow/core/platform/types.h" 28 29 namespace xla { 30 namespace { 31 32 class UnaryOpTest : public ClientLibraryTestBase { 33 protected: 34 template <typename T> 35 T inf() { 36 return std::numeric_limits<T>::infinity(); 37 } 38 template <typename T> 39 void AbsSize0TestHelper() { 40 ComputationBuilder builder(client_, TestName()); 41 auto arg = builder.ConstantR1<T>({}); 42 auto abs = builder.Abs(arg); 43 44 if (primitive_util::NativeToPrimitiveType<T>() == C64) { 45 ComputeAndCompareR1<float>(&builder, {}, {}); 46 } else { 47 ComputeAndCompareR1<T>(&builder, {}, {}); 48 } 49 } 50 51 template <typename T> 52 void AbsTestHelper() { 53 ComputationBuilder builder(client_, TestName()); 54 auto arg = builder.ConstantR1<T>({-2, 25, 0, -123, inf<T>(), -inf<T>()}); 55 auto abs = builder.Abs(arg); 56 57 ComputeAndCompareR1<T>(&builder, {2, 25, 0, 123, inf<T>(), inf<T>()}, {}); 58 } 59 60 template <typename T> 61 void SignTestHelper() { 62 ComputationBuilder builder(client_, TestName()); 63 auto arg = builder.ConstantR1<T>( 64 {-2, 25, 0, static_cast<T>(-0.0), -123, inf<T>(), -inf<T>()}); 65 auto sign = builder.Sign(arg); 66 67 ComputeAndCompareR1<T>(&builder, {-1, 1, 0, 0, -1, 1, -1}, {}); 68 } 69 70 template <typename T> 71 void SignAbsTestHelper() { 72 ComputationBuilder builder(client_, TestName()); 73 auto arg = builder.ConstantR1<T>({-2, 25, 0, -123}); 74 auto sign = builder.Sign(arg); 75 auto abs = builder.Abs(arg); 76 builder.Sub(builder.Mul(sign, abs), arg); 77 78 ComputeAndCompareR1<T>(&builder, {0, 0, 0, 0}, {}); 79 } 80 }; 81 82 template <> 83 int UnaryOpTest::inf<int>() { 84 return 2147483647; 85 } 86 87 template <> 88 void UnaryOpTest::AbsTestHelper<complex64>() { 89 ComputationBuilder builder(client_, TestName()); 90 auto arg = builder.ConstantR1<complex64>({{-2, 0}, 91 {0, 25}, 92 {0, 0}, 93 {-0.3f, 0.4f}, 94 {0, inf<float>()}, 95 {-inf<float>(), 0}}); 96 auto abs = builder.Abs(arg); 97 98 std::unique_ptr<Literal> expected = 99 Literal::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()}); 100 ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); 101 } 102 103 template <> 104 void UnaryOpTest::SignTestHelper<complex64>() { 105 ComputationBuilder builder(client_, TestName()); 106 auto arg = builder.ConstantR1<complex64>( 107 {{-2, 0}, {0, 25}, {0, 0}, {static_cast<float>(-0.0), 0}, {-1, 1}}); 108 auto sign = builder.Sign(arg); 109 110 std::unique_ptr<Literal> expected = Literal::CreateR1<complex64>( 111 {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}}); 112 ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); 113 } 114 115 template <> 116 void UnaryOpTest::SignAbsTestHelper<complex64>() { 117 ComputationBuilder builder(client_, TestName()); 118 auto arg = 119 builder.ConstantR1<complex64>({{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}}); 120 auto sign = builder.Sign(arg); 121 auto abs = builder.Abs(arg); 122 builder.Sub(builder.Mul(sign, builder.ConvertElementType(abs, C64)), arg); 123 124 std::unique_ptr<Literal> expected = 125 Literal::CreateR1<complex64>({0, 0, 0, 0}); 126 ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); 127 } 128 129 XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) { 130 AbsSize0TestHelper<int>(); 131 AbsSize0TestHelper<float>(); 132 AbsSize0TestHelper<complex64>(); 133 } 134 135 XLA_TEST_F(UnaryOpTest, AbsTestR1) { 136 AbsTestHelper<int>(); 137 AbsTestHelper<float>(); 138 AbsTestHelper<complex64>(); 139 } 140 141 XLA_TEST_F(UnaryOpTest, AbsTestR0) { 142 ComputationBuilder builder(client_, TestName()); 143 auto argi = builder.ConstantR0<int>(-5); 144 auto absi = builder.Abs(argi); 145 auto argf = builder.ConstantR0<float>(-3.0f); 146 auto absf = builder.Abs(argf); 147 auto argf0 = builder.ConstantR0<float>(-0.0f); 148 auto absf0 = builder.Abs(argf0); 149 auto argc = builder.ConstantR0<complex64>({-0.3f, 0.4f}); 150 auto absc = builder.Abs(argc); 151 builder.Add(builder.Add(absc, absf0), 152 builder.Add(absf, builder.ConvertElementType(absi, F32))); 153 154 ComputeAndCompareR0<float>(&builder, 8.5f, {}); 155 } 156 157 XLA_TEST_F(UnaryOpTest, SignTestR0) { 158 ComputationBuilder builder(client_, TestName()); 159 auto argi = builder.ConstantR0<int>(-5); 160 auto sgni = builder.Sign(argi); // -1 161 auto argf = builder.ConstantR0<float>(-4.0f); 162 auto sgnf = builder.Sign(argf); // -1 163 auto argf0 = builder.ConstantR0<float>(-0.0f); 164 auto sgnf0 = builder.Sign(argf0); // 0 165 auto argc = builder.ConstantR0<complex64>({-.3, .4}); 166 auto sgnc = builder.Sign(argc); // (-.6, .8) 167 builder.Add(sgnc, builder.ConvertElementType( 168 builder.Add(builder.Add(sgnf0, sgnf), 169 builder.ConvertElementType(sgni, F32)), 170 C64)); 171 172 std::unique_ptr<Literal> expected = 173 Literal::CreateR0<complex64>({-2.6f, 0.8f}); 174 ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); 175 } 176 177 XLA_TEST_F(UnaryOpTest, SignTestR1) { 178 SignTestHelper<int>(); 179 SignTestHelper<float>(); 180 SignTestHelper<complex64>(); 181 } 182 183 XLA_TEST_F(UnaryOpTest, SignAbsTestR1) { 184 SignAbsTestHelper<int>(); 185 SignAbsTestHelper<float>(); 186 SignAbsTestHelper<complex64>(); 187 } 188 189 XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { 190 ComputationBuilder builder(client_, TestName()); 191 auto arg = builder.ConstantR1<unsigned int>( 192 {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}); 193 auto abs = builder.Abs(arg); 194 195 ComputeAndCompareR1<unsigned int>( 196 &builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}, {}); 197 } 198 199 XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) { 200 ComputationBuilder builder(client_, TestName()); 201 auto arg = builder.ConstantR1<unsigned int>( 202 {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}); 203 auto sign = builder.Sign(arg); 204 205 ComputeAndCompareR1<unsigned int>(&builder, {1, 1, 0, 1, 1}, {}); 206 } 207 208 XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { 209 ComputationBuilder builder(client_, TestName()); 210 auto arg = builder.ConstantR2<float>({{1.0, -2.0}, {-3.0, 4.0}}); 211 auto sign = builder.Sign(arg); 212 auto abs = builder.Abs(arg); 213 builder.Sub(builder.Mul(sign, abs), arg); 214 215 ComputeAndCompareR2<float>(&builder, {{0, 0}, {0, 0}}, {}); 216 } 217 218 XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) { 219 ComputationBuilder builder(client_, TestName()); 220 auto lhs = builder.ConstantR1<int32>({0, 1}); 221 auto rhs = builder.ConstantR1<int32>({1, 1}); 222 builder.ConvertElementType(builder.Eq(lhs, rhs), S32); 223 224 ComputeAndCompareR1<int32>(&builder, {0, 1}, {}); 225 } 226 227 XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToF32) { 228 ComputationBuilder builder(client_, TestName()); 229 auto lhs = builder.ConstantR1<int32>({0, 1}); 230 auto rhs = builder.ConstantR1<int32>({1, 1}); 231 builder.ConvertElementType(builder.Eq(lhs, rhs), F32); 232 233 ComputeAndCompareR1<float>(&builder, {0.0, 1.0}, {}); 234 } 235 236 } // namespace 237 } // namespace xla 238