Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // XLA-specific Ops for softmax.
     17 
     18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/framework/tensor_shape.h"
     24 
     25 namespace tensorflow {
     26 namespace {
     27 
     28 class SoftmaxOp : public XlaOpKernel {
     29  public:
     30   explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
     31     log_ = StringPiece(type_string()).starts_with("Log");
     32   }
     33 
     34   void Compile(XlaOpKernelContext* ctx) override {
     35     const TensorShape logits_shape = ctx->InputShape(0);
     36     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
     37                 errors::InvalidArgument("logits must be 2-dimensional"));
     38 
     39     const int kBatchDim = 0;
     40     const int kClassDim = 1;
     41 
     42     const DataType type = input_type(0);
     43     auto logits = ctx->Input(0);
     44 
     45     xla::ComputationBuilder* b = ctx->builder();
     46     const xla::Computation& max_func = *ctx->GetOrCreateMax(type);
     47     const xla::Computation& add_func = *ctx->GetOrCreateAdd(type);
     48 
     49     // Find the max in each batch, resulting in a tensor of shape [batch]
     50     auto logits_max =
     51         b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});
     52     // Subtract the max in batch b from every element in batch b. Broadcasts
     53     // along the batch dimension.
     54     auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim});
     55     xla::ComputationDataHandle softmax;
     56     if (log_) {
     57       // softmax = shifted_logits - log(sum(exp(shifted_logits)))
     58       auto log_sum_exp =
     59           b->Log(b->Reduce(b->Exp(shifted_logits), XlaHelpers::Zero(b, type),
     60                            add_func, {kClassDim}));
     61       softmax = b->Sub(shifted_logits, log_sum_exp, {kBatchDim});
     62     } else {
     63       // softmax = exp(shifted_logits) / sum(exp(shifted_logits))
     64       auto exp_shifted = b->Exp(shifted_logits);
     65       auto sum_exp = b->Reduce(exp_shifted, XlaHelpers::Zero(b, type), add_func,
     66                                {kClassDim});
     67       softmax = b->Div(exp_shifted, sum_exp, {kBatchDim});
     68     }
     69 
     70     ctx->SetOutput(0, softmax);
     71   }
     72 
     73  private:
     74   bool log_;
     75 };
     76 
     77 REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp);
     78 REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp);
     79 
     80 std::pair<xla::ComputationDataHandle, xla::ComputationDataHandle>
     81 CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type,
     82                        const xla::ComputationDataHandle& logits,
     83                        const xla::ComputationDataHandle& labels) {
     84   const xla::Computation& max_func = *ctx->GetOrCreateMax(type);
     85   const xla::Computation& add_func = *ctx->GetOrCreateAdd(type);
     86 
     87   const int kBatchDim = 0;
     88   const int kClassDim = 1;
     89 
     90   xla::ComputationBuilder* b = ctx->builder();
     91   // Find the max in each batch, resulting in a tensor of shape [batch]
     92   auto logits_max =
     93       b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});
     94 
     95   // Subtract the max in batch b from every element in batch b.
     96   // Broadcasts along the batch dimension.
     97   auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim});
     98 
     99   // exp(logits - max_logits)
    100   auto exp_shifted_logits = b->Exp(shifted_logits);
    101 
    102   // sum_{class} (exp(logits - max_logits))
    103   auto sum_exp = b->Reduce(exp_shifted_logits, XlaHelpers::Zero(b, type),
    104                            add_func, {kClassDim});
    105 
    106   // log(sum(exp(logits - max_logits)))
    107   auto log_sum_exp = b->Log(sum_exp);
    108 
    109   // sum(-labels *
    110   //    ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
    111   // along classes
    112   // (The subtraction broadcasts along the batch dimension.)
    113   xla::ComputationDataHandle loss = b->Reduce(
    114       b->Mul(b->Neg(labels), b->Sub(shifted_logits, log_sum_exp, {kBatchDim})),
    115       XlaHelpers::Zero(b, type), add_func, {kClassDim});
    116 
    117   // backprop: prob - labels, where
    118   //   prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
    119   //     (where the division broadcasts along the batch dimension)
    120   xla::ComputationDataHandle backprop =
    121       b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels);
    122   return {loss, backprop};
    123 }
    124 
    125 class SoftmaxXentWithLogitsOp : public XlaOpKernel {
    126  public:
    127   explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* ctx)
    128       : XlaOpKernel(ctx) {}
    129 
    130   void Compile(XlaOpKernelContext* ctx) override {
    131     const TensorShape logits_shape = ctx->InputShape(0);
    132     const TensorShape labels_shape = ctx->InputShape(1);
    133     OP_REQUIRES(ctx, logits_shape.IsSameSize(labels_shape),
    134                 errors::InvalidArgument(
    135                     "logits and labels must be same size: logits_size=",
    136                     logits_shape.DebugString(),
    137                     " labels_size=", labels_shape.DebugString()));
    138     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
    139                 errors::InvalidArgument("logits must be 2-dimensional"));
    140     // As we already tested that both inputs have the same shape no need to
    141     // check that "labels" is a matrix too.
    142 
    143     const DataType type = input_type(0);
    144     auto logits = ctx->Input(0);
    145     auto labels = ctx->Input(1);
    146 
    147     xla::ComputationDataHandle loss, backprop;
    148     std::tie(loss, backprop) =
    149         CrossEntropyWithLogits(ctx, type, logits, labels);
    150     ctx->SetOutput(0, loss);
    151     ctx->SetOutput(1, backprop);
    152   }
    153 };
    154 
    155 REGISTER_XLA_OP(Name("SoftmaxCrossEntropyWithLogits"), SoftmaxXentWithLogitsOp);
    156 
    157 class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
    158  public:
    159   explicit SparseSoftmaxXentWithLogitsOp(OpKernelConstruction* ctx)
    160       : XlaOpKernel(ctx) {}
    161 
    162   void Compile(XlaOpKernelContext* ctx) override {
    163     const TensorShape logits_shape = ctx->InputShape(0);
    164     const TensorShape labels_shape = ctx->InputShape(1);
    165     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
    166                 errors::InvalidArgument("logits must be 2-D, but got shape ",
    167                                         logits_shape.DebugString()));
    168     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_shape),
    169                 errors::InvalidArgument("labels must be 1-D, but got shape ",
    170                                         labels_shape.DebugString()));
    171     OP_REQUIRES(ctx, logits_shape.dim_size(0) == labels_shape.dim_size(0),
    172                 errors::InvalidArgument(
    173                     "logits and labels must have the same first dimension, "
    174                     "got logits shape ",
    175                     logits_shape.DebugString(), " and labels shape ",
    176                     labels_shape.DebugString()));
    177     OP_REQUIRES(ctx, logits_shape.dim_size(1) > 0,
    178                 errors::InvalidArgument(
    179                     "Must have at least one class, but got logits shape ",
    180                     logits_shape.DebugString()));
    181 
    182     int64 batch_size = logits_shape.dim_size(0);
    183     int64 depth = logits_shape.dim_size(1);
    184 
    185     DataType logits_type = input_type(0);
    186     DataType indices_type = input_type(1);
    187 
    188     xla::ComputationDataHandle indices = ctx->Input(1);
    189 
    190     xla::ComputationBuilder* builder = ctx->builder();
    191     xla::ComputationDataHandle labels;
    192     OP_REQUIRES_OK(ctx,
    193                    XlaHelpers::OneHot(
    194                        builder, depth, /*axis=*/1, input_type(1), labels_shape,
    195                        indices, XlaHelpers::One(builder, logits_type),
    196                        XlaHelpers::Zero(builder, logits_type), &labels));
    197 
    198     // If any of the indices are out of range, we must populate the labels with
    199     // NaNs to obey the interface contract of
    200     // tf.nn.sparse_softmax_cross_entropy_with_logits.
    201     // Builds a vector of {batch_size} that is 0 if the index is in range, or
    202     // NaN otherwise; then add that vector to the labels to force out-of-range
    203     // values to NaNs.
    204     xla::ComputationDataHandle nan_or_zero = builder->Select(
    205         builder->And(
    206             builder->Le(XlaHelpers::Zero(builder, indices_type), indices),
    207             builder->Lt(indices, XlaHelpers::IntegerLiteral(
    208                                      builder, indices_type, depth))),
    209         builder->Broadcast(XlaHelpers::Zero(builder, logits_type),
    210                            {batch_size}),
    211         builder->Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN),
    212                            {batch_size}));
    213     labels = builder->Add(labels, nan_or_zero, {0});
    214 
    215     xla::ComputationDataHandle loss, backprop;
    216     std::tie(loss, backprop) =
    217         CrossEntropyWithLogits(ctx, logits_type, ctx->Input(0), labels);
    218     ctx->SetOutput(0, loss);
    219     ctx->SetOutput(1, backprop);
    220   }
    221 };
    222 
    223 REGISTER_XLA_OP(Name("SparseSoftmaxCrossEntropyWithLogits"),
    224                 SparseSoftmaxXentWithLogitsOp);
    225 
    226 }  // namespace
    227 }  // namespace tensorflow
    228