1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #if GOOGLE_CUDA 21 #define EIGEN_USE_GPU 22 #endif // GOOGLE_CUDA 23 24 #include "tensorflow/core/kernels/argmax_op.h" 25 26 #include <memory> 27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/register_types.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_shape.h" 32 #include "tensorflow/core/framework/tensor_types.h" 33 #include "tensorflow/core/framework/types.h" 34 #include "tensorflow/core/kernels/bounds_check.h" 35 #include "tensorflow/core/platform/logging.h" 36 #include "tensorflow/core/platform/macros.h" 37 38 namespace tensorflow { 39 40 typedef Eigen::ThreadPoolDevice CPUDevice; 41 typedef Eigen::GpuDevice GPUDevice; 42 43 template <typename Device, typename T, typename Tout, typename ArgFunctor> 44 class ArgOp : public OpKernel { 45 public: 46 explicit ArgOp(OpKernelConstruction* context) : OpKernel(context) {} 47 48 void Compute(OpKernelContext* context) override { 49 const Tensor& input = context->input(0); 50 const Tensor& dimension = context->input(1); 51 52 OP_REQUIRES(context, TensorShapeUtils::IsScalar(dimension.shape()), 53 errors::InvalidArgument( 54 "dim must be a scalar, but received tensor of shape: ", 55 dimension.shape().DebugString())); 56 57 const int32 dim = internal::SubtleMustCopy(dimension.scalar<int32>()()); 58 const int input_dims = input.dims(); 59 60 int axis = dim < 0 ? dim + input_dims : dim; 61 62 OP_REQUIRES(context, axis >= 0 && axis < input_dims, 63 errors::InvalidArgument("Expected dimension in the range [", 64 -input_dims, ", ", input_dims, 65 "), but got ", dim)); 66 OP_REQUIRES( 67 context, input.dim_size(axis) > 0, 68 errors::InvalidArgument("Reduction axis ", dim, " is empty in shape ", 69 input.shape().DebugString())); 70 71 TensorShape output_shape; 72 const TensorShape& input_shape = input.shape(); 73 for (int d = 0; d < input_dims - 1; ++d) { 74 output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1)); 75 } 76 Tensor* output = nullptr; 77 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 78 79 #define HANDLE_DIM(NDIM) \ 80 case NDIM: \ 81 ArgFunctor::Reduce##NDIM(context->eigen_device<Device>(), \ 82 input.tensor<T, NDIM>(), axis, \ 83 output->tensor<Tout, NDIM - 1>()); \ 84 break; 85 86 switch (input_dims) { 87 HANDLE_DIM(1); 88 HANDLE_DIM(2); 89 HANDLE_DIM(3); 90 HANDLE_DIM(4); 91 HANDLE_DIM(5); 92 93 default: 94 OP_REQUIRES(context, false, 95 errors::InvalidArgument( 96 "ArgOp : Unhandled input dimensions: ", input_dims)); 97 } 98 } 99 #undef HANDLE_DIM 100 101 private: 102 TF_DISALLOW_COPY_AND_ASSIGN(ArgOp); 103 }; 104 105 template <typename Device, typename T, typename Tout> 106 class ArgMaxOp 107 : public ArgOp<Device, T, Tout, functor::ArgMax<Device, T, Tout> > { 108 public: 109 explicit ArgMaxOp(OpKernelConstruction* context) 110 : ArgOp<Device, T, Tout, functor::ArgMax<Device, T, Tout> >(context) {} 111 }; 112 113 template <typename Device, typename T, typename Tout> 114 class ArgMinOp 115 : public ArgOp<Device, T, Tout, functor::ArgMin<Device, T, Tout> > { 116 public: 117 explicit ArgMinOp(OpKernelConstruction* context) 118 : ArgOp<Device, T, Tout, functor::ArgMin<Device, T, Tout> >(context) {} 119 }; 120 121 #define REGISTER_ARGMAX(type) \ 122 REGISTER_KERNEL_BUILDER(Name("ArgMax") \ 123 .Device(DEVICE_CPU) \ 124 .TypeConstraint<type>("T") \ 125 .TypeConstraint<int64>("output_type") \ 126 .HostMemory("dimension"), \ 127 ArgMaxOp<CPUDevice, type, int64>); \ 128 REGISTER_KERNEL_BUILDER(Name("ArgMin") \ 129 .Device(DEVICE_CPU) \ 130 .TypeConstraint<type>("T") \ 131 .TypeConstraint<int64>("output_type") \ 132 .HostMemory("dimension"), \ 133 ArgMinOp<CPUDevice, type, int64>); \ 134 REGISTER_KERNEL_BUILDER(Name("ArgMax") \ 135 .Device(DEVICE_CPU) \ 136 .TypeConstraint<type>("T") \ 137 .TypeConstraint<int32>("output_type") \ 138 .HostMemory("dimension"), \ 139 ArgMaxOp<CPUDevice, type, int32>); \ 140 REGISTER_KERNEL_BUILDER(Name("ArgMin") \ 141 .Device(DEVICE_CPU) \ 142 .TypeConstraint<type>("T") \ 143 .TypeConstraint<int32>("output_type") \ 144 .HostMemory("dimension"), \ 145 ArgMinOp<CPUDevice, type, int32>); 146 147 TF_CALL_REAL_NUMBER_TYPES(REGISTER_ARGMAX); 148 149 #if GOOGLE_CUDA 150 151 // Forward declarations of the functor specializations for GPU. 152 namespace functor { 153 154 #define DECLARE_GPU_SPEC(T, Tout, Dims) \ 155 template <> \ 156 void ArgMax<GPUDevice, T, Tout>::Reduce##Dims( \ 157 const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \ 158 const int32 dimension, typename TTypes<Tout, Dims - 1>::Tensor output); \ 159 template <> \ 160 void ArgMin<GPUDevice, T, Tout>::Reduce##Dims( \ 161 const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \ 162 const int32 dimension, typename TTypes<Tout, Dims - 1>::Tensor output); 163 164 #define DECLARE_GPU_SPECS(T) \ 165 DECLARE_GPU_SPEC(T, int64, 1); \ 166 DECLARE_GPU_SPEC(T, int64, 2); \ 167 DECLARE_GPU_SPEC(T, int64, 3); \ 168 DECLARE_GPU_SPEC(T, int64, 4); \ 169 DECLARE_GPU_SPEC(T, int64, 5); \ 170 DECLARE_GPU_SPEC(T, int32, 1); \ 171 DECLARE_GPU_SPEC(T, int32, 2); \ 172 DECLARE_GPU_SPEC(T, int32, 3); \ 173 DECLARE_GPU_SPEC(T, int32, 4); \ 174 DECLARE_GPU_SPEC(T, int32, 5); 175 176 #define DECLARE_GPU_CLASS(T) \ 177 extern template struct ArgMax<GPUDevice, T, int64>; \ 178 extern template struct ArgMin<GPUDevice, T, int64>; \ 179 extern template struct ArgMax<GPUDevice, T, int32>; \ 180 extern template struct ArgMin<GPUDevice, T, int32>; 181 182 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); 183 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_CLASS); 184 185 #undef DECLARE_GPU_SPECS 186 #undef DECLARE_GPU_CLASS 187 188 } // namespace functor 189 190 // Registration of the GPU implementations. 191 #define REGISTER_ARGMAX_GPU(type) \ 192 REGISTER_KERNEL_BUILDER(Name("ArgMax") \ 193 .Device(DEVICE_GPU) \ 194 .TypeConstraint<type>("T") \ 195 .TypeConstraint<int64>("output_type") \ 196 .TypeConstraint<int32>("Tidx") \ 197 .HostMemory("dimension"), \ 198 ArgMaxOp<GPUDevice, type, int64>); \ 199 REGISTER_KERNEL_BUILDER(Name("ArgMin") \ 200 .Device(DEVICE_GPU) \ 201 .TypeConstraint<type>("T") \ 202 .TypeConstraint<int64>("output_type") \ 203 .TypeConstraint<int32>("Tidx") \ 204 .HostMemory("dimension"), \ 205 ArgMinOp<GPUDevice, type, int64>); \ 206 REGISTER_KERNEL_BUILDER(Name("ArgMax") \ 207 .Device(DEVICE_GPU) \ 208 .TypeConstraint<type>("T") \ 209 .TypeConstraint<int32>("output_type") \ 210 .TypeConstraint<int32>("Tidx") \ 211 .HostMemory("dimension"), \ 212 ArgMaxOp<GPUDevice, type, int32>); \ 213 REGISTER_KERNEL_BUILDER(Name("ArgMin") \ 214 .Device(DEVICE_GPU) \ 215 .TypeConstraint<type>("T") \ 216 .TypeConstraint<int32>("output_type") \ 217 .TypeConstraint<int32>("Tidx") \ 218 .HostMemory("dimension"), \ 219 ArgMinOp<GPUDevice, type, int32>); 220 221 TF_CALL_GPU_NUMBER_TYPES(REGISTER_ARGMAX_GPU); 222 223 #undef REGISTER_ARGMAX_GPU 224 225 #endif // GOOGLE_CUDA 226 227 } // namespace tensorflow 228