Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 
     18 #include <vector>
     19 
     20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/register_types.h"
     23 #include "tensorflow/core/framework/tensor.h"
     24 #include "tensorflow/core/framework/tensor_types.h"
     25 #include "tensorflow/core/framework/types.h"
     26 
     27 #if GOOGLE_CUDA
     28 
     29 #include "tensorflow/core/kernels/cuda_device_array.h"
     30 
     31 namespace tensorflow {
     32 
     33 template <typename T, typename IntType>
     34 void ConcatGPUSlice(
     35     const Eigen::GpuDevice& gpu_device,
     36     const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
     37         inputs_flat,
     38     typename TTypes<T, 2>::Matrix* output);
     39 
     40 template <typename T, typename IntType>
     41 void ConcatGPUImpl(const Eigen::GpuDevice& d,
     42                    const CudaDeviceArrayStruct<const T*>& input_ptrs,
     43                    const CudaDeviceArrayStruct<IntType>& ptr_offsets,
     44                    bool same_size, int slice_size,
     45                    typename TTypes<T, 2>::Matrix* output);
     46 
     47 namespace {
     48 
     49 template <typename T, typename IntType>
     50 void ConcatGPUCall(
     51     OpKernelContext* c,
     52     const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
     53         inputs_flat,
     54     typename TTypes<T, 2>::Tensor* output_flat) {
     55   CudaDeviceArrayOnHost<const T*> input_ptrs(c, inputs_flat.size());
     56   OP_REQUIRES_OK(c, input_ptrs.Init());
     57   for (int i = 0; i < inputs_flat.size(); ++i) {
     58     input_ptrs.Set(i, inputs_flat[i]->data());
     59   }
     60   OP_REQUIRES_OK(c, input_ptrs.Finalize());
     61 
     62   CudaDeviceArrayOnHost<IntType> output_scan(c, inputs_flat.size() + 1);
     63   OP_REQUIRES_OK(c, output_scan.Init());
     64   IntType scan = 0;
     65   output_scan.Set(0, scan);
     66   bool one_size_input = true;
     67   for (int i = 0; i < inputs_flat.size(); ++i) {
     68     if (one_size_input && i < inputs_flat.size() - 1 &&
     69         inputs_flat[i]->dimension(1) != inputs_flat[i + 1]->dimension(1)) {
     70       one_size_input = false;
     71     }
     72     scan += inputs_flat[i]->dimension(1);
     73     output_scan.Set(i + 1, scan);
     74   }
     75   if (!one_size_input) OP_REQUIRES_OK(c, output_scan.Finalize());
     76 
     77   ConcatGPUImpl<T, IntType>(c->eigen_gpu_device(), input_ptrs.data(),
     78                             output_scan.data(), one_size_input,
     79                             inputs_flat[0]->dimension(1), output_flat);
     80 }
     81 
     82 }  // end namespace
     83 
     84 template <typename T>
     85 void ConcatGPU(
     86     OpKernelContext* c,
     87     const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
     88         inputs_flat,
     89     Tensor* output, typename TTypes<T, 2>::Tensor* output_flat) {
     90   if (inputs_flat.size() < 16) {
     91     if (output->NumElements() < std::numeric_limits<int32>::max()) {
     92       ConcatGPUSlice<T, int32>(c->eigen_gpu_device(), inputs_flat, output_flat);
     93     } else {
     94       ConcatGPUSlice<T, int64>(c->eigen_gpu_device(), inputs_flat, output_flat);
     95     }
     96   } else {
     97     // Switching indexing to int64 might cause performance issues.
     98     // Hence, we keep int32 indexing in the GPU kernel unless we need to
     99     // switch to int64.
    100     if (output->NumElements() < std::numeric_limits<int32>::max()) {
    101       ConcatGPUCall<T, int32>(c, inputs_flat, output_flat);
    102     } else {
    103       ConcatGPUCall<T, int64>(c, inputs_flat, output_flat);
    104     }
    105   }
    106 }
    107 
    108 #define REGISTER(T)                                                           \
    109   template void ConcatGPU<T>(                                                 \
    110       OpKernelContext * c,                                                    \
    111       const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
    112           inputs_flat,                                                        \
    113       Tensor* output, typename TTypes<T, 2>::Tensor* output_flat);
    114 
    115 TF_CALL_GPU_NUMBER_TYPES(REGISTER);
    116 TF_CALL_complex64(REGISTER);
    117 TF_CALL_complex128(REGISTER);
    118 TF_CALL_int64(REGISTER);
    119 TF_CALL_bfloat16(REGISTER);
    120 TF_CALL_bool(REGISTER);
    121 
    122 #undef REGISTER
    123 
    124 }  // namespace tensorflow
    125 
    126 #endif  // GOOGLE_CUDA
    127