1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include "tensorflow/core/kernels/xent_op.h" 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 27 namespace tensorflow { 28 29 typedef Eigen::ThreadPoolDevice CPUDevice; 30 typedef Eigen::GpuDevice GPUDevice; 31 #ifdef TENSORFLOW_USE_SYCL 32 typedef Eigen::SyclDevice SYCLDevice; 33 #endif // TENSORFLOW_USE_SYCL 34 35 template <typename Device, typename T> 36 class SoftmaxXentWithLogitsOp : public OpKernel { 37 public: 38 explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* context) 39 : OpKernel(context) {} 40 41 void Compute(OpKernelContext* context) override { 42 const Tensor& logits_in = context->input(0); 43 const Tensor& labels_in = context->input(1); 44 OP_REQUIRES(context, logits_in.IsSameSize(labels_in), 45 errors::InvalidArgument( 46 "logits and labels must be same size: logits_size=", 47 logits_in.shape().DebugString(), 48 " labels_size=", labels_in.shape().DebugString())); 49 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()), 50 errors::InvalidArgument("logits must be 2-dimensional")); 51 // As we already tested that both inputs have the same shape no need to 52 // check that "labels" is a matrix too. 53 54 // loss is 1-D (one per example), and size is batch_size. 55 56 Tensor scratch; 57 OP_REQUIRES_OK( 58 context, context->allocate_temp(DataTypeToEnum<T>::value, 59 TensorShape({logits_in.dim_size(0), 1}), 60 &scratch)); 61 62 Tensor* loss_out = nullptr; 63 OP_REQUIRES_OK(context, 64 context->allocate_output( 65 0, TensorShape({logits_in.dim_size(0)}), &loss_out)); 66 Tensor* back_out = nullptr; 67 // Try to reuse the logits_in buffer for the backprop output. 68 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 69 {0}, 1, logits_in.shape(), &back_out)); 70 if (logits_in.dim_size(0) > 0) { 71 functor::XentFunctor<Device, T> functor; 72 functor(context->eigen_device<Device>(), logits_in.matrix<T>(), 73 labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(), 74 back_out->matrix<T>()); 75 } 76 } 77 }; 78 79 // Partial specialization for a CPUDevice, that uses the Eigen implementation 80 // from XentEigenImpl. 81 namespace functor { 82 template <typename Device, typename T> 83 struct XentFunctorBase { 84 void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits, 85 typename TTypes<T>::ConstMatrix labels, 86 typename TTypes<T>::Matrix scratch, 87 typename TTypes<T>::Vec loss, 88 typename TTypes<T>::Matrix backprop) { 89 XentEigenImpl<Device, T>::Compute(d, logits, labels, scratch, loss, 90 backprop); 91 } 92 }; 93 94 template <typename T> 95 struct XentFunctor<CPUDevice, T> : XentFunctorBase<CPUDevice, T> {}; 96 97 #ifdef TENSORFLOW_USE_SYCL 98 template <typename T> 99 struct XentFunctor<SYCLDevice, T> : XentFunctorBase<SYCLDevice, T> {}; 100 #endif // TENSORFLOW_USE_SYCL 101 } // namespace functor 102 103 #define REGISTER_CPU(T) \ 104 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") \ 105 .Device(DEVICE_CPU) \ 106 .TypeConstraint<T>("T"), \ 107 SoftmaxXentWithLogitsOp<CPUDevice, T>); 108 TF_CALL_half(REGISTER_CPU); 109 TF_CALL_float(REGISTER_CPU); 110 TF_CALL_double(REGISTER_CPU); 111 112 #if GOOGLE_CUDA 113 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") 114 .Device(DEVICE_GPU) 115 .TypeConstraint<Eigen::half>("T"), 116 SoftmaxXentWithLogitsOp<GPUDevice, Eigen::half>); 117 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") 118 .Device(DEVICE_GPU) 119 .TypeConstraint<float>("T"), 120 SoftmaxXentWithLogitsOp<GPUDevice, float>); 121 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") 122 .Device(DEVICE_GPU) 123 .TypeConstraint<double>("T"), 124 SoftmaxXentWithLogitsOp<GPUDevice, double>); 125 #endif // GOOGLE_CUDA 126 127 #ifdef TENSORFLOW_USE_SYCL 128 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") 129 .Device(DEVICE_SYCL) 130 .TypeConstraint<float>("T"), 131 SoftmaxXentWithLogitsOp<SYCLDevice, float>); 132 #endif // TENSORFLOW_USE_SYCL 133 134 } // namespace tensorflow 135