1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include "tensorflow/core/kernels/sparse_xent_op.h" 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/framework/tensor_shape.h" 25 #include "tensorflow/core/framework/tensor_types.h" 26 27 namespace tensorflow { 28 29 typedef Eigen::ThreadPoolDevice CPUDevice; 30 typedef Eigen::GpuDevice GPUDevice; 31 32 template <typename Index> 33 Status CheckInvalidLabelIndex(const Tensor& labels, int64 max_index) { 34 if (labels.NumElements() == 0) return Status::OK(); 35 const auto label_values = labels.vec<Index>(); 36 int64 bad_index; 37 auto min_max_dim_value = std::minmax_element( 38 label_values.data(), label_values.data() + label_values.size()); 39 if (*min_max_dim_value.first < 0 || *min_max_dim_value.second >= max_index) { 40 bad_index = (*min_max_dim_value.first < 0) ? *min_max_dim_value.first 41 : *min_max_dim_value.second; 42 return errors::InvalidArgument( 43 "Received a label value of ", bad_index, 44 " which is outside the valid range of [0, ", max_index, 45 "). Label values: ", labels.SummarizeValue(labels.NumElements())); 46 } 47 return Status::OK(); 48 } 49 50 template <typename Device, typename T, typename Index> 51 class SparseSoftmaxXentWithLogitsOp : public OpKernel { 52 public: 53 explicit SparseSoftmaxXentWithLogitsOp(OpKernelConstruction* context) 54 : OpKernel(context) {} 55 56 void Compute(OpKernelContext* context) override { 57 const Tensor& logits = context->input(0); 58 const Tensor& labels = context->input(1); 59 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits.shape()), 60 errors::InvalidArgument("logits must be 2-D, but got shape ", 61 logits.shape().DebugString())); 62 OP_REQUIRES(context, TensorShapeUtils::IsVector(labels.shape()), 63 errors::InvalidArgument("labels must be 1-D, but got shape ", 64 labels.shape().DebugString())); 65 OP_REQUIRES(context, logits.dim_size(0) == labels.dim_size(0), 66 errors::InvalidArgument( 67 "logits and labels must have the same first dimension, " 68 "got logits shape ", 69 logits.shape().DebugString(), " and labels shape ", 70 labels.shape().DebugString())); 71 OP_REQUIRES(context, logits.dim_size(1) > 0, 72 errors::InvalidArgument( 73 "Must have at least one class, but got logits shape ", 74 logits.shape().DebugString())); 75 76 Tensor scratch; 77 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value, 78 labels.shape(), &scratch)); 79 80 Tensor* loss_out = nullptr; 81 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 82 {1}, 0, labels.shape(), &loss_out)); 83 Tensor* back_out = nullptr; 84 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 85 {0}, 1, logits.shape(), &back_out)); 86 87 if (logits.dim_size(0) > 0) { 88 if (std::is_same<Device, CPUDevice>::value) { 89 OP_REQUIRES_OK( 90 context, CheckInvalidLabelIndex<Index>(labels, logits.dim_size(1))); 91 } 92 functor::SparseXentFunctor<Device, T, Index> functor; 93 functor(context->eigen_device<Device>(), logits.matrix<T>(), 94 labels.vec<Index>(), scratch.vec<T>(), loss_out->vec<T>(), 95 back_out->matrix<T>()); 96 } 97 } 98 }; 99 100 // Partial specialization for a CPUDevice, that uses the Eigen implementation 101 // from XentEigenImpl. 102 namespace functor { 103 template <typename T, typename Index> 104 struct SparseXentFunctor<CPUDevice, T, Index> { 105 void operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix logits, 106 typename TTypes<Index>::ConstVec labels, 107 typename TTypes<T>::Vec scratch, typename TTypes<T>::Vec loss, 108 typename TTypes<T>::Matrix backprop) { 109 SparseXentEigenImpl<CPUDevice, T, Index>::Compute(d, logits, labels, 110 scratch, loss, backprop); 111 } 112 }; 113 } // namespace functor 114 115 #define REGISTER(Dev, T, Index) \ 116 REGISTER_KERNEL_BUILDER( \ 117 Name("SparseSoftmaxCrossEntropyWithLogits") \ 118 .Device(DEVICE_##Dev) \ 119 .TypeConstraint<T>("T") \ 120 .TypeConstraint<Index>("Tlabels"), \ 121 SparseSoftmaxXentWithLogitsOp<Dev##Device, T, Index>); 122 REGISTER(CPU, float, int32) 123 REGISTER(CPU, float, int64) 124 REGISTER(CPU, double, int32) 125 REGISTER(CPU, double, int64) 126 REGISTER(CPU, Eigen::half, int32) 127 REGISTER(CPU, Eigen::half, int64) 128 129 #if GOOGLE_CUDA 130 REGISTER(GPU, float, int32) 131 REGISTER(GPU, float, int64) 132 REGISTER(GPU, Eigen::half, int32) 133 REGISTER(GPU, Eigen::half, int64) 134 #endif // GOOGLE_CUDA 135 136 #undef REGISTER 137 138 } // namespace tensorflow 139