Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Native XLA implementations of simple unary Ops
     17 
     18 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
     19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     21 #include "tensorflow/compiler/xla/client/client_library.h"
     22 #include "tensorflow/compiler/xla/client/computation_builder.h"
     23 #include "tensorflow/core/framework/kernel_def_builder.h"
     24 
     25 namespace tensorflow {
     26 namespace {
     27 
     28 // A subclass of a TlaUnaryOp must build the lambda computation that
     29 // describes the scalar->scalar function to apply to each element of
     30 // the input.
     31 #define XLAJIT_MAKE_UNARY(NAME, COMPUTATION)                           \
     32   class NAME##Op : public XlaOpKernel {                                \
     33    public:                                                             \
     34     explicit NAME##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \
     35     void Compile(XlaOpKernelContext* ctx) {                            \
     36       xla::ComputationBuilder* b = ctx->builder();                     \
     37       xla::ComputationDataHandle x = ctx->Input(0);                    \
     38       xla::ComputationDataHandle y = COMPUTATION;                      \
     39       ctx->SetOutput(0, y);                                            \
     40     }                                                                  \
     41   };                                                                   \
     42   REGISTER_XLA_OP(Name(#NAME), NAME##Op);
     43 
     44 XLAJIT_MAKE_UNARY(ComplexAbs, b->Abs(x));
     45 
     46 XLAJIT_MAKE_UNARY(Angle, b->Atan2(b->Imag(x), b->Real(x)));
     47 
     48 XLAJIT_MAKE_UNARY(Conj, b->Conj(x));
     49 
     50 // Return x if x>0, otherwise -x.
     51 XLAJIT_MAKE_UNARY(Abs, b->Abs(x));
     52 
     53 // acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x))
     54 XLAJIT_MAKE_UNARY(
     55     Acos,
     56     b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
     57            b->Atan2(b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)),
     58                                   b->Mul(x, x)),
     59                            XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
     60                     b->Add(XlaHelpers::One(b, input_type(0)), x))));
     61 
     62 // acosh(x) = log(x + sqrt(x^2 - 1))
     63 XLAJIT_MAKE_UNARY(
     64     Acosh,
     65     b->Log(b->Add(x, b->Pow(b->Sub(b->Mul(x, x),
     66                                    XlaHelpers::One(b, input_type(0))),
     67                             XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
     68 
     69 // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
     70 XLAJIT_MAKE_UNARY(
     71     Asin,
     72     b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
     73            b->Atan2(x, b->Add(XlaHelpers::One(b, input_type(0)),
     74                               b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)),
     75                                             b->Mul(x, x)),
     76                                      XlaHelpers::FloatLiteral(b, input_type(0),
     77                                                               0.5))))));
     78 
     79 // asinh(x) = log(x + sqrt(x^2 + 1))
     80 XLAJIT_MAKE_UNARY(
     81     Asinh,
     82     b->Log(b->Add(x, b->Pow(b->Add(b->Mul(x, x),
     83                                    XlaHelpers::One(b, input_type(0))),
     84                             XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
     85 
     86 XLAJIT_MAKE_UNARY(Atan, b->Atan2(x, XlaHelpers::One(b, input_type(0))));
     87 
     88 // atanh(x) = 0.5 * log((1 + x) / (1 - x))
     89 XLAJIT_MAKE_UNARY(
     90     Atanh, b->Mul(b->Log(b->Div(b->Add(XlaHelpers::One(b, input_type(0)), x),
     91                                 b->Sub(XlaHelpers::One(b, input_type(0)), x))),
     92                   XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
     93 XLAJIT_MAKE_UNARY(Ceil, b->Ceil(x));
     94 XLAJIT_MAKE_UNARY(Cos, b->Cos(x));
     95 XLAJIT_MAKE_UNARY(Cosh,
     96                   b->Mul(b->Add(b->Exp(x), b->Exp(b->Neg(x))),
     97                          XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
     98 XLAJIT_MAKE_UNARY(Sin, b->Sin(x));
     99 XLAJIT_MAKE_UNARY(Exp, b->Exp(x));
    100 
    101 // TODO(b/34703906): use a more accurate implementation of expm1.
    102 XLAJIT_MAKE_UNARY(Expm1, b->Sub(b->Exp(x), XlaHelpers::One(b, input_type(0))));
    103 
    104 XLAJIT_MAKE_UNARY(Floor, b->Floor(x));
    105 XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x));
    106 XLAJIT_MAKE_UNARY(IsInf, b->Eq(b->Abs(x),
    107                                XlaHelpers::FloatLiteral(
    108                                    b, input_type(0),
    109                                    std::numeric_limits<double>::infinity())));
    110 XLAJIT_MAKE_UNARY(IsNan, b->Ne(x, x));
    111 // Return 1/x
    112 XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x));
    113 XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x));
    114 XLAJIT_MAKE_UNARY(Log, b->Log(x));
    115 
    116 // TODO(b/34703906): use a more accurate implementation of log1p.
    117 XLAJIT_MAKE_UNARY(Log1p, b->Log(b->Add(XlaHelpers::One(b, input_type(0)), x)));
    118 
    119 XLAJIT_MAKE_UNARY(Invert, b->Not(x));
    120 XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x));
    121 XLAJIT_MAKE_UNARY(Neg, b->Neg(x));
    122 
    123 // Implements Banker's rounding: numbers that are equidistant between two
    124 // integers are rounded towards even.
    125 static xla::ComputationDataHandle Round(xla::ComputationBuilder* b,
    126                                         DataType dtype,
    127                                         const xla::ComputationDataHandle& x) {
    128   auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
    129   auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
    130   auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
    131 
    132   auto round_val = b->Floor(x);
    133   auto fraction = b->Sub(x, round_val);
    134   auto nearest_even_int =
    135       b->Sub(round_val, b->Mul(two, b->Floor(b->Mul(half, x))));
    136   auto is_odd = b->Eq(nearest_even_int, one);
    137   return b->Select(
    138       b->Or(b->Gt(fraction, half), b->And(b->Eq(fraction, half), is_odd)),
    139       b->Add(round_val, one), round_val);
    140 }
    141 
    142 XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x));
    143 XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x));
    144 
    145 XLAJIT_MAKE_UNARY(Rsqrt,
    146                   b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5)));
    147 
    148 // Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2.
    149 static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b,
    150                                           DataType dtype,
    151                                           const xla::ComputationDataHandle& x) {
    152   auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
    153   return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x))));
    154 }
    155 XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x));
    156 
    157 // Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
    158 XLAJIT_MAKE_UNARY(Sign, b->Sign(x));
    159 XLAJIT_MAKE_UNARY(Sinh,
    160                   b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))),
    161                          XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
    162 
    163 static xla::ComputationDataHandle Softplus(
    164     xla::ComputationBuilder* b, DataType dtype,
    165     const xla::ComputationDataHandle& features) {
    166   xla::ComputationDataHandle threshold =
    167       b->Add(b->Log(XlaHelpers::Epsilon(b, dtype)),
    168              XlaHelpers::FloatLiteral(b, dtype, 2.0));
    169   // Value above which exp(x) may overflow, but softplus(x) == x
    170   // is within machine epsilon.
    171   xla::ComputationDataHandle too_large = b->Gt(features, b->Neg(threshold));
    172   // Value below which exp(x) may underflow, but softplus(x) == exp(x)
    173   // is within machine epsilon.
    174   xla::ComputationDataHandle too_small = b->Lt(features, threshold);
    175   xla::ComputationDataHandle features_exp = b->Exp(features);
    176   xla::ComputationDataHandle output = b->Select(
    177       too_large, features,
    178       b->Select(too_small, features_exp,
    179                 b->Log(b->Add(features_exp, XlaHelpers::One(b, dtype)))));
    180   return output;
    181 }
    182 XLAJIT_MAKE_UNARY(Softplus, Softplus(b, input_type(0), x));
    183 
    184 // softsign(x) = x / (abs(x) + 1)
    185 XLAJIT_MAKE_UNARY(Softsign,
    186                   b->Div(x,
    187                          b->Add(b->Abs(x), XlaHelpers::One(b, input_type(0)))));
    188 XLAJIT_MAKE_UNARY(Sqrt,
    189                   b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
    190 XLAJIT_MAKE_UNARY(Square, b->Mul(x, x));
    191 XLAJIT_MAKE_UNARY(Tan, b->Div(b->Sin(x), b->Cos(x)));
    192 XLAJIT_MAKE_UNARY(Tanh, b->Tanh(x));
    193 
    194 XLAJIT_MAKE_UNARY(Real, b->Real(x));
    195 XLAJIT_MAKE_UNARY(Imag, b->Imag(x));
    196 
    197 #undef XLAJIT_MAKE_UNARY
    198 
    199 }  // namespace
    200 }  // namespace tensorflow
    201