1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #if GOOGLE_CUDA 21 #define EIGEN_USE_GPU 22 #endif // GOOGLE_CUDA 23 24 #include "tensorflow/core/kernels/where_op.h" 25 26 #include <memory> 27 #include <numeric> 28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 29 #include "tensorflow/core/framework/op_kernel.h" 30 #include "tensorflow/core/framework/register_types.h" 31 #include "tensorflow/core/framework/tensor.h" 32 #include "tensorflow/core/framework/tensor_shape.h" 33 #include "tensorflow/core/framework/tensor_types.h" 34 #include "tensorflow/core/framework/types.h" 35 #include "tensorflow/core/kernels/bounds_check.h" 36 #include "tensorflow/core/platform/logging.h" 37 #include "tensorflow/core/platform/macros.h" 38 #include "tensorflow/core/platform/types.h" 39 40 #if GOOGLE_CUDA 41 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" 42 #include "tensorflow/core/kernels/cuda_solvers.h" 43 #include "tensorflow/core/platform/cuda.h" 44 45 using ::perftools::gputools::cuda::ScopedActivateExecutorContext; 46 #endif // GOOGLE_CUDA 47 48 namespace tensorflow { 49 50 typedef Eigen::ThreadPoolDevice CPUDevice; 51 typedef Eigen::GpuDevice GPUDevice; 52 53 namespace functor { 54 55 namespace { 56 template <typename T> 57 int64 CountAccumulator(const T* begin, const T* end) { 58 return std::accumulate(begin, end, 0LL, [](int64 accum, const T& val) { 59 return accum + (val != T(0)); 60 }); 61 } 62 63 template <> 64 int64 CountAccumulator<bool>(const bool* begin, const bool* end) { 65 return std::accumulate(begin, end, 0LL); 66 } 67 68 } // namespace 69 70 template <typename T> 71 struct NumTrue<CPUDevice, T, int64> { 72 static Status Compute(OpKernelContext* ctx, const CPUDevice& d, 73 typename TTypes<T>::ConstFlat input, 74 TTypes<int64>::Scalar num_true) { 75 num_true() = CountAccumulator<T>(input.data(), input.data() + input.size()); 76 return Status::OK(); 77 } 78 }; 79 80 template <int DIMS, typename T, typename TIndex> 81 struct Where<CPUDevice, DIMS, T, TIndex> { 82 EIGEN_ALWAYS_INLINE static void WriteIndexRowMajor( 83 typename TTypes<int64>::Matrix output, 84 const typename Eigen::DSizes<TIndex, DIMS>& strides, TIndex true_n, 85 TIndex index) { 86 for (int i = 0; i < DIMS; ++i) { 87 output(true_n, i) = index / strides[i]; 88 index -= output(true_n, i) * strides[i]; 89 } 90 } 91 92 EIGEN_ALWAYS_INLINE static Status Compute( 93 OpKernelContext* ctx, const CPUDevice& d, 94 typename TTypes<T, DIMS>::ConstTensor input, 95 typename TTypes<int64>::Matrix output, TIndex* found_true) { 96 Eigen::DSizes<Eigen::DenseIndex, DIMS> dims = input.dimensions(); 97 Eigen::DSizes<TIndex, DIMS> strides; 98 99 EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) == 100 static_cast<int>(Eigen::RowMajor)), 101 INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR); 102 103 strides[DIMS - 1] = 1; 104 for (int i = DIMS - 2; i >= 0; --i) { 105 strides[i] = strides[i + 1] * dims[i + 1]; 106 } 107 108 Eigen::DenseIndex output_size = output.dimension(0); 109 for (Eigen::DenseIndex n = 0; n < input.size(); ++n) { 110 if (input.data()[n] != T(0)) { 111 if (FastBoundsCheck(*found_true, output_size)) { 112 WriteIndexRowMajor(output, strides, *found_true, n); 113 } 114 ++*found_true; 115 } 116 } 117 return Status::OK(); 118 } 119 }; 120 121 } // namespace functor 122 123 template <typename T> 124 class WhereCPUOp : public OpKernel { 125 public: 126 explicit WhereCPUOp(OpKernelConstruction* context) : OpKernel(context) {} 127 128 void Compute(OpKernelContext* context) override { 129 const Tensor& input = context->input(0); 130 131 OP_REQUIRES( 132 context, input.dtype() != DT_HALF, 133 errors::Unimplemented("No WhereOp available for float16/half type on " 134 "CPU; dying in CPU WhereOp to avoid silently " 135 "creating costly copies from device.")); 136 137 const int input_dims = input.dims(); 138 139 Tensor num_true; 140 OP_REQUIRES_OK( 141 context, context->allocate_temp(DT_INT64, TensorShape({}), &num_true)); 142 auto num_true_t = num_true.scalar<int64>(); 143 144 Status s = functor::NumTrue<CPUDevice, T, int64>::Compute( 145 context, context->eigen_device<CPUDevice>(), input.flat<T>(), 146 num_true_t); 147 OP_REQUIRES_OK(context, s); 148 TensorShape output_shape({num_true_t(), input_dims}); 149 Tensor* output = nullptr; 150 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 151 152 // TODO(ebrevdo): Replace single-threaded copy with a 153 // multithreaded block copy by getting block counts above instead 154 // of a global NumTrue, then having each block filled in in 155 // separate threads below. 156 int64 found_true = 0; 157 158 #define HANDLE_DIM(NDIM) \ 159 case NDIM: { \ 160 Status s = functor::Where<CPUDevice, NDIM, T, int64>::Compute( \ 161 context, context->eigen_device<CPUDevice>(), input.tensor<T, NDIM>(), \ 162 output->matrix<int64>(), &found_true); \ 163 OP_REQUIRES_OK(context, s); \ 164 } break; 165 166 switch (input_dims) { 167 HANDLE_DIM(1); 168 HANDLE_DIM(2); 169 HANDLE_DIM(3); 170 HANDLE_DIM(4); 171 HANDLE_DIM(5); 172 173 default: 174 OP_REQUIRES(context, false, 175 errors::InvalidArgument( 176 "WhereOp : Unhandled input dimensions: ", input_dims)); 177 } 178 #undef HANDLE_DIM 179 180 OP_REQUIRES( 181 context, found_true == num_true_t(), 182 errors::InvalidArgument( 183 "WhereOp: Race condition between counting the number of true " 184 "elements and writing them. When counting, saw ", 185 num_true_t(), " elements; but when writing their indices, saw ", 186 found_true, " elements.")); 187 } 188 189 private: 190 TF_DISALLOW_COPY_AND_ASSIGN(WhereCPUOp); 191 }; 192 193 #define REGISTER_WHERE_OP(T) \ 194 REGISTER_KERNEL_BUILDER( \ 195 Name("Where").Device(DEVICE_CPU).TypeConstraint<T>("T"), WhereCPUOp<T>); 196 197 TF_CALL_NUMBER_TYPES(REGISTER_WHERE_OP); 198 TF_CALL_bool(REGISTER_WHERE_OP); 199 200 #undef REGISTER_WHERE_OP 201 202 #if GOOGLE_CUDA 203 204 namespace functor { 205 206 #define DECLARE_GPU_NUMTRUE(T, Tindex) \ 207 template <> \ 208 Status NumTrue<GPUDevice, T, Tindex>::Compute( \ 209 OpKernelContext* ctx, const GPUDevice& d, TTypes<T>::ConstFlat input, \ 210 TTypes<Tindex>::Scalar num_true); \ 211 extern template struct NumTrue<GPUDevice, T, Tindex> 212 213 #define DECLARE_GPU_NUMTRUE_TYPE(T) \ 214 DECLARE_GPU_NUMTRUE(T, int32); \ 215 DECLARE_GPU_NUMTRUE(T, int64); 216 217 TF_CALL_NUMBER_TYPES(DECLARE_GPU_NUMTRUE_TYPE); 218 TF_CALL_bool(DECLARE_GPU_NUMTRUE_TYPE); 219 220 #undef DECLARE_GPU_NUMTRUE_TYPE 221 #undef DECLARE_GPU_NUMTRUE 222 223 #define DECLARE_GPU_WHERE_INDEX(Dims, T, Tindex) \ 224 template <> \ 225 Status Where<GPUDevice, Dims, T, Tindex>::Compute( \ 226 OpKernelContext* ctx, const GPUDevice& d, \ 227 typename TTypes<T, Dims>::ConstTensor input, \ 228 typename TTypes<int64>::Matrix output, Tindex* found_true); \ 229 extern template struct Where<GPUDevice, Dims, T, Tindex>; 230 #define DECLARE_GPU_WHERE(Dims, T) \ 231 DECLARE_GPU_WHERE_INDEX(Dims, T, int32); \ 232 DECLARE_GPU_WHERE_INDEX(Dims, T, int64); 233 234 #define DECLARE_GPU_WHERE_TYPES(T) \ 235 DECLARE_GPU_WHERE(1, T); \ 236 DECLARE_GPU_WHERE(2, T); \ 237 DECLARE_GPU_WHERE(3, T); \ 238 DECLARE_GPU_WHERE(4, T); \ 239 DECLARE_GPU_WHERE(5, T); 240 241 TF_CALL_WHERE_GPU_TYPES(DECLARE_GPU_WHERE_TYPES); 242 243 #undef DECLARE_GPU_WHERE_TYPES 244 #undef DECLARE_GPU_WHERE 245 #undef DECLARE_GPU_WHERE_INDEX 246 247 } // namespace functor 248 249 template <typename T> 250 class WhereGPUOp : public AsyncOpKernel { 251 public: 252 explicit WhereGPUOp(OpKernelConstruction* context) : AsyncOpKernel(context) {} 253 254 void ComputeAsync(OpKernelContext* context, DoneCallback done) override { 255 const Tensor& input = context->input(0); 256 const int input_dims = input.dims(); 257 258 if (input.NumElements() < std::numeric_limits<int32>::max()) { 259 ComputeAsyncType<int32>(input, input_dims, context, done); 260 } else { 261 ComputeAsyncType<int64>(input, input_dims, context, done); 262 } 263 } 264 265 template <typename Tindex> 266 void ComputeAsyncType(const Tensor& input, const int input_dims, 267 OpKernelContext* context, DoneCallback done) { 268 // Step 0: alloc nnz 269 // Step 1: call nnz kernel 270 // Step 2: copy nnz to host 271 // Step 3: call create_output 272 // Step 4: call where kernel 273 Tensor num_true; 274 OP_REQUIRES_OK_ASYNC(context, 275 context->allocate_temp(DataTypeToEnum<Tindex>::v(), 276 TensorShape({}), &num_true), 277 done); 278 279 auto num_true_t = num_true.scalar<Tindex>(); 280 281 perftools::gputools::DeviceMemoryBase num_true_ptr( 282 static_cast<void*>(num_true_t.data())); 283 // Push kernel to stream to get number of true elements. 284 const GPUDevice& d = context->eigen_device<GPUDevice>(); 285 Status s = functor::NumTrue<GPUDevice, T, Tindex>::Compute( 286 context, d, input.flat<T>(), num_true_t); 287 OP_REQUIRES_OK_ASYNC(context, s, done); 288 289 // Copy num_true to host; 290 ScratchSpace<Tindex> num_true_host(context, 1, /* on_host */ true); 291 292 auto stream = context->op_device_context()->stream(); 293 OP_REQUIRES_ASYNC( 294 context, 295 stream 296 ->ThenMemcpy(num_true_host.mutable_data(), num_true_ptr, 297 sizeof(Tindex)) 298 .ok(), 299 errors::Internal("WhereOp: failed to copy num_true from device"), done); 300 301 auto create_and_check_output = [context, &d, &input, input_dims, 302 num_true_host, done]() { 303 // Ensure that within the callback, the proper GPU settings are 304 // configured. 305 auto stream = context->op_device_context()->stream(); 306 ScopedActivateExecutorContext scoped_activation{stream->parent()}; 307 308 Tindex num_true = *num_true_host.data(); 309 310 // TODO(ebrevdo): Properly copy back found_true value to CPU for 311 // validation checking. Currently Where<GPUDevice>::Compute() 312 // does not perform this copy back to CPU. 313 Tindex found_true = -1; 314 315 // Step 1: Allocate the output and perform the selection/copy. 316 Tensor* output; 317 OP_REQUIRES_OK_ASYNC(context, 318 context->allocate_output( 319 0, TensorShape({num_true, input_dims}), &output), 320 done); 321 322 #define HANDLE_DIM(NDIM) \ 323 case NDIM: { \ 324 Status s = functor::Where<GPUDevice, NDIM, T, Tindex>::Compute( \ 325 context, d, input.tensor<T, NDIM>(), output->matrix<int64>(), \ 326 &found_true); \ 327 OP_REQUIRES_OK_ASYNC(context, s, done); \ 328 } break; 329 330 switch (input_dims) { 331 HANDLE_DIM(1); 332 HANDLE_DIM(2); 333 HANDLE_DIM(3); 334 HANDLE_DIM(4); 335 HANDLE_DIM(5); 336 337 default: 338 OP_REQUIRES_ASYNC( 339 context, false, 340 errors::InvalidArgument("WhereOp: Unhandled input dimensions: ", 341 input_dims), 342 done); 343 } 344 #undef HANDLE_DIM 345 346 // TODO(ebrevdo): Fix the copy back to host. 347 348 // OP_REQUIRES_ASYNC( 349 // context, found_true == num_true, 350 // errors::InvalidArgument( 351 // "WhereOp: Race condition between counting the number of true " 352 // "elements and writing them. When counting, saw ", 353 // num_true, " elements; but when writing their indices, saw ", 354 // found_true, " elements."), 355 // done); 356 357 done(); 358 }; 359 context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( 360 stream, create_and_check_output); 361 } 362 363 private: 364 TF_DISALLOW_COPY_AND_ASSIGN(WhereGPUOp); 365 }; 366 367 #define REGISTER_GPU_WHERE_OP(T) \ 368 REGISTER_KERNEL_BUILDER( \ 369 Name("Where").Device(DEVICE_GPU).TypeConstraint<T>("T"), WhereGPUOp<T>); 370 371 TF_CALL_WHERE_GPU_TYPES(REGISTER_GPU_WHERE_OP); 372 373 #undef REGISTER_GPU_WHERE_OP 374 375 #endif // GOOGLE_CUDA 376 377 } // namespace tensorflow 378