1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific Ops for softmax. 17 18 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/framework/tensor_shape.h" 24 25 namespace tensorflow { 26 namespace { 27 28 class SoftmaxOp : public XlaOpKernel { 29 public: 30 explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 31 log_ = StringPiece(type_string()).starts_with("Log"); 32 } 33 34 void Compile(XlaOpKernelContext* ctx) override { 35 const TensorShape logits_shape = ctx->InputShape(0); 36 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape), 37 errors::InvalidArgument("logits must be 2-dimensional")); 38 39 const int kBatchDim = 0; 40 const int kClassDim = 1; 41 42 const DataType type = input_type(0); 43 auto logits = ctx->Input(0); 44 45 xla::ComputationBuilder* b = ctx->builder(); 46 const xla::Computation& max_func = *ctx->GetOrCreateMax(type); 47 const xla::Computation& add_func = *ctx->GetOrCreateAdd(type); 48 49 // Find the max in each batch, resulting in a tensor of shape [batch] 50 auto logits_max = 51 b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); 52 // Subtract the max in batch b from every element in batch b. Broadcasts 53 // along the batch dimension. 54 auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); 55 xla::ComputationDataHandle softmax; 56 if (log_) { 57 // softmax = shifted_logits - log(sum(exp(shifted_logits))) 58 auto log_sum_exp = 59 b->Log(b->Reduce(b->Exp(shifted_logits), XlaHelpers::Zero(b, type), 60 add_func, {kClassDim})); 61 softmax = b->Sub(shifted_logits, log_sum_exp, {kBatchDim}); 62 } else { 63 // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) 64 auto exp_shifted = b->Exp(shifted_logits); 65 auto sum_exp = b->Reduce(exp_shifted, XlaHelpers::Zero(b, type), add_func, 66 {kClassDim}); 67 softmax = b->Div(exp_shifted, sum_exp, {kBatchDim}); 68 } 69 70 ctx->SetOutput(0, softmax); 71 } 72 73 private: 74 bool log_; 75 }; 76 77 REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp); 78 REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp); 79 80 std::pair<xla::ComputationDataHandle, xla::ComputationDataHandle> 81 CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, 82 const xla::ComputationDataHandle& logits, 83 const xla::ComputationDataHandle& labels) { 84 const xla::Computation& max_func = *ctx->GetOrCreateMax(type); 85 const xla::Computation& add_func = *ctx->GetOrCreateAdd(type); 86 87 const int kBatchDim = 0; 88 const int kClassDim = 1; 89 90 xla::ComputationBuilder* b = ctx->builder(); 91 // Find the max in each batch, resulting in a tensor of shape [batch] 92 auto logits_max = 93 b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); 94 95 // Subtract the max in batch b from every element in batch b. 96 // Broadcasts along the batch dimension. 97 auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); 98 99 // exp(logits - max_logits) 100 auto exp_shifted_logits = b->Exp(shifted_logits); 101 102 // sum_{class} (exp(logits - max_logits)) 103 auto sum_exp = b->Reduce(exp_shifted_logits, XlaHelpers::Zero(b, type), 104 add_func, {kClassDim}); 105 106 // log(sum(exp(logits - max_logits))) 107 auto log_sum_exp = b->Log(sum_exp); 108 109 // sum(-labels * 110 // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) 111 // along classes 112 // (The subtraction broadcasts along the batch dimension.) 113 xla::ComputationDataHandle loss = b->Reduce( 114 b->Mul(b->Neg(labels), b->Sub(shifted_logits, log_sum_exp, {kBatchDim})), 115 XlaHelpers::Zero(b, type), add_func, {kClassDim}); 116 117 // backprop: prob - labels, where 118 // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) 119 // (where the division broadcasts along the batch dimension) 120 xla::ComputationDataHandle backprop = 121 b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels); 122 return {loss, backprop}; 123 } 124 125 class SoftmaxXentWithLogitsOp : public XlaOpKernel { 126 public: 127 explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* ctx) 128 : XlaOpKernel(ctx) {} 129 130 void Compile(XlaOpKernelContext* ctx) override { 131 const TensorShape logits_shape = ctx->InputShape(0); 132 const TensorShape labels_shape = ctx->InputShape(1); 133 OP_REQUIRES(ctx, logits_shape.IsSameSize(labels_shape), 134 errors::InvalidArgument( 135 "logits and labels must be same size: logits_size=", 136 logits_shape.DebugString(), 137 " labels_size=", labels_shape.DebugString())); 138 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape), 139 errors::InvalidArgument("logits must be 2-dimensional")); 140 // As we already tested that both inputs have the same shape no need to 141 // check that "labels" is a matrix too. 142 143 const DataType type = input_type(0); 144 auto logits = ctx->Input(0); 145 auto labels = ctx->Input(1); 146 147 xla::ComputationDataHandle loss, backprop; 148 std::tie(loss, backprop) = 149 CrossEntropyWithLogits(ctx, type, logits, labels); 150 ctx->SetOutput(0, loss); 151 ctx->SetOutput(1, backprop); 152 } 153 }; 154 155 REGISTER_XLA_OP(Name("SoftmaxCrossEntropyWithLogits"), SoftmaxXentWithLogitsOp); 156 157 class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { 158 public: 159 explicit SparseSoftmaxXentWithLogitsOp(OpKernelConstruction* ctx) 160 : XlaOpKernel(ctx) {} 161 162 void Compile(XlaOpKernelContext* ctx) override { 163 const TensorShape logits_shape = ctx->InputShape(0); 164 const TensorShape labels_shape = ctx->InputShape(1); 165 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape), 166 errors::InvalidArgument("logits must be 2-D, but got shape ", 167 logits_shape.DebugString())); 168 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_shape), 169 errors::InvalidArgument("labels must be 1-D, but got shape ", 170 labels_shape.DebugString())); 171 OP_REQUIRES(ctx, logits_shape.dim_size(0) == labels_shape.dim_size(0), 172 errors::InvalidArgument( 173 "logits and labels must have the same first dimension, " 174 "got logits shape ", 175 logits_shape.DebugString(), " and labels shape ", 176 labels_shape.DebugString())); 177 OP_REQUIRES(ctx, logits_shape.dim_size(1) > 0, 178 errors::InvalidArgument( 179 "Must have at least one class, but got logits shape ", 180 logits_shape.DebugString())); 181 182 int64 batch_size = logits_shape.dim_size(0); 183 int64 depth = logits_shape.dim_size(1); 184 185 DataType logits_type = input_type(0); 186 DataType indices_type = input_type(1); 187 188 xla::ComputationDataHandle indices = ctx->Input(1); 189 190 xla::ComputationBuilder* builder = ctx->builder(); 191 xla::ComputationDataHandle labels; 192 OP_REQUIRES_OK(ctx, 193 XlaHelpers::OneHot( 194 builder, depth, /*axis=*/1, input_type(1), labels_shape, 195 indices, XlaHelpers::One(builder, logits_type), 196 XlaHelpers::Zero(builder, logits_type), &labels)); 197 198 // If any of the indices are out of range, we must populate the labels with 199 // NaNs to obey the interface contract of 200 // tf.nn.sparse_softmax_cross_entropy_with_logits. 201 // Builds a vector of {batch_size} that is 0 if the index is in range, or 202 // NaN otherwise; then add that vector to the labels to force out-of-range 203 // values to NaNs. 204 xla::ComputationDataHandle nan_or_zero = builder->Select( 205 builder->And( 206 builder->Le(XlaHelpers::Zero(builder, indices_type), indices), 207 builder->Lt(indices, XlaHelpers::IntegerLiteral( 208 builder, indices_type, depth))), 209 builder->Broadcast(XlaHelpers::Zero(builder, logits_type), 210 {batch_size}), 211 builder->Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN), 212 {batch_size})); 213 labels = builder->Add(labels, nan_or_zero, {0}); 214 215 xla::ComputationDataHandle loss, backprop; 216 std::tie(loss, backprop) = 217 CrossEntropyWithLogits(ctx, logits_type, ctx->Input(0), labels); 218 ctx->SetOutput(0, loss); 219 ctx->SetOutput(1, backprop); 220 } 221 }; 222 223 REGISTER_XLA_OP(Name("SparseSoftmaxCrossEntropyWithLogits"), 224 SparseSoftmaxXentWithLogitsOp); 225 226 } // namespace 227 } // namespace tensorflow 228