Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #if GOOGLE_CUDA
     21 #define EIGEN_USE_GPU
     22 #endif  // GOOGLE_CUDA
     23 
     24 #include "tensorflow/core/kernels/one_hot_op.h"
     25 
     26 #include <memory>
     27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/register_types.h"
     30 #include "tensorflow/core/framework/tensor.h"
     31 #include "tensorflow/core/framework/tensor_shape.h"
     32 #include "tensorflow/core/framework/tensor_types.h"
     33 #include "tensorflow/core/framework/types.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/macros.h"
     36 #include "tensorflow/core/util/overflow.h"
     37 
     38 namespace tensorflow {
     39 
     40 typedef Eigen::ThreadPoolDevice CPUDevice;
     41 typedef Eigen::GpuDevice GPUDevice;
     42 
     43 template <typename Device, typename T, typename TI>
     44 class OneHotOp : public OpKernel {
     45  public:
     46   explicit OneHotOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
     47     OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
     48   }
     49 
     50   void Compute(OpKernelContext* ctx) override {
     51     const Tensor& indices = ctx->input(0);
     52     const Tensor& depth = ctx->input(1);
     53     const Tensor& on_value = ctx->input(2);
     54     const Tensor& off_value = ctx->input(3);
     55     const TensorShape& indices_shape = indices.shape();
     56 
     57     const int indices_dims = indices_shape.dims();
     58     const int output_dims = indices_dims + 1;
     59 
     60     // Preliminary validation of sizes.
     61     OP_REQUIRES(
     62         ctx, axis_ == -1 || (axis_ >= 0 && axis_ < output_dims),
     63         errors::InvalidArgument("Expected axis to be -1 or between [0, ",
     64                                 output_dims, ").  But received: ", axis_));
     65     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(depth.shape()),
     66                 errors::InvalidArgument("depth must be a scalar, but got: ",
     67                                         depth.shape().DebugString()));
     68     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(on_value.shape()),
     69                 errors::InvalidArgument("on_value must be a scalar, but got: ",
     70                                         on_value.shape().DebugString()));
     71     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(off_value.shape()),
     72                 errors::InvalidArgument("off_value must be a scalar, but got: ",
     73                                         off_value.shape().DebugString()));
     74 
     75     const int axis = (axis_ == -1) ? indices_dims : axis_;
     76 
     77     // The one-hot dimension.
     78     const int32 depth_v = depth.scalar<int32>()();
     79     OP_REQUIRES(
     80         ctx, depth_v >= 0,
     81         errors::InvalidArgument("depth must be non-negative, got: ", depth_v));
     82     OP_REQUIRES(
     83         ctx,
     84         MultiplyWithoutOverflow(indices_shape.num_elements(), depth_v) >= 0,
     85         errors::InvalidArgument("OneHot result would have shape ",
     86                                 indices_shape.DebugString(), " + [", depth_v,
     87                                 "], which exceeds 2**63 - 1 elements"));
     88 
     89     TensorShape output_shape = indices_shape;
     90     output_shape.InsertDim(axis, depth_v);
     91 
     92     auto on_value_t = on_value.scalar<T>();
     93     auto off_value_t = off_value.scalar<T>();
     94 
     95     Tensor* output;
     96     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output));
     97 
     98     if (output_shape.num_elements() > 0) {
     99       // prefix_dim_size == # of elements before the axis
    100       // depth_v == # of elements per axis
    101       // suffix_dim_size == # of elements after the axis
    102       int64 prefix_dim_size = 1;
    103       for (int i = 0; i < axis; ++i) {
    104         prefix_dim_size *= indices_shape.dim_size(i);
    105       }
    106       TI suffix_dim_size = indices_shape.num_elements() / prefix_dim_size;
    107 
    108       // Split indices into matrix of size prefix_dim_size x suffix_dim_size
    109       auto indices_t =
    110           indices.shaped<TI, 2>({prefix_dim_size, suffix_dim_size});
    111       // Split output into 3-Tensor of size:
    112       //   prefix_dim_size x depth x suffix_dim_size.
    113       auto output_t =
    114           output->shaped<T, 3>({prefix_dim_size, depth_v, suffix_dim_size});
    115 
    116       functor::OneHot<Device, T, TI>::Compute(ctx->eigen_device<Device>(),
    117                                               indices_t, on_value_t,
    118                                               off_value_t, &output_t);
    119     }
    120   }
    121 
    122  private:
    123   int32 axis_;
    124 
    125   TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp);
    126 };
    127 
    128 #define REGISTER_ONE_HOT_INDEX(type, index_type)                \
    129   REGISTER_KERNEL_BUILDER(Name("OneHot")                        \
    130                               .Device(DEVICE_CPU)               \
    131                               .TypeConstraint<index_type>("TI") \
    132                               .TypeConstraint<type>("T")        \
    133                               .HostMemory("depth"),             \
    134                           OneHotOp<CPUDevice, type, index_type>);
    135 
    136 #define REGISTER_ONE_HOT(type)         \
    137   REGISTER_ONE_HOT_INDEX(type, uint8); \
    138   REGISTER_ONE_HOT_INDEX(type, int32); \
    139   REGISTER_ONE_HOT_INDEX(type, int64)
    140 
    141 TF_CALL_ALL_TYPES(REGISTER_ONE_HOT);
    142 
    143 #if GOOGLE_CUDA
    144 
    145 // Forward declarations of the functor specializations for GPU.
    146 namespace functor {
    147 #define DECLARE_GPU_SPEC_INDEX(T, TI)                                      \
    148   template <>                                                              \
    149   void OneHot<GPUDevice, T, TI>::Compute(                                  \
    150       const GPUDevice& d, const typename TTypes<TI>::ConstMatrix& indices, \
    151       const typename TTypes<T>::ConstScalar& on_value,                     \
    152       const typename TTypes<T>::ConstScalar& off_value,                    \
    153       typename TTypes<T, 3>::Tensor* output);                              \
    154   extern template struct OneHot<GPUDevice, T, TI>;
    155 
    156 #define DECLARE_GPU_SPEC(T)         \
    157   DECLARE_GPU_SPEC_INDEX(T, uint8); \
    158   DECLARE_GPU_SPEC_INDEX(T, int32); \
    159   DECLARE_GPU_SPEC_INDEX(T, int64);
    160 
    161 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
    162 TF_CALL_int32(DECLARE_GPU_SPEC);
    163 TF_CALL_int64(DECLARE_GPU_SPEC);
    164 
    165 #undef DECLARE_GPU_SPEC_INDEX
    166 #undef DECLARE_GPU_SPEC
    167 
    168 }  // namespace functor
    169 
    170 // Registration of the GPU implementations.
    171 #define REGISTER_ONE_HOT_GPU_INDEX(type, index_type)            \
    172   REGISTER_KERNEL_BUILDER(Name("OneHot")                        \
    173                               .Device(DEVICE_GPU)               \
    174                               .TypeConstraint<index_type>("TI") \
    175                               .TypeConstraint<type>("T")        \
    176                               .HostMemory("depth"),             \
    177                           OneHotOp<GPUDevice, type, index_type>);
    178 
    179 #define REGISTER_ONE_HOT_GPU(type)         \
    180   REGISTER_ONE_HOT_GPU_INDEX(type, uint8); \
    181   REGISTER_ONE_HOT_GPU_INDEX(type, int32); \
    182   REGISTER_ONE_HOT_GPU_INDEX(type, int64);
    183 
    184 TF_CALL_GPU_NUMBER_TYPES(REGISTER_ONE_HOT_GPU);
    185 TF_CALL_int32(REGISTER_ONE_HOT_GPU);
    186 TF_CALL_int64(REGISTER_ONE_HOT_GPU);
    187 
    188 #undef REGISTER_ONE_HOT_GPU_INDEX
    189 #undef REGISTER_ONE_HOT_GPU
    190 
    191 #endif  // GOOGLE_CUDA
    192 
    193 }  // namespace tensorflow
    194