1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc 17 18 #define EIGEN_USE_THREADS 19 20 #if GOOGLE_CUDA 21 #define EIGEN_USE_GPU 22 #endif // GOOGLE_CUDA 23 24 #include "tensorflow/core/kernels/one_hot_op.h" 25 26 #include <memory> 27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/register_types.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_shape.h" 32 #include "tensorflow/core/framework/tensor_types.h" 33 #include "tensorflow/core/framework/types.h" 34 #include "tensorflow/core/platform/logging.h" 35 #include "tensorflow/core/platform/macros.h" 36 #include "tensorflow/core/util/overflow.h" 37 38 namespace tensorflow { 39 40 typedef Eigen::ThreadPoolDevice CPUDevice; 41 typedef Eigen::GpuDevice GPUDevice; 42 43 template <typename Device, typename T, typename TI> 44 class OneHotOp : public OpKernel { 45 public: 46 explicit OneHotOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 47 OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_)); 48 } 49 50 void Compute(OpKernelContext* ctx) override { 51 const Tensor& indices = ctx->input(0); 52 const Tensor& depth = ctx->input(1); 53 const Tensor& on_value = ctx->input(2); 54 const Tensor& off_value = ctx->input(3); 55 const TensorShape& indices_shape = indices.shape(); 56 57 const int indices_dims = indices_shape.dims(); 58 const int output_dims = indices_dims + 1; 59 60 // Preliminary validation of sizes. 61 OP_REQUIRES( 62 ctx, axis_ == -1 || (axis_ >= 0 && axis_ < output_dims), 63 errors::InvalidArgument("Expected axis to be -1 or between [0, ", 64 output_dims, "). But received: ", axis_)); 65 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(depth.shape()), 66 errors::InvalidArgument("depth must be a scalar, but got: ", 67 depth.shape().DebugString())); 68 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(on_value.shape()), 69 errors::InvalidArgument("on_value must be a scalar, but got: ", 70 on_value.shape().DebugString())); 71 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(off_value.shape()), 72 errors::InvalidArgument("off_value must be a scalar, but got: ", 73 off_value.shape().DebugString())); 74 75 const int axis = (axis_ == -1) ? indices_dims : axis_; 76 77 // The one-hot dimension. 78 const int32 depth_v = depth.scalar<int32>()(); 79 OP_REQUIRES( 80 ctx, depth_v >= 0, 81 errors::InvalidArgument("depth must be non-negative, got: ", depth_v)); 82 OP_REQUIRES( 83 ctx, 84 MultiplyWithoutOverflow(indices_shape.num_elements(), depth_v) >= 0, 85 errors::InvalidArgument("OneHot result would have shape ", 86 indices_shape.DebugString(), " + [", depth_v, 87 "], which exceeds 2**63 - 1 elements")); 88 89 TensorShape output_shape = indices_shape; 90 output_shape.InsertDim(axis, depth_v); 91 92 auto on_value_t = on_value.scalar<T>(); 93 auto off_value_t = off_value.scalar<T>(); 94 95 Tensor* output; 96 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output)); 97 98 if (output_shape.num_elements() > 0) { 99 // prefix_dim_size == # of elements before the axis 100 // depth_v == # of elements per axis 101 // suffix_dim_size == # of elements after the axis 102 int64 prefix_dim_size = 1; 103 for (int i = 0; i < axis; ++i) { 104 prefix_dim_size *= indices_shape.dim_size(i); 105 } 106 TI suffix_dim_size = indices_shape.num_elements() / prefix_dim_size; 107 108 // Split indices into matrix of size prefix_dim_size x suffix_dim_size 109 auto indices_t = 110 indices.shaped<TI, 2>({prefix_dim_size, suffix_dim_size}); 111 // Split output into 3-Tensor of size: 112 // prefix_dim_size x depth x suffix_dim_size. 113 auto output_t = 114 output->shaped<T, 3>({prefix_dim_size, depth_v, suffix_dim_size}); 115 116 functor::OneHot<Device, T, TI>::Compute(ctx->eigen_device<Device>(), 117 indices_t, on_value_t, 118 off_value_t, &output_t); 119 } 120 } 121 122 private: 123 int32 axis_; 124 125 TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp); 126 }; 127 128 #define REGISTER_ONE_HOT_INDEX(type, index_type) \ 129 REGISTER_KERNEL_BUILDER(Name("OneHot") \ 130 .Device(DEVICE_CPU) \ 131 .TypeConstraint<index_type>("TI") \ 132 .TypeConstraint<type>("T") \ 133 .HostMemory("depth"), \ 134 OneHotOp<CPUDevice, type, index_type>); 135 136 #define REGISTER_ONE_HOT(type) \ 137 REGISTER_ONE_HOT_INDEX(type, uint8); \ 138 REGISTER_ONE_HOT_INDEX(type, int32); \ 139 REGISTER_ONE_HOT_INDEX(type, int64) 140 141 TF_CALL_ALL_TYPES(REGISTER_ONE_HOT); 142 143 #if GOOGLE_CUDA 144 145 // Forward declarations of the functor specializations for GPU. 146 namespace functor { 147 #define DECLARE_GPU_SPEC_INDEX(T, TI) \ 148 template <> \ 149 void OneHot<GPUDevice, T, TI>::Compute( \ 150 const GPUDevice& d, const typename TTypes<TI>::ConstMatrix& indices, \ 151 const typename TTypes<T>::ConstScalar& on_value, \ 152 const typename TTypes<T>::ConstScalar& off_value, \ 153 typename TTypes<T, 3>::Tensor* output); \ 154 extern template struct OneHot<GPUDevice, T, TI>; 155 156 #define DECLARE_GPU_SPEC(T) \ 157 DECLARE_GPU_SPEC_INDEX(T, uint8); \ 158 DECLARE_GPU_SPEC_INDEX(T, int32); \ 159 DECLARE_GPU_SPEC_INDEX(T, int64); 160 161 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); 162 TF_CALL_int32(DECLARE_GPU_SPEC); 163 TF_CALL_int64(DECLARE_GPU_SPEC); 164 165 #undef DECLARE_GPU_SPEC_INDEX 166 #undef DECLARE_GPU_SPEC 167 168 } // namespace functor 169 170 // Registration of the GPU implementations. 171 #define REGISTER_ONE_HOT_GPU_INDEX(type, index_type) \ 172 REGISTER_KERNEL_BUILDER(Name("OneHot") \ 173 .Device(DEVICE_GPU) \ 174 .TypeConstraint<index_type>("TI") \ 175 .TypeConstraint<type>("T") \ 176 .HostMemory("depth"), \ 177 OneHotOp<GPUDevice, type, index_type>); 178 179 #define REGISTER_ONE_HOT_GPU(type) \ 180 REGISTER_ONE_HOT_GPU_INDEX(type, uint8); \ 181 REGISTER_ONE_HOT_GPU_INDEX(type, int32); \ 182 REGISTER_ONE_HOT_GPU_INDEX(type, int64); 183 184 TF_CALL_GPU_NUMBER_TYPES(REGISTER_ONE_HOT_GPU); 185 TF_CALL_int32(REGISTER_ONE_HOT_GPU); 186 TF_CALL_int64(REGISTER_ONE_HOT_GPU); 187 188 #undef REGISTER_ONE_HOT_GPU_INDEX 189 #undef REGISTER_ONE_HOT_GPU 190 191 #endif // GOOGLE_CUDA 192 193 } // namespace tensorflow 194