Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/data_flow_ops.cc.
     17 
     18 #include "tensorflow/core/framework/op_kernel.h"
     19 #include "tensorflow/core/framework/register_types.h"
     20 #include "tensorflow/core/framework/tensor.h"
     21 #include "tensorflow/core/kernels/bounds_check.h"
     22 #include "tensorflow/core/lib/core/threadpool.h"
     23 
     24 #ifdef GOOGLE_CUDA
     25 #include "tensorflow/core/kernels/cuda_device_array.h"
     26 #endif  // GOOGLE_CUDA
     27 
     28 namespace tensorflow {
     29 
     30 typedef Eigen::ThreadPoolDevice CPUDevice;
     31 #ifdef GOOGLE_CUDA
     32 typedef Eigen::GpuDevice GPUDevice;
     33 #endif  // GOOGLE_CUDA
     34 
     35 template <class T>
     36 class DynamicStitchOpImplBase : public OpKernel {
     37  public:
     38   explicit DynamicStitchOpImplBase(OpKernelConstruction* c,
     39                                    const string& op_name)
     40       : OpKernel(c) {
     41     // Compute expected input signature
     42     const DataType dt = DataTypeToEnum<T>::v();
     43     const int n = c->num_inputs() / 2;
     44     DataTypeVector expected;
     45     for (int i = 0; i < n; i++) {
     46       expected.push_back(DT_INT32);
     47     }
     48     for (int i = 0; i < n; i++) {
     49       expected.push_back(dt);
     50     }
     51     OP_REQUIRES_OK(c, c->MatchSignature(expected, {dt}));
     52     OP_REQUIRES(c, c->num_inputs() > 0,
     53                 errors::InvalidArgument(op_name + ": Must have some inputs"));
     54     OP_REQUIRES(c, c->num_inputs() % 2 == 0,
     55                 errors::InvalidArgument(
     56                     op_name + ": Must have even number of arguments"));
     57   }
     58 
     59  protected:
     60   // Check if data0.shape[indices0.dims():] == data1.shape[indices1.dims():]
     61   static bool SameExtraShape(const Tensor& data0, const Tensor& indices0,
     62                              const Tensor& data1, const Tensor& indices1) {
     63     const int extra0 = data0.dims() - indices0.dims();
     64     const int extra1 = data1.dims() - indices1.dims();
     65     if (extra0 != extra1) return false;
     66     for (int i = 0; i < extra0; i++) {
     67       if (data0.dim_size(indices0.dims() + i) !=
     68           data1.dim_size(indices1.dims() + i)) {
     69         return false;
     70       }
     71     }
     72     return true;
     73   }
     74 
     75   void CheckArgsAndAllocateResult(OpKernelContext* c,
     76                                   OpInputList* indices_inputs,
     77                                   OpInputList* data_inputs, int* first_dim_size,
     78                                   int* data_elements_size,
     79                                   Tensor** result_ptr) {
     80     // Find maximum index in the indices vectors
     81     OP_REQUIRES_OK(c, c->input_list("indices", indices_inputs));
     82 
     83     int32 max_index = -1;
     84     if (data_elements_size) {
     85       *data_elements_size = 0;
     86     }
     87     for (const Tensor& indices : *indices_inputs) {
     88       if (indices.NumElements() > 0) {
     89         Eigen::Tensor<int32, 0, Eigen::RowMajor> m =
     90             indices.flat<int32>().maximum();
     91         max_index = std::max(m(), max_index);
     92       }
     93       if (data_elements_size) {
     94         *data_elements_size += indices.NumElements();
     95       }
     96     }
     97 
     98     *first_dim_size = max_index + 1;
     99 
    100     // Validate that data[i].shape = indices[i].shape + constant
    101     OP_REQUIRES_OK(c, c->input_list("data", data_inputs));
    102     const Tensor& data0 = (*data_inputs)[0];
    103     const Tensor& indices0 = (*indices_inputs)[0];
    104     for (int input_num = 0; input_num < indices_inputs->size(); input_num++) {
    105       const Tensor& indices = (*indices_inputs)[input_num];
    106       const Tensor& data = (*data_inputs)[input_num];
    107       OP_REQUIRES(
    108           c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()),
    109           errors::InvalidArgument("data[", input_num,
    110                                   "].shape = ", data.shape().DebugString(),
    111                                   " does not start with indices[", input_num,
    112                                   "].shape = ", indices.shape().DebugString()));
    113       OP_REQUIRES(
    114           c, input_num == 0 || SameExtraShape(data0, indices0, data, indices),
    115           errors::InvalidArgument(
    116               "Need data[0].shape[", indices0.dims(), ":] = data[", input_num,
    117               "].shape[", indices.dims(),
    118               ":], got data[0].shape = ", data0.shape().DebugString(),
    119               ", data[", input_num, "].shape = ", data.shape().DebugString(),
    120               ", indices[0].shape = ", indices0.shape().DebugString(),
    121               ", indices[", input_num,
    122               "].shape = ", indices.shape().DebugString()));
    123     }
    124 
    125     // Allocate result tensor of shape
    126     //   [*first_dim_size] + data.shape[indices.dims:]
    127     TensorShape result_shape;
    128     result_shape.AddDim(*first_dim_size);
    129     for (int d = indices0.dims(); d < data0.dims(); d++) {
    130       result_shape.AddDim(data0.dim_size(d));
    131     }
    132     OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, result_ptr));
    133   }
    134 };
    135 
    136 #if GOOGLE_CUDA
    137 
    138 template <typename T>
    139 void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
    140                           const int32 slice_size, const int32 first_dim_size,
    141                           const CudaDeviceArrayStruct<int>& input_indices,
    142                           const CudaDeviceArrayStruct<const T*>& input_ptrs,
    143                           T* output);
    144 
    145 template <class T>
    146 class DynamicStitchOpGPU : public DynamicStitchOpImplBase<T> {
    147  public:
    148   explicit DynamicStitchOpGPU(OpKernelConstruction* c)
    149       : DynamicStitchOpImplBase<T>(c, "DynamicStitchOp") {}
    150 
    151   void Compute(OpKernelContext* c) override {
    152     OpInputList indices_inputs;
    153     OpInputList data_inputs;
    154     int first_dim_size;
    155     int data_elements_size;
    156     Tensor* merged = nullptr;
    157     this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs,
    158                                      &first_dim_size, &data_elements_size,
    159                                      &merged);
    160     if (!c->status().ok()) {
    161       // Avoid segmentation faults if merged cannot be allocated and an error is
    162       // passed back in the context.
    163       return;
    164     }
    165 
    166     // TODO(jeff): Currently we leave uninitialized any portions of
    167     // merged that aren't covered by an index in indices.  What should we do?
    168     if (first_dim_size > 0) {
    169       // because the collision requirements, we have to deal with
    170       // collion first before send data to gpu kernel.
    171       // TODO(ekelsen): Instead of doing a serial scan on the CPU to pick the
    172       // last of duplicated indices, it could instead be done of the GPU
    173       // implicitly using atomics to make sure the last index is the final
    174       // write.
    175       const int slice_size = merged->flat_outer_dims<T>().dimension(1);
    176       CudaDeviceArrayOnHost<int32> indices_flat(c, first_dim_size);
    177       CudaDeviceArrayOnHost<const T*> data_flat(c, data_elements_size);
    178       OP_REQUIRES_OK(c, indices_flat.Init());
    179       OP_REQUIRES_OK(c, data_flat.Init());
    180       // initialize the indices_flat (-1 represents missing indices)
    181       for (int i = 0; i < first_dim_size; ++i) {
    182         indices_flat.Set(i, -1);
    183       }
    184 
    185       // data_flat index
    186       int32 idx = 0;
    187       // sum of indices_inputs[i].NumElements() for compute indicies_flat value.
    188       int32 base_size = 0;
    189       for (int i = 0; i < indices_inputs.size(); ++i) {
    190         auto indices_vec = indices_inputs[i].flat<int32>();
    191         auto data_ptr_base = data_inputs[i].template flat<T>().data();
    192         for (int j = 0; j < indices_vec.size(); ++j) {
    193           // indices_flat's indices represent the indices of output.
    194           // indices_flat's values represent the indices of input_data where the
    195           // data located.
    196           indices_flat.Set(indices_vec(j), base_size + j);
    197           data_flat.Set(
    198               idx, const_cast<T*>(reinterpret_cast<const T*>(data_ptr_base) +
    199                                   j * slice_size));
    200           ++idx;
    201         }
    202         base_size += indices_vec.size();
    203       }
    204       OP_REQUIRES_OK(c, indices_flat.Finalize());
    205       OP_REQUIRES_OK(c, data_flat.Finalize());
    206 
    207       auto output = merged->template flat<T>().data();
    208       DynamicStitchGPUImpl<T>(c->eigen_gpu_device(), slice_size, first_dim_size,
    209                               indices_flat.data(), data_flat.data(), output);
    210     }
    211   }
    212 };
    213 
    214 #endif  // GOOGLE_CUDA
    215 
    216 template <class T, bool Parallel>
    217 class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> {
    218  public:
    219   explicit DynamicStitchOpImplCPU(OpKernelConstruction* c)
    220       : DynamicStitchOpImplBase<T>(
    221             c, (Parallel ? "ParallelDynamicStitchOp" : "DynamicStitchOp")) {}
    222 
    223   void Compute(OpKernelContext* c) override {
    224     OpInputList indices_inputs;
    225     OpInputList data_inputs;
    226     int first_dim_size;
    227     Tensor* merged = nullptr;
    228     this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs,
    229                                      &first_dim_size, nullptr, &merged);
    230     if (!c->status().ok()) {
    231       // Avoid segmentation faults if merged cannot be allocated and an error is
    232       // passed back in the context.
    233       return;
    234     }
    235 
    236     // TODO(jeff): Currently we leave uninitialized any portions of
    237     // merged that aren't covered by an index in indices.  What should we do?
    238     if (first_dim_size > 0) {
    239       auto merged_flat = merged->flat_outer_dims<T>();
    240       const int slice_size = merged_flat.dimension(1);
    241       const size_t slice_bytes = slice_size * sizeof(T);
    242       auto OnInputNumber = [&](int input_num) {
    243         const Tensor& indices = indices_inputs[input_num];
    244         auto indices_vec = indices.flat<int32>();
    245         const Tensor& data = data_inputs[input_num];
    246         auto data_flat =
    247             data.shaped<T, 2>({indices_vec.dimension(0), slice_size});
    248 
    249         if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
    250           T* merged_base = &merged_flat(0, 0);
    251           const T* data_base = &data_flat(0, 0);
    252           for (int i = 0; i < indices_vec.size(); i++) {
    253             int32 index = internal::SubtleMustCopy(indices_vec(i));
    254             OP_REQUIRES(
    255                 c, FastBoundsCheck(index, first_dim_size),
    256                 errors::InvalidArgument("indices[", i, "] is out of range"));
    257             memcpy(merged_base + index * slice_size, data_base + i * slice_size,
    258                    slice_bytes);
    259           }
    260         } else {
    261           Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, slice_size);
    262           for (int i = 0; i < indices_vec.size(); i++) {
    263             // Copy slice data[i] to merged[indices[i]]
    264             Eigen::DSizes<Eigen::DenseIndex, 2> data_indices(i, 0);
    265             int32 index = internal::SubtleMustCopy(indices_vec(i));
    266             OP_REQUIRES(
    267                 c, FastBoundsCheck(index, first_dim_size),
    268                 errors::InvalidArgument("indices[", i, "] is out of range"));
    269             Eigen::DSizes<Eigen::DenseIndex, 2> merged_indices(index, 0);
    270             merged_flat.slice(merged_indices, sizes) =
    271                 data_flat.slice(data_indices, sizes);
    272           }
    273         }
    274       };
    275       if (Parallel) {
    276         auto thread_pool =
    277             c->device()->tensorflow_cpu_worker_threads()->workers;
    278         size_t total_indices_size = 0;
    279         for (int input_num = 0; input_num < indices_inputs.size();
    280              ++input_num) {
    281           total_indices_size += indices_inputs[input_num].NumElements();
    282         }
    283         const double avg_indices_size =
    284             static_cast<double>(total_indices_size) / indices_inputs.size();
    285         auto bytes_processed = slice_bytes * avg_indices_size;
    286         auto LoopBody = [&](int first, int last) {
    287           for (int input_num = first; input_num < last; ++input_num) {
    288             OnInputNumber(input_num);
    289           }
    290         };
    291         thread_pool->ParallelFor(indices_inputs.size(), bytes_processed,
    292                                  LoopBody);
    293       } else {
    294         for (int input_num = 0; input_num < indices_inputs.size();
    295              input_num++) {
    296           OnInputNumber(input_num);
    297         }
    298       }
    299     }
    300   }
    301 };
    302 
    303 // Using inheritance rather than a typedef so that these classes might have more
    304 // functionality later.
    305 
    306 template <typename T>
    307 struct DynamicStitchOpCPU : DynamicStitchOpImplCPU<T, false> {
    308   using DynamicStitchOpImplCPU<T, false>::DynamicStitchOpImplCPU;
    309 };
    310 
    311 template <typename T>
    312 struct ParallelDynamicStitchOpCPU : DynamicStitchOpImplCPU<T, true> {
    313   using DynamicStitchOpImplCPU<T, true>::DynamicStitchOpImplCPU;
    314 };
    315 
    316 #define REGISTER_DYNAMIC_STITCH(type)                    \
    317   REGISTER_KERNEL_BUILDER(Name("DynamicStitch")          \
    318                               .Device(DEVICE_CPU)        \
    319                               .TypeConstraint<type>("T") \
    320                               .HostMemory("indices"),    \
    321                           DynamicStitchOpCPU<type>)      \
    322   REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch")  \
    323                               .Device(DEVICE_CPU)        \
    324                               .TypeConstraint<type>("T") \
    325                               .HostMemory("indices"),    \
    326                           ParallelDynamicStitchOpCPU<type>)
    327 
    328 TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
    329 #undef REGISTER_DYNAMIC_STITCH
    330 
    331 #if GOOGLE_CUDA
    332 #define REGISTER_DYNAMIC_STITCH_GPU(type)                \
    333   REGISTER_KERNEL_BUILDER(Name("DynamicStitch")          \
    334                               .Device(DEVICE_GPU)        \
    335                               .TypeConstraint<type>("T") \
    336                               .HostMemory("indices"),    \
    337                           DynamicStitchOpGPU<type>)      \
    338   REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch")  \
    339                               .Device(DEVICE_GPU)        \
    340                               .TypeConstraint<type>("T") \
    341                               .HostMemory("indices")     \
    342                               .HostMemory("data")        \
    343                               .HostMemory("merged"),     \
    344                           ParallelDynamicStitchOpCPU<type>)
    345 
    346 TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
    347 TF_CALL_complex64(REGISTER_DYNAMIC_STITCH_GPU);
    348 TF_CALL_complex128(REGISTER_DYNAMIC_STITCH_GPU);
    349 TF_CALL_int64(REGISTER_DYNAMIC_STITCH_GPU);
    350 TF_CALL_int32(REGISTER_DYNAMIC_STITCH_GPU);
    351 #undef REGISTER_DYNAMIC_STITCH_GPU
    352 
    353 #endif  // GOOGLE_CUDA
    354 
    355 #ifdef TENSORFLOW_USE_SYCL
    356 #define REGISTER_DYNAMIC_STITCH_SYCL(type)               \
    357   REGISTER_KERNEL_BUILDER(Name("DynamicStitch")          \
    358                               .Device(DEVICE_SYCL)       \
    359                               .TypeConstraint<type>("T") \
    360                               .HostMemory("indices")     \
    361                               .HostMemory("data")        \
    362                               .HostMemory("merged"),     \
    363                           DynamicStitchOpCPU<type>)      \
    364   REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch")  \
    365                               .Device(DEVICE_SYCL)       \
    366                               .TypeConstraint<type>("T") \
    367                               .HostMemory("indices")     \
    368                               .HostMemory("data")        \
    369                               .HostMemory("merged"),     \
    370                           ParallelDynamicStitchOpCPU<type>)
    371 
    372 TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH_SYCL);
    373 #undef REGISTER_DYNAMIC_STITCH_SYCL
    374 #endif  // TENSORFLOW_USE_SYCL
    375 }  // namespace tensorflow
    376