Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #if GOOGLE_CUDA
     21 #define EIGEN_USE_GPU
     22 #endif  // GOOGLE_CUDA
     23 
     24 #include <numeric>
     25 
     26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     27 #include "tensorflow/core/framework/op_kernel.h"
     28 #include "tensorflow/core/framework/register_types.h"
     29 #include "tensorflow/core/framework/tensor.h"
     30 #include "tensorflow/core/kernels/bounds_check.h"
     31 #include "tensorflow/core/kernels/ops_util.h"
     32 #include "tensorflow/core/kernels/split_lib.h"
     33 #include "tensorflow/core/lib/core/status.h"
     34 #include "tensorflow/core/lib/gtl/array_slice.h"
     35 #include "tensorflow/core/util/work_sharder.h"
     36 #if GOOGLE_CUDA
     37 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
     38 #include "tensorflow/core/kernels/cuda_device_array.h"
     39 #include "tensorflow/core/platform/stream_executor.h"
     40 #endif  // GOOGLE_CUDA
     41 
     42 namespace tensorflow {
     43 
     44 typedef Eigen::ThreadPoolDevice CPUDevice;
     45 typedef Eigen::GpuDevice GPUDevice;
     46 
     47 template <typename Device, typename T, typename Tlen>
     48 class SplitVOpBase : public OpKernel {
     49  public:
     50   explicit SplitVOpBase(OpKernelConstruction* c) : OpKernel(c) {}
     51 
     52   void ComputeEasyCases(OpKernelContext* context, bool* done,
     53                         std::vector<Tlen>* split_sizes_vec) {
     54     const int32 num_split = context->num_outputs();
     55     const Tensor& input = context->input(0);
     56     const TensorShape& input_shape = input.shape();
     57     const Tensor& split_tensor = context->input(1);
     58 
     59     const int32 split_dim_orig = context->input(2).flat<int32>()(0);
     60     const int32 split_dim =
     61         split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
     62 
     63     OP_REQUIRES(
     64         context,
     65         split_tensor.dims() == 1 && split_tensor.NumElements() == num_split,
     66         errors::InvalidArgument("size of the split_tensor must be 1-D and have "
     67                                 "the same elements as outputs got ",
     68                                 split_tensor.dims(), " -D and ",
     69                                 split_tensor.NumElements(), " elements"));
     70 
     71     auto split_sizes_d = split_tensor.vec<Tlen>();
     72 
     73     split_sizes_vec->resize(split_sizes_d.size());
     74 
     75     std::copy(split_sizes_d.data(), split_sizes_d.data() + split_sizes_d.size(),
     76               split_sizes_vec->begin());
     77 
     78     OP_REQUIRES(
     79         context, num_split > 0,
     80         errors::InvalidArgument(
     81             "Number of ways to split should be > 0, but got ", num_split));
     82 
     83     OP_REQUIRES(
     84         context, 0 <= split_dim && split_dim < input.dims(),
     85         errors::InvalidArgument("-input rank(-", input.dims(),
     86                                 ") <= split_dim < input rank (", input.dims(),
     87                                 "), but got ", split_dim_orig));
     88 
     89     Tlen input_size_split_dim = input_shape.dim_size(split_dim);
     90 
     91     // Special case 1: num_split == 1. Nothing to do.
     92     if (num_split == 1) {
     93       context->set_output(0, context->input(0));
     94       OP_REQUIRES(
     95           context, (*split_sizes_vec)[0] == input_size_split_dim,
     96           errors::InvalidArgument("If there is only one output, it must have "
     97                                   "the same size as the input. Input size: ",
     98                                   input_size_split_dim,
     99                                   " output size: ", (*split_sizes_vec)[0]));
    100       *done = true;
    101       return;
    102     }
    103 
    104     // Determine sizes of output, in case of a -1 input value
    105     int neg_one_dim = -1;
    106     Tlen determined_size = 0;
    107     for (int d = 0; d < split_sizes_vec->size(); ++d) {
    108       Tlen size = (*split_sizes_vec)[d];
    109 
    110       if (size == -1) {
    111         OP_REQUIRES(context, neg_one_dim == -1,
    112                     errors::InvalidArgument("There can only be one -1 in the "
    113                                             "input."));
    114         neg_one_dim = d;
    115       } else {
    116         determined_size += size;
    117       }
    118     }
    119 
    120     OP_REQUIRES(
    121         context,
    122         (neg_one_dim == -1 && determined_size == input_size_split_dim) ||
    123             (neg_one_dim >= 0 && determined_size <= input_size_split_dim),
    124         errors::InvalidArgument("Determined shape must either match "
    125                                 "input shape along split_dim exactly if "
    126                                 "fully specified, or be less than the size of "
    127                                 "the input along split_dim if not fully "
    128                                 "specified.  Got: ",
    129                                 determined_size));
    130 
    131     if (neg_one_dim >= 0) {
    132       (*split_sizes_vec)[neg_one_dim] = input_size_split_dim - determined_size;
    133     }
    134 
    135     // Special case 2: split along the 1st dimension. We can share the
    136     // underlying buffer.
    137     //
    138     // Apply this optimization conservatively: if input is aligned,
    139     // the resulting tensors must be aligned. It's conservative
    140     // because if the immediate consumer of the resulting tensors are
    141     // not using eigen for computation, its perfectly fine to avoid
    142     // the copying.
    143     if ((split_dim == 0) && IsInnerDimsSizeAligned<T>(input_shape)) {
    144       Tlen start = 0;
    145       for (int i = 0; i < num_split; ++i) {
    146         context->set_output(i,
    147                             input.Slice(start, start + (*split_sizes_vec)[i]));
    148         start += (*split_sizes_vec)[i];
    149       }
    150       *done = true;
    151       return;
    152     }
    153   }
    154 
    155   template <typename IndexType>
    156   std::tuple<IndexType, IndexType, IndexType> SetDims(
    157       const TensorShape& input_shape, const int32 split_dim) const {
    158     static_assert(std::is_integral<IndexType>::value,
    159                   "IndexType must be an integer type");
    160     int32 prefix_dim_size = 1;
    161     for (int i = 0; i < split_dim; ++i) {
    162       prefix_dim_size *= input_shape.dim_size(i);
    163     }
    164 
    165     // Caller must ensure that dim_size and suffix_dim_size are <
    166     // std::numeric_limits<IndexType>::max()
    167     IndexType split_dim_size =
    168         static_cast<IndexType>(input_shape.dim_size(split_dim));
    169 
    170     IndexType suffix_dim_size = 1;
    171     for (int i = split_dim + 1; i < input_shape.dims(); ++i) {
    172       suffix_dim_size *= static_cast<IndexType>(input_shape.dim_size(i));
    173     }
    174     return std::make_tuple(prefix_dim_size, split_dim_size, suffix_dim_size);
    175   }
    176 };
    177 
    178 template <typename T, typename Tlen>
    179 class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> {
    180  public:
    181   typedef SplitVOpBase<CPUDevice, T, Tlen> Base;
    182   explicit SplitVOpCPU(OpKernelConstruction* c) : Base(c) {}
    183 
    184   void Compute(OpKernelContext* context) override {
    185     bool done = false;
    186     std::vector<Tlen> split_sizes_vec;
    187     Base::ComputeEasyCases(context, &done, &split_sizes_vec);
    188     if (!context->status().ok() || done) {
    189       return;
    190     }
    191     const int32 num_split = Base::num_outputs();
    192     const Tensor& input = context->input(0);
    193     const TensorShape& input_shape = input.shape();
    194     const int32 split_dim_orig = context->input(2).flat<int32>()(0);
    195     const int32 split_dim =
    196         split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
    197 
    198     // Android also uses int32 indexing, so check here also.
    199     OP_REQUIRES(
    200         context,
    201         FastBoundsCheck(input.NumElements(),
    202                         std::numeric_limits<Eigen::DenseIndex>::max()),
    203         errors::InvalidArgument("Split requires input size < ",
    204                                 std::numeric_limits<Eigen::DenseIndex>::max()));
    205 
    206     Eigen::DenseIndex prefix_dim_size;
    207     Eigen::DenseIndex split_dim_size;
    208     Eigen::DenseIndex suffix_dim_size;
    209 
    210     std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) =
    211         Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim);
    212     auto input_reshaped =
    213         input.shaped<T, 3>({prefix_dim_size, split_dim_size, suffix_dim_size});
    214 
    215     Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0};
    216     std::vector<int64> split_start_points(num_split);
    217     for (int i = 0; i < num_split; ++i) {
    218       if (i == 0) {
    219         split_start_points[i] = 0;
    220       } else {
    221         split_start_points[i] =
    222             split_start_points[i - 1] + split_sizes_vec[i - 1];
    223       }
    224     }
    225 
    226     const auto num_threads =
    227         context->device()->tensorflow_cpu_worker_threads()->num_threads;
    228     // TODO(jewillco): Tune heuristic further.
    229     const auto input_element_count = input_shape.num_elements();
    230     const bool use_parallelism_between_outputs =
    231         (num_split >= 4 &&
    232          input_element_count >= std::max(num_threads, num_split) * 4096 &&
    233          input_element_count < num_split * 180 * 1024);
    234 
    235     auto range_output_func = [&indices, context, &input_shape, prefix_dim_size,
    236                               split_dim, &split_sizes_vec, &split_start_points,
    237                               suffix_dim_size, use_parallelism_between_outputs,
    238                               &input_reshaped](int64 start, int64 limit) {
    239       for (int64 i = start; i < limit; ++i) {
    240         TensorShape output_shape(input_shape);
    241         output_shape.set_dim(split_dim, split_sizes_vec[i]);
    242         Tensor* result = nullptr;
    243         OP_REQUIRES_OK(context,
    244                        context->allocate_output(i, output_shape, &result));
    245 
    246         Eigen::DSizes<Eigen::DenseIndex, 3> sizes{
    247             prefix_dim_size, split_sizes_vec[i], suffix_dim_size};
    248 
    249         if (sizes.TotalSize() > 0) {
    250           auto result_shaped = result->shaped<T, 3>(
    251               {prefix_dim_size, split_sizes_vec[i], suffix_dim_size});
    252 
    253           auto current_indices = indices;
    254           current_indices[1] = split_start_points[i];
    255           if (use_parallelism_between_outputs) {
    256             // Use sequential implementation for single output.
    257             result_shaped = input_reshaped.slice(current_indices, sizes);
    258           } else {
    259             // This implementation may be parallel internally.
    260             functor::Split<CPUDevice, T>()(context->eigen_device<CPUDevice>(),
    261                                            result_shaped, input_reshaped,
    262                                            current_indices, sizes);
    263           }
    264         }
    265       }
    266     };
    267     if (use_parallelism_between_outputs) {
    268       // Run in parallel, disabling parallelism in functor.
    269       Shard(num_split,
    270             context->device()->tensorflow_cpu_worker_threads()->workers,
    271             num_split, input_element_count / num_split, range_output_func);
    272     } else {
    273       // Run sequentially, but allow internal parallelism in functor.
    274       range_output_func(0, num_split);
    275     }
    276   }
    277 };
    278 
    279 #if GOOGLE_CUDA
    280 
    281 template <typename T, typename IntType>
    282 struct SplitVOpGPULaunch {
    283   void Run(const Eigen::GpuDevice& d, bool fixed, const T* input,
    284            int total_cols, int total_rows,
    285            const CudaDeviceArrayStruct<IntType>& output_scan,
    286            const CudaDeviceArrayStruct<T*>& output_ptr_data);
    287 };
    288 
    289 // Partial specialization for GPU
    290 template <typename T, typename Tlen>
    291 class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> {
    292  public:
    293   typedef SplitVOpBase<GPUDevice, T, Tlen> Base;
    294   explicit SplitVOpGPU(OpKernelConstruction* c) : Base(c) {}
    295 
    296   void Compute(OpKernelContext* context) override {
    297     bool done = false;
    298     std::vector<Tlen> split_sizes_vec;
    299     Base::ComputeEasyCases(context, &done, &split_sizes_vec);
    300     if (!context->status().ok() || done) {
    301       return;
    302     }
    303     const int32 num_split = Base::num_outputs();
    304     const Tensor& input = context->input(0);
    305     const TensorShape& input_shape = input.shape();
    306     const int32 split_dim_orig = context->input(2).flat<int32>()(0);
    307     const int32 split_dim =
    308         split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
    309     OP_REQUIRES(
    310         context,
    311         FastBoundsCheck(input.NumElements(), std::numeric_limits<int32>::max()),
    312         errors::InvalidArgument("Split on GPU requires input size "
    313                                 "< max int32"));
    314 
    315     int32 prefix_dim_size;
    316     int32 split_dim_size;
    317     int32 suffix_dim_size;
    318     std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) =
    319         Base::template SetDims<int32>(input_shape, split_dim);
    320 
    321     // use the same approach as concat (see documentation there)
    322     // reshape to 2D
    323 
    324     if (num_split > 16) {
    325       CudaDeviceArrayOnHost<T*> ptrs(context, num_split);
    326       OP_REQUIRES_OK(context, ptrs.Init());
    327 
    328       CudaDeviceArrayOnHost<Tlen> offsets(context, num_split + 1);
    329       OP_REQUIRES_OK(context, offsets.Init());
    330 
    331       Tlen offset = 0;
    332       int entry = split_sizes_vec[0];
    333       bool fixed_size =
    334           std::all_of(split_sizes_vec.begin(), split_sizes_vec.end(),
    335                       [&entry](int n) { return n == entry; });
    336 
    337       for (int i = 0; i < num_split; ++i) {
    338         TensorShape output_shape(input_shape);
    339         output_shape.set_dim(split_dim, split_sizes_vec[i]);
    340         Tensor* result = nullptr;
    341         OP_REQUIRES_OK(context,
    342                        context->allocate_output(i, output_shape, &result));
    343         ptrs.Set(i, result->flat<T>().data());
    344         offsets.Set(i, offset);
    345         offset += split_sizes_vec[i] * suffix_dim_size;
    346       }
    347       offsets.Set(num_split, offset);
    348       OP_REQUIRES_OK(context, ptrs.Finalize());
    349       OP_REQUIRES_OK(context, offsets.Finalize());
    350 
    351       if (input.NumElements() > 0) {
    352         SplitVOpGPULaunch<T, Tlen>().Run(
    353             context->eigen_device<GPUDevice>(), fixed_size,
    354             input.flat<T>().data(), prefix_dim_size,
    355             input.NumElements() / prefix_dim_size, offsets.data(), ptrs.data());
    356         OP_REQUIRES(
    357             context, context->op_device_context()->stream()->ok(),
    358             errors::Internal("Launch of gpu kernel for SplitVOp failed"));
    359       }
    360     } else {
    361       Eigen::DenseIndex prefix_dim_size;
    362       Eigen::DenseIndex split_dim_size;
    363       Eigen::DenseIndex suffix_dim_size;
    364 
    365       std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) =
    366           Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim);
    367       auto input_reshaped = input.shaped<T, 2>(
    368           {prefix_dim_size, split_dim_size * suffix_dim_size});
    369 
    370       Eigen::DSizes<Eigen::DenseIndex, 2> indices{0, 0};
    371 
    372       for (int i = 0; i < num_split; ++i) {
    373         TensorShape output_shape(input_shape);
    374         output_shape.set_dim(split_dim, split_sizes_vec[i]);
    375         Tensor* result = nullptr;
    376         OP_REQUIRES_OK(context,
    377                        context->allocate_output(i, output_shape, &result));
    378 
    379         Eigen::DSizes<Eigen::DenseIndex, 2> sizes{
    380             prefix_dim_size, split_sizes_vec[i] * suffix_dim_size};
    381 
    382         if (sizes.TotalSize() > 0) {
    383           auto result_shaped = result->shaped<T, 2>(
    384               {prefix_dim_size, split_sizes_vec[i] * suffix_dim_size});
    385 
    386           functor::SplitCustom<GPUDevice, T>()(
    387               context->eigen_device<GPUDevice>(), result_shaped, input_reshaped,
    388               indices, sizes);
    389         }
    390         indices[1] += split_sizes_vec[i] * suffix_dim_size;
    391       }
    392     }
    393   }
    394 };
    395 #endif  // GOOGLE_CUDA
    396 
    397 #define REGISTER_SPLIT(type, len_type)                          \
    398   REGISTER_KERNEL_BUILDER(Name("SplitV")                        \
    399                               .Device(DEVICE_CPU)               \
    400                               .TypeConstraint<len_type>("Tlen") \
    401                               .TypeConstraint<type>("T")        \
    402                               .HostMemory("size_splits")        \
    403                               .HostMemory("split_dim"),         \
    404                           SplitVOpCPU<type, len_type>);
    405 
    406 #define REGISTER_SPLIT_LEN(type) \
    407   REGISTER_SPLIT(type, int32);   \
    408   REGISTER_SPLIT(type, int64);
    409 
    410 TF_CALL_ALL_TYPES(REGISTER_SPLIT_LEN);
    411 
    412 #undef REGISTER_SPLIT_LEN
    413 #undef REGISTER_SPLIT
    414 
    415 #if GOOGLE_CUDA
    416 
    417 #define REGISTER_GPU(type, len_type)                            \
    418   REGISTER_KERNEL_BUILDER(Name("SplitV")                        \
    419                               .Device(DEVICE_GPU)               \
    420                               .TypeConstraint<len_type>("Tlen") \
    421                               .TypeConstraint<type>("T")        \
    422                               .HostMemory("size_splits")        \
    423                               .HostMemory("split_dim"),         \
    424                           SplitVOpGPU<type, len_type>);
    425 
    426 #define REGISTER_GPU_LEN(type) \
    427   REGISTER_GPU(type, int32);   \
    428   REGISTER_GPU(type, int64);
    429 
    430 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_LEN);
    431 TF_CALL_complex64(REGISTER_GPU_LEN);
    432 TF_CALL_complex128(REGISTER_GPU_LEN);
    433 REGISTER_GPU_LEN(bfloat16);
    434 #undef REGISTER_GPU_LEN
    435 #undef REGISTER_GPU
    436 
    437 // special GPU kernel for int32
    438 
    439 #define REGISTER_GPU_int32(len_type)                            \
    440   REGISTER_KERNEL_BUILDER(Name("SplitV")                        \
    441                               .Device(DEVICE_GPU)               \
    442                               .TypeConstraint<int32>("T")       \
    443                               .TypeConstraint<len_type>("Tlen") \
    444                               .HostMemory("size_splits")        \
    445                               .HostMemory("split_dim")          \
    446                               .HostMemory("value")              \
    447                               .HostMemory("output"),            \
    448                           SplitVOpCPU<int32, len_type>);
    449 
    450 REGISTER_GPU_int32(int32);
    451 REGISTER_GPU_int32(int64);
    452 
    453 #undef REGISTER_GPU_int32
    454 
    455 #endif  // GOOGLE_CUDA
    456 
    457 }  // end namespace tensorflow
    458