Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/register_types.h"
     23 #include "tensorflow/core/framework/tensor.h"
     24 #include "tensorflow/core/kernels/bounds_check.h"
     25 #include "tensorflow/core/kernels/ops_util.h"
     26 #include "tensorflow/core/kernels/split_lib.h"
     27 #include "tensorflow/core/lib/core/status.h"
     28 #include "tensorflow/core/lib/gtl/array_slice.h"
     29 #include "tensorflow/core/util/work_sharder.h"
     30 #if GOOGLE_CUDA
     31 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
     32 #include "tensorflow/core/kernels/cuda_device_array.h"
     33 #include "tensorflow/core/platform/stream_executor.h"
     34 #endif  // GOOGLE_CUDA
     35 
     36 namespace tensorflow {
     37 
     38 typedef Eigen::ThreadPoolDevice CPUDevice;
     39 typedef Eigen::GpuDevice GPUDevice;
     40 #ifdef TENSORFLOW_USE_SYCL
     41 typedef Eigen::SyclDevice SYCLDevice;
     42 #endif  // TENSORFLOW_USE_SYCL
     43 
     44 template <typename Device, typename T>
     45 class SplitOpBase : public OpKernel {
     46  public:
     47   explicit SplitOpBase(OpKernelConstruction* c) : OpKernel(c) {}
     48 
     49   void ComputeEasyCases(OpKernelContext* context, bool* done) {
     50     const Tensor& input = context->input(1);
     51     const TensorShape& input_shape = input.shape();
     52     const int32 split_dim_orig = context->input(0).flat<int32>()(0);
     53     const int32 split_dim =
     54         split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
     55     const int32 num_split = num_outputs();
     56 
     57     OP_REQUIRES(
     58         context, 0 <= split_dim && split_dim < input_shape.dims(),
     59         errors::InvalidArgument("-input rank(-", input.dims(),
     60                                 ") <= split_dim < input rank (", input.dims(),
     61                                 "), but got ", split_dim_orig));
     62 
     63     OP_REQUIRES(
     64         context, num_split > 0,
     65         errors::InvalidArgument(
     66             "Number of ways to split should be > 0, but got ", num_split));
     67 
     68     OP_REQUIRES(context, input_shape.dim_size(split_dim) % num_split == 0,
     69                 errors::InvalidArgument(
     70                     "Number of ways to split should evenly divide the split "
     71                     "dimension, but got split_dim ",
     72                     split_dim, " (size = ", input_shape.dim_size(split_dim),
     73                     ") ", "and num_split ", num_split));
     74     // Special case 1: num_split == 1. Nothing to do.
     75     if (num_split == 1) {
     76       VLOG(1) << "Split identity";
     77       context->set_output(0, context->input(1));
     78       *done = true;
     79       return;
     80     }
     81 
     82     // Special case 2: split along the 1st dimension. We can share the
     83     // underlying buffer.
     84     //
     85     // Apply this optimization conservatively: if input is aligned,
     86     // the resulting tensors must be aligned. It's conservative
     87     // because if the immediate consumer of the resulting tensors are
     88     // not using eigen for computation, its perfectly fine to avoid
     89     // the copying.
     90     if ((split_dim == 0) && IsInnerDimsSizeAligned<T>(input_shape)) {
     91       VLOG(1) << "Slice dim 0: " << input_shape.DebugString();
     92       const int64 delta = input_shape.dim_size(0) / num_split;
     93       for (int i = 0; i < num_split; ++i) {
     94         context->set_output(i, input.Slice(i * delta, (i + 1) * delta));
     95       }
     96       *done = true;
     97       return;
     98     }
     99   }
    100 
    101   template <typename IndexType>
    102   std::tuple<IndexType, IndexType, IndexType> SetDims(
    103       const TensorShape& input_shape, int32 split_dim) const {
    104     static_assert(std::is_integral<IndexType>::value,
    105                   "IndexType must be an integer type");
    106     int32 prefix_dim_size = 1;
    107     for (int i = 0; i < split_dim; ++i) {
    108       prefix_dim_size *= input_shape.dim_size(i);
    109     }
    110 
    111     // Caller must ensure that dim_size and suffix_dim_size are <
    112     // std::numeric_limits<IndexType>::max()
    113     IndexType split_dim_size =
    114         static_cast<IndexType>(input_shape.dim_size(split_dim));
    115 
    116     IndexType suffix_dim_size = 1;
    117     for (int i = split_dim + 1; i < input_shape.dims(); ++i) {
    118       suffix_dim_size *= static_cast<IndexType>(input_shape.dim_size(i));
    119     }
    120     return std::make_tuple(prefix_dim_size, split_dim_size, suffix_dim_size);
    121   }
    122 };
    123 
    124 template <typename T>
    125 class SplitOpCPU : public SplitOpBase<CPUDevice, T> {
    126  public:
    127   typedef SplitOpBase<CPUDevice, T> Base;
    128   explicit SplitOpCPU(OpKernelConstruction* c) : Base(c) {}
    129 
    130   void Compute(OpKernelContext* context) override {
    131     bool done = false;
    132     Base::ComputeEasyCases(context, &done);
    133     if (!context->status().ok() || done) {
    134       return;
    135     }
    136     const int32 num_split = Base::num_outputs();
    137     const Tensor& input = context->input(1);
    138     const TensorShape& input_shape = input.shape();
    139     const int32 split_dim_orig = context->input(0).flat<int32>()(0);
    140     const int32 split_dim =
    141         split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
    142 
    143     // Android also uses int32 indexing, so check here also.
    144     OP_REQUIRES(
    145         context,
    146         FastBoundsCheck(input.NumElements(),
    147                         std::numeric_limits<Eigen::DenseIndex>::max()),
    148         errors::InvalidArgument("Split requires input size < ",
    149                                 std::numeric_limits<Eigen::DenseIndex>::max()));
    150 
    151     Eigen::DenseIndex prefix_dim_size;
    152     Eigen::DenseIndex split_dim_size;
    153     Eigen::DenseIndex suffix_dim_size;
    154 
    155     std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) =
    156         Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim);
    157     auto input_reshaped =
    158         input.shaped<T, 3>({prefix_dim_size, split_dim_size, suffix_dim_size});
    159 
    160     const int64 split_dim_output_size = split_dim_size / num_split;
    161     TensorShape output_shape(input_shape);
    162     output_shape.set_dim(split_dim, split_dim_output_size);
    163 
    164     Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0};
    165     const Eigen::DSizes<Eigen::DenseIndex, 3> sizes{
    166         prefix_dim_size, split_dim_output_size, suffix_dim_size};
    167 
    168     const auto num_threads =
    169         context->device()->tensorflow_cpu_worker_threads()->num_threads;
    170     // TODO(jewillco): Tune heuristic further.
    171     const auto input_element_count = input_shape.num_elements();
    172     const bool use_parallelism_between_outputs =
    173         (num_split >= 4 &&
    174          input_element_count >= std::max(num_threads, num_split) * 4096 &&
    175          input_element_count < num_split * 180 * 1024);
    176 
    177     auto range_output_func = [&indices, context, &output_shape, prefix_dim_size,
    178                               split_dim_output_size, suffix_dim_size, &sizes,
    179                               use_parallelism_between_outputs,
    180                               &input_reshaped](int64 start, int64 limit) {
    181       for (int64 i = start; i < limit; ++i) {
    182         Tensor* result = nullptr;
    183         OP_REQUIRES_OK(context,
    184                        context->allocate_output(i, output_shape, &result));
    185         if (prefix_dim_size * split_dim_output_size * suffix_dim_size > 0) {
    186           Eigen::DSizes<Eigen::DenseIndex, 3> slice_indices;
    187           Eigen::DSizes<Eigen::DenseIndex, 3> slice_sizes;
    188           for (int j = 0; j < 3; ++j) {
    189             slice_indices[j] =
    190                 (j == 1 ? i * split_dim_output_size : indices[j]);
    191             slice_sizes[j] = sizes[j];
    192           }
    193 
    194           auto result_shaped = result->shaped<T, 3>(
    195               {prefix_dim_size, split_dim_output_size, suffix_dim_size});
    196 
    197           if (use_parallelism_between_outputs) {
    198             // Use sequential implementation for single output.
    199             result_shaped = input_reshaped.slice(slice_indices, slice_sizes);
    200           } else {
    201             // This implementation may be parallel internally.
    202             functor::Split<CPUDevice, T>()(context->eigen_device<CPUDevice>(),
    203                                            result_shaped, input_reshaped,
    204                                            slice_indices, slice_sizes);
    205           }
    206         }
    207       }
    208     };
    209     if (use_parallelism_between_outputs) {
    210       // Run in parallel, disabling parallelism in functor.
    211       Shard(num_split,
    212             context->device()->tensorflow_cpu_worker_threads()->workers,
    213             num_split, input_element_count / num_split, range_output_func);
    214     } else {
    215       // Run sequentially, but allow internal parallelism in functor.
    216       range_output_func(0, num_split);
    217     }
    218   }
    219 };
    220 
    221 #if GOOGLE_CUDA
    222 
    223 template <typename T>
    224 struct SplitOpGPULaunch {
    225   void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
    226            int32 split_dim_size, int32 suffix_dim_size,
    227            const CudaDeviceArrayStruct<T*>& output_ptr_data);
    228 };
    229 
    230 // Partial specialization for GPU
    231 template <typename T>
    232 class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
    233  public:
    234   typedef SplitOpBase<GPUDevice, T> Base;
    235   explicit SplitOpGPU(OpKernelConstruction* c) : Base(c) {}
    236 
    237   void Compute(OpKernelContext* context) override {
    238     bool done = false;
    239     Base::ComputeEasyCases(context, &done);
    240     if (!context->status().ok() || done) {
    241       return;
    242     }
    243     const Tensor& input = context->input(1);
    244     const TensorShape& input_shape = input.shape();
    245     const int32 split_dim_orig = context->input(0).flat<int32>()(0);
    246     const int32 split_dim =
    247         split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
    248     const int32 num_split = Base::num_outputs();
    249     OP_REQUIRES(
    250         context,
    251         FastBoundsCheck(input.NumElements(), std::numeric_limits<int32>::max()),
    252         errors::InvalidArgument("Split on GPU requires input size "
    253                                 "< max int32"));
    254     int32 prefix_dim_size;
    255     int32 split_dim_size;
    256     int32 suffix_dim_size;
    257     std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) =
    258         Base::template SetDims<int32>(input_shape, split_dim);
    259 
    260     const int32 split_dim_output_size = split_dim_size / num_split;
    261     TensorShape output_shape(input_shape);
    262     output_shape.set_dim(split_dim, split_dim_output_size);
    263 
    264     CudaDeviceArrayOnHost<T*> ptrs(context, num_split);
    265     OP_REQUIRES_OK(context, ptrs.Init());
    266 
    267     for (int i = 0; i < num_split; ++i) {
    268       Tensor* result = nullptr;
    269       OP_REQUIRES_OK(context,
    270                      context->allocate_output(i, output_shape, &result));
    271       ptrs.Set(i, result->flat<T>().data());
    272     }
    273     if (prefix_dim_size * split_dim_output_size * suffix_dim_size == 0) {
    274       return;
    275     }
    276     OP_REQUIRES_OK(context, ptrs.Finalize());
    277 
    278     SplitOpGPULaunch<T>().Run(context->eigen_device<GPUDevice>(),
    279                               input.flat<T>().data(), prefix_dim_size,
    280                               split_dim_size, suffix_dim_size, ptrs.data());
    281     OP_REQUIRES(context, context->op_device_context()->stream()->ok(),
    282                 errors::Internal("Launch of gpu kernel for SplitOp failed"));
    283   }
    284 };
    285 #endif  // GOOGLE_CUDA
    286 
    287 #ifdef TENSORFLOW_USE_SYCL
    288 template <typename T>
    289 class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
    290  public:
    291   typedef SplitOpBase<SYCLDevice, T> Base;
    292   explicit SplitOpSYCL(OpKernelConstruction* c) : Base(c) {}
    293 
    294   void Compute(OpKernelContext* context) override {
    295     bool done = false;
    296     Base::ComputeEasyCases(context, &done);
    297     if (!context->status().ok() || done) {
    298       return;
    299     }
    300     const Tensor& input = context->input(1);
    301     const TensorShape& input_shape = input.shape();
    302     const int32 split_dim_orig = context->input(0).flat<int32>()(0);
    303     const int32 split_dim =
    304         split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
    305     const int32 num_split = Base::num_outputs();
    306 
    307     // Android also uses int32 indexing, so check here also.
    308     OP_REQUIRES(
    309         context,
    310         FastBoundsCheck(input.NumElements(),
    311                         std::numeric_limits<Eigen::DenseIndex>::max()),
    312         errors::InvalidArgument("Split requires input size < ",
    313                                 std::numeric_limits<Eigen::DenseIndex>::max()));
    314 
    315     Eigen::DenseIndex prefix_dim_size;
    316     Eigen::DenseIndex split_dim_size;
    317     Eigen::DenseIndex suffix_dim_size;
    318 
    319     std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) =
    320         Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim);
    321     auto input_reshaped =
    322         input.shaped<T, 3>({prefix_dim_size, split_dim_size, suffix_dim_size});
    323 
    324     const int64 split_dim_output_size = split_dim_size / num_split;
    325     TensorShape output_shape(input_shape);
    326     output_shape.set_dim(split_dim, split_dim_output_size);
    327 
    328     Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0};
    329     Eigen::DSizes<Eigen::DenseIndex, 3> sizes{
    330         prefix_dim_size, split_dim_output_size, suffix_dim_size};
    331 
    332     for (int i = 0; i < num_split; ++i) {
    333       Tensor* result = nullptr;
    334       OP_REQUIRES_OK(context,
    335                      context->allocate_output(i, output_shape, &result));
    336       if (prefix_dim_size * split_dim_output_size * suffix_dim_size > 0) {
    337         Eigen::DSizes<Eigen::DenseIndex, 3> slice_indices;
    338         Eigen::DSizes<Eigen::DenseIndex, 3> slice_sizes;
    339         for (int j = 0; j < 3; ++j) {
    340           slice_indices[j] = indices[j];
    341           slice_sizes[j] = sizes[j];
    342         }
    343 
    344         auto result_shaped = result->shaped<T, 3>(
    345             {prefix_dim_size, split_dim_output_size, suffix_dim_size});
    346 
    347         functor::Split<SYCLDevice, T>()(context->eigen_device<SYCLDevice>(),
    348                                         result_shaped, input_reshaped,
    349                                         slice_indices, slice_sizes);
    350       }
    351       indices[1] += split_dim_output_size;
    352     }
    353   }
    354 };
    355 #endif  // TENSORFLOW_USE_SYCL
    356 
    357 #define REGISTER_SPLIT(type)                             \
    358   REGISTER_KERNEL_BUILDER(Name("Split")                  \
    359                               .Device(DEVICE_CPU)        \
    360                               .TypeConstraint<type>("T") \
    361                               .HostMemory("split_dim"),  \
    362                           SplitOpCPU<type>)
    363 
    364 TF_CALL_ALL_TYPES(REGISTER_SPLIT);
    365 REGISTER_SPLIT(quint8);
    366 
    367 #undef REGISTER_SPLIT
    368 
    369 #if GOOGLE_CUDA
    370 
    371 #define REGISTER_GPU(type)                               \
    372   REGISTER_KERNEL_BUILDER(Name("Split")                  \
    373                               .Device(DEVICE_GPU)        \
    374                               .TypeConstraint<type>("T") \
    375                               .HostMemory("split_dim"),  \
    376                           SplitOpGPU<type>)
    377 
    378 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
    379 TF_CALL_complex64(REGISTER_GPU);
    380 TF_CALL_complex128(REGISTER_GPU);
    381 REGISTER_GPU(bfloat16);
    382 #undef REGISTER_GPU
    383 
    384 #endif  // GOOGLE_CUDA
    385 
    386 #ifdef TENSORFLOW_USE_SYCL
    387 #define REGISTER_SYCL(type)                              \
    388   REGISTER_KERNEL_BUILDER(Name("Split")                  \
    389                               .Device(DEVICE_SYCL)       \
    390                               .TypeConstraint<type>("T") \
    391                               .HostMemory("split_dim"),  \
    392                           SplitOpSYCL<type>)
    393 
    394 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL);
    395 #undef REGISTER_SYCL
    396 
    397 #endif  // TENSORFLOW_USE_SYCL
    398 
    399 }  // end namespace tensorflow
    400