1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/data_flow_ops.cc. 17 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/register_types.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/kernels/bounds_check.h" 22 #include "tensorflow/core/lib/core/threadpool.h" 23 24 #ifdef GOOGLE_CUDA 25 #include "tensorflow/core/kernels/cuda_device_array.h" 26 #endif // GOOGLE_CUDA 27 28 namespace tensorflow { 29 30 typedef Eigen::ThreadPoolDevice CPUDevice; 31 #ifdef GOOGLE_CUDA 32 typedef Eigen::GpuDevice GPUDevice; 33 #endif // GOOGLE_CUDA 34 35 template <class T> 36 class DynamicStitchOpImplBase : public OpKernel { 37 public: 38 explicit DynamicStitchOpImplBase(OpKernelConstruction* c, 39 const string& op_name) 40 : OpKernel(c) { 41 // Compute expected input signature 42 const DataType dt = DataTypeToEnum<T>::v(); 43 const int n = c->num_inputs() / 2; 44 DataTypeVector expected; 45 for (int i = 0; i < n; i++) { 46 expected.push_back(DT_INT32); 47 } 48 for (int i = 0; i < n; i++) { 49 expected.push_back(dt); 50 } 51 OP_REQUIRES_OK(c, c->MatchSignature(expected, {dt})); 52 OP_REQUIRES(c, c->num_inputs() > 0, 53 errors::InvalidArgument(op_name + ": Must have some inputs")); 54 OP_REQUIRES(c, c->num_inputs() % 2 == 0, 55 errors::InvalidArgument( 56 op_name + ": Must have even number of arguments")); 57 } 58 59 protected: 60 // Check if data0.shape[indices0.dims():] == data1.shape[indices1.dims():] 61 static bool SameExtraShape(const Tensor& data0, const Tensor& indices0, 62 const Tensor& data1, const Tensor& indices1) { 63 const int extra0 = data0.dims() - indices0.dims(); 64 const int extra1 = data1.dims() - indices1.dims(); 65 if (extra0 != extra1) return false; 66 for (int i = 0; i < extra0; i++) { 67 if (data0.dim_size(indices0.dims() + i) != 68 data1.dim_size(indices1.dims() + i)) { 69 return false; 70 } 71 } 72 return true; 73 } 74 75 void CheckArgsAndAllocateResult(OpKernelContext* c, 76 OpInputList* indices_inputs, 77 OpInputList* data_inputs, int* first_dim_size, 78 int* data_elements_size, 79 Tensor** result_ptr) { 80 // Find maximum index in the indices vectors 81 OP_REQUIRES_OK(c, c->input_list("indices", indices_inputs)); 82 83 int32 max_index = -1; 84 if (data_elements_size) { 85 *data_elements_size = 0; 86 } 87 for (const Tensor& indices : *indices_inputs) { 88 if (indices.NumElements() > 0) { 89 Eigen::Tensor<int32, 0, Eigen::RowMajor> m = 90 indices.flat<int32>().maximum(); 91 max_index = std::max(m(), max_index); 92 } 93 if (data_elements_size) { 94 *data_elements_size += indices.NumElements(); 95 } 96 } 97 98 *first_dim_size = max_index + 1; 99 100 // Validate that data[i].shape = indices[i].shape + constant 101 OP_REQUIRES_OK(c, c->input_list("data", data_inputs)); 102 const Tensor& data0 = (*data_inputs)[0]; 103 const Tensor& indices0 = (*indices_inputs)[0]; 104 for (int input_num = 0; input_num < indices_inputs->size(); input_num++) { 105 const Tensor& indices = (*indices_inputs)[input_num]; 106 const Tensor& data = (*data_inputs)[input_num]; 107 OP_REQUIRES( 108 c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()), 109 errors::InvalidArgument("data[", input_num, 110 "].shape = ", data.shape().DebugString(), 111 " does not start with indices[", input_num, 112 "].shape = ", indices.shape().DebugString())); 113 OP_REQUIRES( 114 c, input_num == 0 || SameExtraShape(data0, indices0, data, indices), 115 errors::InvalidArgument( 116 "Need data[0].shape[", indices0.dims(), ":] = data[", input_num, 117 "].shape[", indices.dims(), 118 ":], got data[0].shape = ", data0.shape().DebugString(), 119 ", data[", input_num, "].shape = ", data.shape().DebugString(), 120 ", indices[0].shape = ", indices0.shape().DebugString(), 121 ", indices[", input_num, 122 "].shape = ", indices.shape().DebugString())); 123 } 124 125 // Allocate result tensor of shape 126 // [*first_dim_size] + data.shape[indices.dims:] 127 TensorShape result_shape; 128 result_shape.AddDim(*first_dim_size); 129 for (int d = indices0.dims(); d < data0.dims(); d++) { 130 result_shape.AddDim(data0.dim_size(d)); 131 } 132 OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, result_ptr)); 133 } 134 }; 135 136 #if GOOGLE_CUDA 137 138 template <typename T> 139 void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device, 140 const int32 slice_size, const int32 first_dim_size, 141 const CudaDeviceArrayStruct<int>& input_indices, 142 const CudaDeviceArrayStruct<const T*>& input_ptrs, 143 T* output); 144 145 template <class T> 146 class DynamicStitchOpGPU : public DynamicStitchOpImplBase<T> { 147 public: 148 explicit DynamicStitchOpGPU(OpKernelConstruction* c) 149 : DynamicStitchOpImplBase<T>(c, "DynamicStitchOp") {} 150 151 void Compute(OpKernelContext* c) override { 152 OpInputList indices_inputs; 153 OpInputList data_inputs; 154 int first_dim_size; 155 int data_elements_size; 156 Tensor* merged = nullptr; 157 this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs, 158 &first_dim_size, &data_elements_size, 159 &merged); 160 if (!c->status().ok()) { 161 // Avoid segmentation faults if merged cannot be allocated and an error is 162 // passed back in the context. 163 return; 164 } 165 166 // TODO(jeff): Currently we leave uninitialized any portions of 167 // merged that aren't covered by an index in indices. What should we do? 168 if (first_dim_size > 0) { 169 // because the collision requirements, we have to deal with 170 // collion first before send data to gpu kernel. 171 // TODO(ekelsen): Instead of doing a serial scan on the CPU to pick the 172 // last of duplicated indices, it could instead be done of the GPU 173 // implicitly using atomics to make sure the last index is the final 174 // write. 175 const int slice_size = merged->flat_outer_dims<T>().dimension(1); 176 CudaDeviceArrayOnHost<int32> indices_flat(c, first_dim_size); 177 CudaDeviceArrayOnHost<const T*> data_flat(c, data_elements_size); 178 OP_REQUIRES_OK(c, indices_flat.Init()); 179 OP_REQUIRES_OK(c, data_flat.Init()); 180 // initialize the indices_flat (-1 represents missing indices) 181 for (int i = 0; i < first_dim_size; ++i) { 182 indices_flat.Set(i, -1); 183 } 184 185 // data_flat index 186 int32 idx = 0; 187 // sum of indices_inputs[i].NumElements() for compute indicies_flat value. 188 int32 base_size = 0; 189 for (int i = 0; i < indices_inputs.size(); ++i) { 190 auto indices_vec = indices_inputs[i].flat<int32>(); 191 auto data_ptr_base = data_inputs[i].template flat<T>().data(); 192 for (int j = 0; j < indices_vec.size(); ++j) { 193 // indices_flat's indices represent the indices of output. 194 // indices_flat's values represent the indices of input_data where the 195 // data located. 196 indices_flat.Set(indices_vec(j), base_size + j); 197 data_flat.Set( 198 idx, const_cast<T*>(reinterpret_cast<const T*>(data_ptr_base) + 199 j * slice_size)); 200 ++idx; 201 } 202 base_size += indices_vec.size(); 203 } 204 OP_REQUIRES_OK(c, indices_flat.Finalize()); 205 OP_REQUIRES_OK(c, data_flat.Finalize()); 206 207 auto output = merged->template flat<T>().data(); 208 DynamicStitchGPUImpl<T>(c->eigen_gpu_device(), slice_size, first_dim_size, 209 indices_flat.data(), data_flat.data(), output); 210 } 211 } 212 }; 213 214 #endif // GOOGLE_CUDA 215 216 template <class T, bool Parallel> 217 class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> { 218 public: 219 explicit DynamicStitchOpImplCPU(OpKernelConstruction* c) 220 : DynamicStitchOpImplBase<T>( 221 c, (Parallel ? "ParallelDynamicStitchOp" : "DynamicStitchOp")) {} 222 223 void Compute(OpKernelContext* c) override { 224 OpInputList indices_inputs; 225 OpInputList data_inputs; 226 int first_dim_size; 227 Tensor* merged = nullptr; 228 this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs, 229 &first_dim_size, nullptr, &merged); 230 if (!c->status().ok()) { 231 // Avoid segmentation faults if merged cannot be allocated and an error is 232 // passed back in the context. 233 return; 234 } 235 236 // TODO(jeff): Currently we leave uninitialized any portions of 237 // merged that aren't covered by an index in indices. What should we do? 238 if (first_dim_size > 0) { 239 auto merged_flat = merged->flat_outer_dims<T>(); 240 const int slice_size = merged_flat.dimension(1); 241 const size_t slice_bytes = slice_size * sizeof(T); 242 auto OnInputNumber = [&](int input_num) { 243 const Tensor& indices = indices_inputs[input_num]; 244 auto indices_vec = indices.flat<int32>(); 245 const Tensor& data = data_inputs[input_num]; 246 auto data_flat = 247 data.shaped<T, 2>({indices_vec.dimension(0), slice_size}); 248 249 if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) { 250 T* merged_base = &merged_flat(0, 0); 251 const T* data_base = &data_flat(0, 0); 252 for (int i = 0; i < indices_vec.size(); i++) { 253 int32 index = internal::SubtleMustCopy(indices_vec(i)); 254 OP_REQUIRES( 255 c, FastBoundsCheck(index, first_dim_size), 256 errors::InvalidArgument("indices[", i, "] is out of range")); 257 memcpy(merged_base + index * slice_size, data_base + i * slice_size, 258 slice_bytes); 259 } 260 } else { 261 Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, slice_size); 262 for (int i = 0; i < indices_vec.size(); i++) { 263 // Copy slice data[i] to merged[indices[i]] 264 Eigen::DSizes<Eigen::DenseIndex, 2> data_indices(i, 0); 265 int32 index = internal::SubtleMustCopy(indices_vec(i)); 266 OP_REQUIRES( 267 c, FastBoundsCheck(index, first_dim_size), 268 errors::InvalidArgument("indices[", i, "] is out of range")); 269 Eigen::DSizes<Eigen::DenseIndex, 2> merged_indices(index, 0); 270 merged_flat.slice(merged_indices, sizes) = 271 data_flat.slice(data_indices, sizes); 272 } 273 } 274 }; 275 if (Parallel) { 276 auto thread_pool = 277 c->device()->tensorflow_cpu_worker_threads()->workers; 278 size_t total_indices_size = 0; 279 for (int input_num = 0; input_num < indices_inputs.size(); 280 ++input_num) { 281 total_indices_size += indices_inputs[input_num].NumElements(); 282 } 283 const double avg_indices_size = 284 static_cast<double>(total_indices_size) / indices_inputs.size(); 285 auto bytes_processed = slice_bytes * avg_indices_size; 286 auto LoopBody = [&](int first, int last) { 287 for (int input_num = first; input_num < last; ++input_num) { 288 OnInputNumber(input_num); 289 } 290 }; 291 thread_pool->ParallelFor(indices_inputs.size(), bytes_processed, 292 LoopBody); 293 } else { 294 for (int input_num = 0; input_num < indices_inputs.size(); 295 input_num++) { 296 OnInputNumber(input_num); 297 } 298 } 299 } 300 } 301 }; 302 303 // Using inheritance rather than a typedef so that these classes might have more 304 // functionality later. 305 306 template <typename T> 307 struct DynamicStitchOpCPU : DynamicStitchOpImplCPU<T, false> { 308 using DynamicStitchOpImplCPU<T, false>::DynamicStitchOpImplCPU; 309 }; 310 311 template <typename T> 312 struct ParallelDynamicStitchOpCPU : DynamicStitchOpImplCPU<T, true> { 313 using DynamicStitchOpImplCPU<T, true>::DynamicStitchOpImplCPU; 314 }; 315 316 #define REGISTER_DYNAMIC_STITCH(type) \ 317 REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ 318 .Device(DEVICE_CPU) \ 319 .TypeConstraint<type>("T") \ 320 .HostMemory("indices"), \ 321 DynamicStitchOpCPU<type>) \ 322 REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \ 323 .Device(DEVICE_CPU) \ 324 .TypeConstraint<type>("T") \ 325 .HostMemory("indices"), \ 326 ParallelDynamicStitchOpCPU<type>) 327 328 TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH); 329 #undef REGISTER_DYNAMIC_STITCH 330 331 #if GOOGLE_CUDA 332 #define REGISTER_DYNAMIC_STITCH_GPU(type) \ 333 REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ 334 .Device(DEVICE_GPU) \ 335 .TypeConstraint<type>("T") \ 336 .HostMemory("indices"), \ 337 DynamicStitchOpGPU<type>) \ 338 REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \ 339 .Device(DEVICE_GPU) \ 340 .TypeConstraint<type>("T") \ 341 .HostMemory("indices") \ 342 .HostMemory("data") \ 343 .HostMemory("merged"), \ 344 ParallelDynamicStitchOpCPU<type>) 345 346 TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_STITCH_GPU); 347 TF_CALL_complex64(REGISTER_DYNAMIC_STITCH_GPU); 348 TF_CALL_complex128(REGISTER_DYNAMIC_STITCH_GPU); 349 TF_CALL_int64(REGISTER_DYNAMIC_STITCH_GPU); 350 TF_CALL_int32(REGISTER_DYNAMIC_STITCH_GPU); 351 #undef REGISTER_DYNAMIC_STITCH_GPU 352 353 #endif // GOOGLE_CUDA 354 355 #ifdef TENSORFLOW_USE_SYCL 356 #define REGISTER_DYNAMIC_STITCH_SYCL(type) \ 357 REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ 358 .Device(DEVICE_SYCL) \ 359 .TypeConstraint<type>("T") \ 360 .HostMemory("indices") \ 361 .HostMemory("data") \ 362 .HostMemory("merged"), \ 363 DynamicStitchOpCPU<type>) \ 364 REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \ 365 .Device(DEVICE_SYCL) \ 366 .TypeConstraint<type>("T") \ 367 .HostMemory("indices") \ 368 .HostMemory("data") \ 369 .HostMemory("merged"), \ 370 ParallelDynamicStitchOpCPU<type>) 371 372 TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH_SYCL); 373 #undef REGISTER_DYNAMIC_STITCH_SYCL 374 #endif // TENSORFLOW_USE_SYCL 375 } // namespace tensorflow 376