1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Native XLA implementations of simple unary Ops 17 18 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" 19 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 21 #include "tensorflow/compiler/xla/client/client_library.h" 22 #include "tensorflow/compiler/xla/client/computation_builder.h" 23 #include "tensorflow/core/framework/kernel_def_builder.h" 24 25 namespace tensorflow { 26 namespace { 27 28 // A subclass of a TlaUnaryOp must build the lambda computation that 29 // describes the scalar->scalar function to apply to each element of 30 // the input. 31 #define XLAJIT_MAKE_UNARY(NAME, COMPUTATION) \ 32 class NAME##Op : public XlaOpKernel { \ 33 public: \ 34 explicit NAME##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \ 35 void Compile(XlaOpKernelContext* ctx) { \ 36 xla::ComputationBuilder* b = ctx->builder(); \ 37 xla::ComputationDataHandle x = ctx->Input(0); \ 38 xla::ComputationDataHandle y = COMPUTATION; \ 39 ctx->SetOutput(0, y); \ 40 } \ 41 }; \ 42 REGISTER_XLA_OP(Name(#NAME), NAME##Op); 43 44 XLAJIT_MAKE_UNARY(ComplexAbs, b->Abs(x)); 45 46 XLAJIT_MAKE_UNARY(Angle, b->Atan2(b->Imag(x), b->Real(x))); 47 48 XLAJIT_MAKE_UNARY(Conj, b->Conj(x)); 49 50 // Return x if x>0, otherwise -x. 51 XLAJIT_MAKE_UNARY(Abs, b->Abs(x)); 52 53 // acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) 54 XLAJIT_MAKE_UNARY( 55 Acos, 56 b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), 57 b->Atan2(b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)), 58 b->Mul(x, x)), 59 XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), 60 b->Add(XlaHelpers::One(b, input_type(0)), x)))); 61 62 // acosh(x) = log(x + sqrt(x^2 - 1)) 63 XLAJIT_MAKE_UNARY( 64 Acosh, 65 b->Log(b->Add(x, b->Pow(b->Sub(b->Mul(x, x), 66 XlaHelpers::One(b, input_type(0))), 67 XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); 68 69 // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) 70 XLAJIT_MAKE_UNARY( 71 Asin, 72 b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), 73 b->Atan2(x, b->Add(XlaHelpers::One(b, input_type(0)), 74 b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)), 75 b->Mul(x, x)), 76 XlaHelpers::FloatLiteral(b, input_type(0), 77 0.5)))))); 78 79 // asinh(x) = log(x + sqrt(x^2 + 1)) 80 XLAJIT_MAKE_UNARY( 81 Asinh, 82 b->Log(b->Add(x, b->Pow(b->Add(b->Mul(x, x), 83 XlaHelpers::One(b, input_type(0))), 84 XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); 85 86 XLAJIT_MAKE_UNARY(Atan, b->Atan2(x, XlaHelpers::One(b, input_type(0)))); 87 88 // atanh(x) = 0.5 * log((1 + x) / (1 - x)) 89 XLAJIT_MAKE_UNARY( 90 Atanh, b->Mul(b->Log(b->Div(b->Add(XlaHelpers::One(b, input_type(0)), x), 91 b->Sub(XlaHelpers::One(b, input_type(0)), x))), 92 XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); 93 XLAJIT_MAKE_UNARY(Ceil, b->Ceil(x)); 94 XLAJIT_MAKE_UNARY(Cos, b->Cos(x)); 95 XLAJIT_MAKE_UNARY(Cosh, 96 b->Mul(b->Add(b->Exp(x), b->Exp(b->Neg(x))), 97 XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); 98 XLAJIT_MAKE_UNARY(Sin, b->Sin(x)); 99 XLAJIT_MAKE_UNARY(Exp, b->Exp(x)); 100 101 // TODO(b/34703906): use a more accurate implementation of expm1. 102 XLAJIT_MAKE_UNARY(Expm1, b->Sub(b->Exp(x), XlaHelpers::One(b, input_type(0)))); 103 104 XLAJIT_MAKE_UNARY(Floor, b->Floor(x)); 105 XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x)); 106 XLAJIT_MAKE_UNARY(IsInf, b->Eq(b->Abs(x), 107 XlaHelpers::FloatLiteral( 108 b, input_type(0), 109 std::numeric_limits<double>::infinity()))); 110 XLAJIT_MAKE_UNARY(IsNan, b->Ne(x, x)); 111 // Return 1/x 112 XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x)); 113 XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x)); 114 XLAJIT_MAKE_UNARY(Log, b->Log(x)); 115 116 // TODO(b/34703906): use a more accurate implementation of log1p. 117 XLAJIT_MAKE_UNARY(Log1p, b->Log(b->Add(XlaHelpers::One(b, input_type(0)), x))); 118 119 XLAJIT_MAKE_UNARY(Invert, b->Not(x)); 120 XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x)); 121 XLAJIT_MAKE_UNARY(Neg, b->Neg(x)); 122 123 // Implements Banker's rounding: numbers that are equidistant between two 124 // integers are rounded towards even. 125 static xla::ComputationDataHandle Round(xla::ComputationBuilder* b, 126 DataType dtype, 127 const xla::ComputationDataHandle& x) { 128 auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); 129 auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); 130 auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); 131 132 auto round_val = b->Floor(x); 133 auto fraction = b->Sub(x, round_val); 134 auto nearest_even_int = 135 b->Sub(round_val, b->Mul(two, b->Floor(b->Mul(half, x)))); 136 auto is_odd = b->Eq(nearest_even_int, one); 137 return b->Select( 138 b->Or(b->Gt(fraction, half), b->And(b->Eq(fraction, half), is_odd)), 139 b->Add(round_val, one), round_val); 140 } 141 142 XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x)); 143 XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x)); 144 145 XLAJIT_MAKE_UNARY(Rsqrt, 146 b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5))); 147 148 // Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2. 149 static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b, 150 DataType dtype, 151 const xla::ComputationDataHandle& x) { 152 auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); 153 return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x)))); 154 } 155 XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x)); 156 157 // Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. 158 XLAJIT_MAKE_UNARY(Sign, b->Sign(x)); 159 XLAJIT_MAKE_UNARY(Sinh, 160 b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))), 161 XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); 162 163 static xla::ComputationDataHandle Softplus( 164 xla::ComputationBuilder* b, DataType dtype, 165 const xla::ComputationDataHandle& features) { 166 xla::ComputationDataHandle threshold = 167 b->Add(b->Log(XlaHelpers::Epsilon(b, dtype)), 168 XlaHelpers::FloatLiteral(b, dtype, 2.0)); 169 // Value above which exp(x) may overflow, but softplus(x) == x 170 // is within machine epsilon. 171 xla::ComputationDataHandle too_large = b->Gt(features, b->Neg(threshold)); 172 // Value below which exp(x) may underflow, but softplus(x) == exp(x) 173 // is within machine epsilon. 174 xla::ComputationDataHandle too_small = b->Lt(features, threshold); 175 xla::ComputationDataHandle features_exp = b->Exp(features); 176 xla::ComputationDataHandle output = b->Select( 177 too_large, features, 178 b->Select(too_small, features_exp, 179 b->Log(b->Add(features_exp, XlaHelpers::One(b, dtype))))); 180 return output; 181 } 182 XLAJIT_MAKE_UNARY(Softplus, Softplus(b, input_type(0), x)); 183 184 // softsign(x) = x / (abs(x) + 1) 185 XLAJIT_MAKE_UNARY(Softsign, 186 b->Div(x, 187 b->Add(b->Abs(x), XlaHelpers::One(b, input_type(0))))); 188 XLAJIT_MAKE_UNARY(Sqrt, 189 b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); 190 XLAJIT_MAKE_UNARY(Square, b->Mul(x, x)); 191 XLAJIT_MAKE_UNARY(Tan, b->Div(b->Sin(x), b->Cos(x))); 192 XLAJIT_MAKE_UNARY(Tanh, b->Tanh(x)); 193 194 XLAJIT_MAKE_UNARY(Real, b->Real(x)); 195 XLAJIT_MAKE_UNARY(Imag, b->Imag(x)); 196 197 #undef XLAJIT_MAKE_UNARY 198 199 } // namespace 200 } // namespace tensorflow 201