1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #if GOOGLE_CUDA 17 18 #define EIGEN_USE_GPU 19 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 #include "external/cub_archive/cub/device/device_reduce.cuh" 22 #include "external/cub_archive/cub/device/device_select.cuh" 23 #include "external/cub_archive/cub/iterator/counting_input_iterator.cuh" 24 #include "external/cub_archive/cub/iterator/transform_input_iterator.cuh" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/framework/tensor_types.h" 27 #include "tensorflow/core/kernels/bounds_check.h" 28 #include "tensorflow/core/kernels/where_op.h" 29 #include "tensorflow/core/platform/macros.h" 30 #include "tensorflow/core/platform/types.h" 31 #include "tensorflow/core/util/cuda_kernel_helper.h" 32 33 namespace tensorflow { 34 35 typedef Eigen::GpuDevice GPUDevice; 36 37 namespace functor { 38 39 template <int NDIM, typename TIndex> 40 __global__ void PropagateWhereIndicesKernel( 41 const TIndex output_rows, const typename Eigen::array<TIndex, NDIM> strides, 42 int64* output) { 43 // TODO(ebrevdo): Use a multi-dimensional loop, increasing the 44 // dimensions of individual indices manually, instead of relying on 45 // a scalar loop variable and using integer division. 46 CUDA_1D_KERNEL_LOOP(i, output_rows) { 47 TIndex index_value = ldg(output + NDIM * i); 48 #pragma unroll 49 for (int c = 0; c < NDIM; ++c) { 50 *(output + NDIM * i + c) = index_value / strides[c]; 51 index_value %= strides[c]; 52 } 53 } 54 } 55 56 namespace { 57 58 template <typename T> 59 struct IsNonzero { 60 EIGEN_DEVICE_FUNC IsNonzero() : zero(T(0)) {} 61 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x) const { 62 return (x != zero); 63 } 64 const T zero; 65 }; 66 67 template <typename T, typename TIndex> 68 struct CubDeviceReduceCount { 69 cudaError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes, 70 const T* d_in, TIndex* d_out, int num_items, 71 cudaStream_t stream = 0, 72 bool debug_synchronous = false) { 73 IsNonzero<T> is_nonzero; 74 cub::TransformInputIterator<bool, IsNonzero<T>, const T*> is_nonzero_iter( 75 d_in, is_nonzero); 76 return cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, 77 is_nonzero_iter, d_out, num_items, stream, 78 debug_synchronous); 79 } 80 }; 81 82 template <typename TIndex> 83 struct CubDeviceReduceCount<bool, TIndex> { 84 cudaError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes, 85 const bool* d_in, TIndex* d_out, int num_items, 86 cudaStream_t stream = 0, 87 bool debug_synchronous = false) { 88 return cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, d_in, 89 d_out, num_items, stream, debug_synchronous); 90 } 91 }; 92 93 template <typename T, typename TIndex, typename OutputIterator, 94 bool IsConvertibleToBool> 95 struct CubDeviceSelectFlaggedCounter; 96 97 template <typename T, typename TIndex, typename OutputIterator> 98 struct CubDeviceSelectFlaggedCounter<T, TIndex, OutputIterator, 99 false /*IsConvertibleToBool*/> { 100 cudaError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes, 101 const T* d_flags, OutputIterator d_out, 102 TIndex* d_num_selected_out, int num_items, 103 cudaStream_t stream = 0, 104 bool debug_synchronous = false) { 105 cub::CountingInputIterator<TIndex> select_counter(0); 106 IsNonzero<T> is_nonzero; 107 cub::TransformInputIterator<bool, IsNonzero<T>, const T*> is_nonzero_iter( 108 d_flags, is_nonzero); 109 return cub::DeviceSelect::Flagged( 110 d_temp_storage, temp_storage_bytes, select_counter /*d_in*/, 111 is_nonzero_iter /*d_flags*/, d_out, d_num_selected_out, num_items, 112 stream, debug_synchronous); 113 } 114 }; 115 116 template <typename T, typename TIndex, typename OutputIterator> 117 struct CubDeviceSelectFlaggedCounter<T, TIndex, OutputIterator, 118 true /*IsConvertibleToBool*/> { 119 cudaError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes, 120 const T* d_flags, OutputIterator d_out, 121 TIndex* d_num_selected_out, int num_items, 122 cudaStream_t stream = 0, 123 bool debug_synchronous = false) { 124 cub::CountingInputIterator<TIndex> select_counter(0); 125 return cub::DeviceSelect::Flagged( 126 d_temp_storage, temp_storage_bytes, select_counter /*d_in*/, d_flags, 127 d_out, d_num_selected_out, num_items, stream, debug_synchronous); 128 } 129 }; 130 131 } // namespace 132 133 template <typename T, typename TIndex> 134 struct NumTrue<GPUDevice, T, TIndex> { 135 EIGEN_ALWAYS_INLINE static Status Compute( 136 OpKernelContext* ctx, const GPUDevice& d, 137 typename TTypes<T>::ConstFlat input, 138 typename TTypes<TIndex>::Scalar num_true) { 139 const cudaStream_t& cu_stream = GetCudaStream(ctx); 140 141 std::size_t temp_storage_bytes = 0; 142 const T* input_data = input.data(); 143 TIndex* num_true_data = num_true.data(); 144 145 // TODO(ebrevdo): sum doesn't work; perhaps need a different 146 // iterator? 147 auto reducer = CubDeviceReduceCount<T, TIndex>(); 148 auto first_success = reducer(/*temp_storage*/ nullptr, temp_storage_bytes, 149 /*d_in*/ input_data, 150 /*d_out*/ num_true_data, 151 /*num_items*/ input.size(), 152 /*stream*/ cu_stream); 153 154 if (first_success != cudaSuccess) { 155 return errors::Internal( 156 "WhereOp: Could not launch cub::DeviceReduce::Sum to calculate " 157 "temp_storage_bytes, status: ", 158 cudaGetErrorString(first_success)); 159 } 160 161 Tensor temp_storage; 162 TF_RETURN_IF_ERROR(ctx->allocate_temp( 163 DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}), 164 &temp_storage)); 165 166 auto second_success = reducer( 167 /*temp_storage*/ temp_storage.flat<int8>().data(), temp_storage_bytes, 168 /*d_in*/ input_data, 169 /*d_out*/ num_true_data, 170 /*num_items*/ input.size(), 171 /*stream*/ cu_stream); 172 173 if (second_success != cudaSuccess) { 174 return errors::Internal( 175 "WhereOp: Could not launch cub::DeviceReduce::Sum to count " 176 "number of true / nonzero indices. temp_storage_bytes: ", 177 temp_storage_bytes, ", status: ", cudaGetErrorString(second_success)); 178 } 179 180 return Status::OK(); 181 } 182 }; 183 184 #define NUMTRUE_GPU_FUNCTOR(T) \ 185 template struct NumTrue<GPUDevice, T, int32>; \ 186 template struct NumTrue<GPUDevice, T, int64>; 187 188 // We only need to declare the NumTrue functor once, but this file is 189 // included from where_op_gpu_impl_X.cu.cc for X=1,2,... 190 // Only declare for X = 1. 191 #if GPU_PROVIDED_DIM == 1 192 193 TF_CALL_WHERE_GPU_TYPES(NUMTRUE_GPU_FUNCTOR); 194 195 #endif // GPU_PROVIDED_DIM == 1 196 197 #undef NUMTRUE_GPU_FUNCTOR 198 199 template <int NDIM> 200 class WhereOutputIterator { 201 public: 202 // Required iterator traits 203 typedef WhereOutputIterator self_type; 204 typedef std::ptrdiff_t difference_type; 205 typedef void value_type; 206 typedef void pointer; 207 typedef int64& reference; 208 209 #if (THRUST_VERSION >= 100700) 210 // Use Thrust's iterator categories so we can use these iterators in Thrust 211 // 1.7 (or newer) methods 212 typedef typename thrust::detail::iterator_facade_category< 213 thrust::device_system_tag, thrust::random_access_traversal_tag, 214 value_type, 215 reference>::type iterator_category; ///< The iterator category 216 #else 217 typedef std::random_access_iterator_tag 218 iterator_category; ///< The iterator category 219 #endif // THRUST_VERSION 220 221 WhereOutputIterator(int64* ptr, const Eigen::DenseIndex max_row) 222 : ptr_(ptr), max_row_(max_row) {} 223 224 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int64& operator[](int n) const { 225 // If the selection mechanism finds too many true values (because 226 // the input tensor changed between allocation of output and now), 227 // we may accidentally try to write past the allowable memory. If 228 // valid is false, then we don't do this. Instead, we'll read off 229 // the number of items found in Flagged()'s d_num_selected_out at 230 // the end and confirm that it matches the number of rows of output. 231 const bool valid = FastBoundsCheck(n, max_row_); 232 return *(ptr_ + (valid ? (NDIM * n) : 0)); 233 } 234 235 private: 236 int64* ptr_; 237 const Eigen::DenseIndex max_row_; 238 }; 239 240 template <typename TIndex, typename T, int NDIM> 241 Eigen::array<TIndex, NDIM> CalculateStrides( 242 typename TTypes<T, NDIM>::ConstTensor input) { 243 const Eigen::DSizes<Eigen::DenseIndex, NDIM> dims = input.dimensions(); 244 Eigen::array<TIndex, NDIM> strides; 245 EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) == 246 static_cast<int>(Eigen::RowMajor)), 247 INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR); 248 strides[NDIM - 1] = 1; 249 for (int i = NDIM - 2; i >= 0; --i) { 250 strides[i] = strides[i + 1] * dims[i + 1]; 251 } 252 return strides; 253 } 254 255 template <int NDIM, typename T, typename TIndex> 256 struct Where<GPUDevice, NDIM, T, TIndex> { 257 EIGEN_ALWAYS_INLINE static Status Compute( 258 OpKernelContext* ctx, const GPUDevice& d, 259 typename TTypes<T, NDIM>::ConstTensor input, 260 typename TTypes<int64>::Matrix output, TIndex* found_true_host) { 261 if (output.dimension(0) == 0) { 262 // Nothing to do. 263 return Status::OK(); 264 } 265 266 const cudaStream_t& cu_stream = GetCudaStream(ctx); 267 268 std::size_t temp_storage_bytes = 0; 269 270 Tensor found_true_t; 271 TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<TIndex>::v(), 272 TensorShape({}), &found_true_t)); 273 TIndex* found_true_device = found_true_t.scalar<TIndex>().data(); 274 275 WhereOutputIterator<NDIM> output_iterator( 276 output.data(), 277 /* max_row */ output.dimension(0)); 278 279 typedef std::decay<T> DT; 280 CubDeviceSelectFlaggedCounter< 281 T, TIndex, decltype(output_iterator) /*OutputIterator*/, 282 std::is_convertible<DT, bool>::value /*IsConvertibleToBool*/> 283 counter; 284 auto first_success = counter(/*temp_storage*/ nullptr, temp_storage_bytes, 285 /*d_flags*/ input.data(), 286 /*d_out*/ output_iterator, 287 /*d_num_selected_out*/ found_true_device, 288 /*num_items*/ input.size(), 289 /*stream*/ cu_stream); 290 if (first_success != cudaSuccess) { 291 return errors::Internal( 292 "WhereOp: Could not launch cub::DeviceSelect::Flagged to calculate " 293 "temp_storage_bytes, status: ", 294 cudaGetErrorString(first_success)); 295 } 296 297 Tensor temp_storage; 298 TF_RETURN_IF_ERROR(ctx->allocate_temp( 299 DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}), 300 &temp_storage)); 301 302 auto second_success = counter( 303 /*temp_storage*/ temp_storage.flat<int8>().data(), temp_storage_bytes, 304 /*d_flags*/ input.data(), 305 /*d_out*/ output_iterator, 306 /*d_num_selected_out*/ found_true_device, 307 /*num_items*/ input.size(), 308 /*stream*/ cu_stream); 309 310 if (second_success != cudaSuccess) { 311 return errors::Internal( 312 "WhereOp: Could not launch cub::DeviceSelect::Flagged to copy " 313 "indices out, status: ", 314 cudaGetErrorString(second_success)); 315 } 316 317 // TODO(ebrevdo): Find a way to synchronously copy back data from 318 // found_true_device to *found_true_host. 319 320 const Eigen::array<TIndex, NDIM> strides = 321 CalculateStrides<TIndex, T, NDIM>(input); 322 const TIndex output_rows = output.dimension(0); 323 CudaLaunchConfig config = GetCudaLaunchConfig(output_rows, d); 324 PropagateWhereIndicesKernel<NDIM, TIndex> 325 <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 326 output_rows, strides, output.data()); 327 328 return Status::OK(); 329 } 330 }; 331 332 #define DECLARE_GPU_SPEC_INDEX(Dims, T, TIndex) \ 333 template struct Where<GPUDevice, Dims, T, TIndex> 334 335 #define DECLARE_GPU_SPEC(T) \ 336 DECLARE_GPU_SPEC_INDEX(GPU_PROVIDED_DIM, T, int32); \ 337 DECLARE_GPU_SPEC_INDEX(GPU_PROVIDED_DIM, T, int64) 338 339 TF_CALL_WHERE_GPU_TYPES(DECLARE_GPU_SPEC); 340 341 #undef DECLARE_GPU_SPEC 342 #undef DECLARE_GPU_SPEC_INDEX 343 344 } // namespace functor 345 346 } // namespace tensorflow 347 348 #endif // GOOGLE_CUDA 349