Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/math_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #if GOOGLE_CUDA
     21 #define EIGEN_USE_GPU
     22 #endif  // GOOGLE_CUDA
     23 
     24 #include "tensorflow/core/kernels/argmax_op.h"
     25 
     26 #include <memory>
     27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/register_types.h"
     30 #include "tensorflow/core/framework/tensor.h"
     31 #include "tensorflow/core/framework/tensor_shape.h"
     32 #include "tensorflow/core/framework/tensor_types.h"
     33 #include "tensorflow/core/framework/types.h"
     34 #include "tensorflow/core/kernels/bounds_check.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 #include "tensorflow/core/platform/macros.h"
     37 
     38 namespace tensorflow {
     39 
     40 typedef Eigen::ThreadPoolDevice CPUDevice;
     41 typedef Eigen::GpuDevice GPUDevice;
     42 
     43 template <typename Device, typename T, typename Tout, typename ArgFunctor>
     44 class ArgOp : public OpKernel {
     45  public:
     46   explicit ArgOp(OpKernelConstruction* context) : OpKernel(context) {}
     47 
     48   void Compute(OpKernelContext* context) override {
     49     const Tensor& input = context->input(0);
     50     const Tensor& dimension = context->input(1);
     51 
     52     OP_REQUIRES(context, TensorShapeUtils::IsScalar(dimension.shape()),
     53                 errors::InvalidArgument(
     54                     "dim must be a scalar, but received tensor of shape: ",
     55                     dimension.shape().DebugString()));
     56 
     57     const int32 dim = internal::SubtleMustCopy(dimension.scalar<int32>()());
     58     const int input_dims = input.dims();
     59 
     60     int axis = dim < 0 ? dim + input_dims : dim;
     61 
     62     OP_REQUIRES(context, axis >= 0 && axis < input_dims,
     63                 errors::InvalidArgument("Expected dimension in the range [",
     64                                         -input_dims, ", ", input_dims,
     65                                         "), but got ", dim));
     66     OP_REQUIRES(
     67         context, input.dim_size(axis) > 0,
     68         errors::InvalidArgument("Reduction axis ", dim, " is empty in shape ",
     69                                 input.shape().DebugString()));
     70 
     71     TensorShape output_shape;
     72     const TensorShape& input_shape = input.shape();
     73     for (int d = 0; d < input_dims - 1; ++d) {
     74       output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1));
     75     }
     76     Tensor* output = nullptr;
     77     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
     78 
     79 #define HANDLE_DIM(NDIM)                                        \
     80   case NDIM:                                                    \
     81     ArgFunctor::Reduce##NDIM(context->eigen_device<Device>(),   \
     82                              input.tensor<T, NDIM>(), axis,     \
     83                              output->tensor<Tout, NDIM - 1>()); \
     84     break;
     85 
     86     switch (input_dims) {
     87       HANDLE_DIM(1);
     88       HANDLE_DIM(2);
     89       HANDLE_DIM(3);
     90       HANDLE_DIM(4);
     91       HANDLE_DIM(5);
     92 
     93       default:
     94         OP_REQUIRES(context, false,
     95                     errors::InvalidArgument(
     96                         "ArgOp : Unhandled input dimensions: ", input_dims));
     97     }
     98   }
     99 #undef HANDLE_DIM
    100 
    101  private:
    102   TF_DISALLOW_COPY_AND_ASSIGN(ArgOp);
    103 };
    104 
    105 template <typename Device, typename T, typename Tout>
    106 class ArgMaxOp
    107     : public ArgOp<Device, T, Tout, functor::ArgMax<Device, T, Tout> > {
    108  public:
    109   explicit ArgMaxOp(OpKernelConstruction* context)
    110       : ArgOp<Device, T, Tout, functor::ArgMax<Device, T, Tout> >(context) {}
    111 };
    112 
    113 template <typename Device, typename T, typename Tout>
    114 class ArgMinOp
    115     : public ArgOp<Device, T, Tout, functor::ArgMin<Device, T, Tout> > {
    116  public:
    117   explicit ArgMinOp(OpKernelConstruction* context)
    118       : ArgOp<Device, T, Tout, functor::ArgMin<Device, T, Tout> >(context) {}
    119 };
    120 
    121 #define REGISTER_ARGMAX(type)                                       \
    122   REGISTER_KERNEL_BUILDER(Name("ArgMax")                            \
    123                               .Device(DEVICE_CPU)                   \
    124                               .TypeConstraint<type>("T")            \
    125                               .TypeConstraint<int64>("output_type") \
    126                               .HostMemory("dimension"),             \
    127                           ArgMaxOp<CPUDevice, type, int64>);        \
    128   REGISTER_KERNEL_BUILDER(Name("ArgMin")                            \
    129                               .Device(DEVICE_CPU)                   \
    130                               .TypeConstraint<type>("T")            \
    131                               .TypeConstraint<int64>("output_type") \
    132                               .HostMemory("dimension"),             \
    133                           ArgMinOp<CPUDevice, type, int64>);        \
    134   REGISTER_KERNEL_BUILDER(Name("ArgMax")                            \
    135                               .Device(DEVICE_CPU)                   \
    136                               .TypeConstraint<type>("T")            \
    137                               .TypeConstraint<int32>("output_type") \
    138                               .HostMemory("dimension"),             \
    139                           ArgMaxOp<CPUDevice, type, int32>);        \
    140   REGISTER_KERNEL_BUILDER(Name("ArgMin")                            \
    141                               .Device(DEVICE_CPU)                   \
    142                               .TypeConstraint<type>("T")            \
    143                               .TypeConstraint<int32>("output_type") \
    144                               .HostMemory("dimension"),             \
    145                           ArgMinOp<CPUDevice, type, int32>);
    146 
    147 TF_CALL_REAL_NUMBER_TYPES(REGISTER_ARGMAX);
    148 
    149 #if GOOGLE_CUDA
    150 
    151 // Forward declarations of the functor specializations for GPU.
    152 namespace functor {
    153 
    154 #define DECLARE_GPU_SPEC(T, Tout, Dims)                                       \
    155   template <>                                                                 \
    156   void ArgMax<GPUDevice, T, Tout>::Reduce##Dims(                              \
    157       const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input,        \
    158       const int32 dimension, typename TTypes<Tout, Dims - 1>::Tensor output); \
    159   template <>                                                                 \
    160   void ArgMin<GPUDevice, T, Tout>::Reduce##Dims(                              \
    161       const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input,        \
    162       const int32 dimension, typename TTypes<Tout, Dims - 1>::Tensor output);
    163 
    164 #define DECLARE_GPU_SPECS(T)     \
    165   DECLARE_GPU_SPEC(T, int64, 1); \
    166   DECLARE_GPU_SPEC(T, int64, 2); \
    167   DECLARE_GPU_SPEC(T, int64, 3); \
    168   DECLARE_GPU_SPEC(T, int64, 4); \
    169   DECLARE_GPU_SPEC(T, int64, 5); \
    170   DECLARE_GPU_SPEC(T, int32, 1); \
    171   DECLARE_GPU_SPEC(T, int32, 2); \
    172   DECLARE_GPU_SPEC(T, int32, 3); \
    173   DECLARE_GPU_SPEC(T, int32, 4); \
    174   DECLARE_GPU_SPEC(T, int32, 5);
    175 
    176 #define DECLARE_GPU_CLASS(T)                          \
    177   extern template struct ArgMax<GPUDevice, T, int64>; \
    178   extern template struct ArgMin<GPUDevice, T, int64>; \
    179   extern template struct ArgMax<GPUDevice, T, int32>; \
    180   extern template struct ArgMin<GPUDevice, T, int32>;
    181 
    182 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
    183 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_CLASS);
    184 
    185 #undef DECLARE_GPU_SPECS
    186 #undef DECLARE_GPU_CLASS
    187 
    188 }  // namespace functor
    189 
    190 // Registration of the GPU implementations.
    191 #define REGISTER_ARGMAX_GPU(type)                                   \
    192   REGISTER_KERNEL_BUILDER(Name("ArgMax")                            \
    193                               .Device(DEVICE_GPU)                   \
    194                               .TypeConstraint<type>("T")            \
    195                               .TypeConstraint<int64>("output_type") \
    196                               .TypeConstraint<int32>("Tidx")        \
    197                               .HostMemory("dimension"),             \
    198                           ArgMaxOp<GPUDevice, type, int64>);        \
    199   REGISTER_KERNEL_BUILDER(Name("ArgMin")                            \
    200                               .Device(DEVICE_GPU)                   \
    201                               .TypeConstraint<type>("T")            \
    202                               .TypeConstraint<int64>("output_type") \
    203                               .TypeConstraint<int32>("Tidx")        \
    204                               .HostMemory("dimension"),             \
    205                           ArgMinOp<GPUDevice, type, int64>);        \
    206   REGISTER_KERNEL_BUILDER(Name("ArgMax")                            \
    207                               .Device(DEVICE_GPU)                   \
    208                               .TypeConstraint<type>("T")            \
    209                               .TypeConstraint<int32>("output_type") \
    210                               .TypeConstraint<int32>("Tidx")        \
    211                               .HostMemory("dimension"),             \
    212                           ArgMaxOp<GPUDevice, type, int32>);        \
    213   REGISTER_KERNEL_BUILDER(Name("ArgMin")                            \
    214                               .Device(DEVICE_GPU)                   \
    215                               .TypeConstraint<type>("T")            \
    216                               .TypeConstraint<int32>("output_type") \
    217                               .TypeConstraint<int32>("Tidx")        \
    218                               .HostMemory("dimension"),             \
    219                           ArgMinOp<GPUDevice, type, int32>);
    220 
    221 TF_CALL_GPU_NUMBER_TYPES(REGISTER_ARGMAX_GPU);
    222 
    223 #undef REGISTER_ARGMAX_GPU
    224 
    225 #endif  // GOOGLE_CUDA
    226 
    227 }  // namespace tensorflow
    228