Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 
     18 #include <limits>
     19 #include <vector>
     20 
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_types.h"
     26 #include "tensorflow/core/framework/types.h"
     27 #include "tensorflow/core/kernels/bounds_check.h"
     28 #include "tensorflow/core/kernels/concat_lib.h"
     29 #include "tensorflow/core/lib/core/status.h"
     30 #include "tensorflow/core/platform/types.h"
     31 
     32 namespace tensorflow {
     33 
     34 typedef Eigen::ThreadPoolDevice CPUDevice;
     35 #if GOOGLE_CUDA
     36 typedef Eigen::GpuDevice GPUDevice;
     37 #endif  // GOOGLE_CUDA
     38 #ifdef TENSORFLOW_USE_SYCL
     39 typedef Eigen::SyclDevice SYCLDevice;
     40 #endif  // TENSORFLOW_USE_SYCL
     41 
     42 enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM };
     43 
     44 // --------------------------------------------------------------------------
     45 template <typename Device, typename T, AxisArgumentName AxisArgName>
     46 class ConcatBaseOp : public OpKernel {
     47  public:
     48   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
     49       ConstMatrixVector;
     50 
     51   explicit ConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {}
     52 
     53   void Compute(OpKernelContext* c) override {
     54     const Tensor* concat_dim_tensor;
     55     const char* axis_attribute_name =
     56         AxisArgName == NAME_IS_AXIS
     57             ? "axis"
     58             : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>";
     59     OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
     60     OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()),
     61                 errors::InvalidArgument(
     62                     axis_attribute_name,
     63                     " tensor should be a scalar integer, but got shape ",
     64                     concat_dim_tensor->shape().DebugString()));
     65     const int32 concat_dim =
     66         internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()());
     67     OpInputList values;
     68     OP_REQUIRES_OK(c, c->input_list("values", &values));
     69     const int N = values.size();
     70     const int input_dims = values[0].dims();
     71     const TensorShape& input_shape = values[0].shape();
     72 
     73     int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
     74     OP_REQUIRES(c,
     75                 (0 <= axis && axis < input_dims) ||
     76                     (allow_legacy_scalars() && concat_dim == 0),
     77                 errors::InvalidArgument(
     78                     "ConcatOp : Expected concatenating dimensions in the range "
     79                     "[",
     80                     -input_dims, ", ", input_dims, "), but got ", concat_dim));
     81     // Note that we reduce the concat of n-dimensional tensors into a two
     82     // dimensional concat. Assuming the dimensions of any input/output
     83     // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along
     84     // the dimension indicated with size y0, we flatten it to {x, y}, where y =
     85     // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1).
     86     ConstMatrixVector inputs_flat;
     87     inputs_flat.reserve(N);
     88     int64 inputs_flat_dim0 = 1;
     89     for (int d = 0; d < axis; ++d) {
     90       inputs_flat_dim0 *= input_shape.dim_size(d);
     91     }
     92     int64 output_concat_dim = 0;
     93     const bool input_is_scalar = IsLegacyScalar(input_shape);
     94     for (int i = 0; i < N; ++i) {
     95       const auto in = values[i];
     96       const bool in_is_scalar = IsLegacyScalar(in.shape());
     97       OP_REQUIRES(
     98           c, in.dims() == input_dims || (input_is_scalar && in_is_scalar),
     99           errors::InvalidArgument(
    100               "ConcatOp : Ranks of all input tensors should match: shape[0] = ",
    101               input_shape.DebugString(), " vs. shape[", i,
    102               "] = ", in.shape().DebugString()));
    103       for (int j = 0; j < input_dims; ++j) {
    104         if (j == axis) {
    105           continue;
    106         }
    107         OP_REQUIRES(
    108             c, in.dim_size(j) == input_shape.dim_size(j),
    109             errors::InvalidArgument(
    110                 "ConcatOp : Dimensions of inputs should match: shape[0] = ",
    111                 input_shape.DebugString(), " vs. shape[", i,
    112                 "] = ", in.shape().DebugString()));
    113       }
    114       if (in.NumElements() > 0) {
    115         int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
    116         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
    117             in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
    118       }
    119       // TODO(irving): Remove check once !allow_legacy_scalars().
    120       output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1;
    121     }
    122 
    123     TensorShape output_shape(input_shape);
    124     // TODO(irving): Remove rank 0 case once !allow_legacy_scalars().
    125     if (output_shape.dims() == 0) {
    126       output_shape.AddDim(output_concat_dim);
    127     } else {
    128       output_shape.set_dim(axis, output_concat_dim);
    129     }
    130     Tensor* output = nullptr;
    131     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
    132     if (output->NumElements() > 0) {
    133       int64 output_dim1 = output->NumElements() / inputs_flat_dim0;
    134       auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1});
    135 #if GOOGLE_CUDA
    136       if (std::is_same<Device, GPUDevice>::value) {
    137         ConcatGPU<T>(c, inputs_flat, output, &output_flat);
    138         return;
    139       }
    140 #endif  // GOOGLE_CUDA
    141 #ifdef TENSORFLOW_USE_SYCL
    142       if (std::is_same<Device, SYCLDevice>::value) {
    143         ConcatSYCL<T>(c->eigen_sycl_device(), inputs_flat, &output_flat);
    144         return;
    145       }
    146 #endif  // TENSORFLOW_USE_SYCL
    147       ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
    148     }
    149   }
    150 };
    151 
    152 template <typename Device, typename T>
    153 using ConcatOp = ConcatBaseOp<Device, T, NAME_IS_CONCAT_DIM>;
    154 template <typename Device, typename T>
    155 using ConcatV2Op = ConcatBaseOp<Device, T, NAME_IS_AXIS>;
    156 
    157 #define REGISTER_CONCAT(type)                                \
    158   REGISTER_KERNEL_BUILDER(Name("Concat")                     \
    159                               .Device(DEVICE_CPU)            \
    160                               .TypeConstraint<type>("T")     \
    161                               .HostMemory("concat_dim"),     \
    162                           ConcatOp<CPUDevice, type>)         \
    163   REGISTER_KERNEL_BUILDER(Name("ConcatV2")                   \
    164                               .Device(DEVICE_CPU)            \
    165                               .TypeConstraint<type>("T")     \
    166                               .TypeConstraint<int32>("Tidx") \
    167                               .HostMemory("axis"),           \
    168                           ConcatV2Op<CPUDevice, type>)
    169 
    170 TF_CALL_POD_STRING_TYPES(REGISTER_CONCAT);
    171 REGISTER_CONCAT(quint8);
    172 REGISTER_CONCAT(qint8);
    173 REGISTER_CONCAT(quint16);
    174 REGISTER_CONCAT(qint16);
    175 REGISTER_CONCAT(qint32);
    176 
    177 #undef REGISTER_CONCAT
    178 
    179 #if GOOGLE_CUDA
    180 
    181 #define REGISTER_GPU(type)                                   \
    182   REGISTER_KERNEL_BUILDER(Name("Concat")                     \
    183                               .Device(DEVICE_GPU)            \
    184                               .TypeConstraint<type>("T")     \
    185                               .HostMemory("concat_dim"),     \
    186                           ConcatOp<GPUDevice, type>)         \
    187   REGISTER_KERNEL_BUILDER(Name("ConcatV2")                   \
    188                               .Device(DEVICE_GPU)            \
    189                               .TypeConstraint<type>("T")     \
    190                               .TypeConstraint<int32>("Tidx") \
    191                               .HostMemory("axis"),           \
    192                           ConcatV2Op<GPUDevice, type>)
    193 
    194 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
    195 REGISTER_GPU(bfloat16);
    196 TF_CALL_complex64(REGISTER_GPU);
    197 TF_CALL_complex128(REGISTER_GPU);
    198 TF_CALL_int64(REGISTER_GPU);
    199 REGISTER_GPU(bool);
    200 #undef REGISTER_GPU
    201 
    202 // A special GPU kernel for int32.
    203 // TODO(b/25387198): Also enable int32 in device memory. This kernel
    204 // registration requires all int32 inputs and outputs to be in host memory.
    205 REGISTER_KERNEL_BUILDER(Name("Concat")
    206                             .Device(DEVICE_GPU)
    207                             .TypeConstraint<int32>("T")
    208                             .HostMemory("concat_dim")
    209                             .HostMemory("values")
    210                             .HostMemory("output"),
    211                         ConcatOp<CPUDevice, int32>);
    212 REGISTER_KERNEL_BUILDER(Name("ConcatV2")
    213                             .Device(DEVICE_GPU)
    214                             .TypeConstraint<int32>("T")
    215                             .TypeConstraint<int32>("Tidx")
    216                             .HostMemory("values")
    217                             .HostMemory("axis")
    218                             .HostMemory("output"),
    219                         ConcatV2Op<CPUDevice, int32>);
    220 
    221 #endif  // GOOGLE_CUDA
    222 
    223 #ifdef TENSORFLOW_USE_SYCL
    224 #define REGISTER_SYCL(type)                                  \
    225   REGISTER_KERNEL_BUILDER(Name("Concat")                     \
    226                               .Device(DEVICE_SYCL)           \
    227                               .TypeConstraint<type>("T")     \
    228                               .HostMemory("concat_dim"),     \
    229                           ConcatOp<SYCLDevice, type>)        \
    230   REGISTER_KERNEL_BUILDER(Name("ConcatV2")                   \
    231                               .Device(DEVICE_SYCL)           \
    232                               .TypeConstraint<type>("T")     \
    233                               .TypeConstraint<int32>("Tidx") \
    234                               .HostMemory("axis"),           \
    235                           ConcatV2Op<SYCLDevice, type>)
    236 
    237 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL);
    238 
    239 REGISTER_KERNEL_BUILDER(Name("Concat")
    240                             .Device(DEVICE_SYCL)
    241                             .TypeConstraint<int32>("T")
    242                             .HostMemory("concat_dim")
    243                             .HostMemory("values")
    244                             .HostMemory("output"),
    245                         ConcatOp<CPUDevice, int32>);
    246 REGISTER_KERNEL_BUILDER(Name("ConcatV2")
    247                             .Device(DEVICE_SYCL)
    248                             .TypeConstraint<int32>("T")
    249                             .TypeConstraint<int32>("Tidx")
    250                             .HostMemory("values")
    251                             .HostMemory("axis")
    252                             .HostMemory("output"),
    253                         ConcatV2Op<CPUDevice, int32>);
    254 
    255 #undef REGISTER_SYCL
    256 #endif  // TENSORFLOW_USE_SYCL
    257 
    258 class ConcatOffsetOp : public OpKernel {
    259  public:
    260   explicit ConcatOffsetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    261 
    262   void Compute(OpKernelContext* ctx) override {
    263     const Tensor& concat_dim = ctx->input(0);
    264     OP_REQUIRES(
    265         ctx, IsLegacyScalar(concat_dim.shape()),
    266         errors::InvalidArgument(
    267             "Concat dim tensor should be a scalar integer, but got shape ",
    268             concat_dim.shape().DebugString()));
    269     for (int i = 1; i < ctx->num_inputs(); ++i) {
    270       const Tensor& inp = ctx->input(i);
    271       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(inp.shape()),
    272                   errors::InvalidArgument("input ", i,
    273                                           " should be a vector, but got shape ",
    274                                           inp.shape().DebugString()));
    275     }
    276     // Suppose a Concat() op needs to Concatenate N tensors, each of
    277     // which has the same number of dimensions.  Their shapes match
    278     // except the concat dimension.
    279     //
    280     // E.g., say, we want to concatenate 3 tensors in the 2nd
    281     // dimension, and their shapes are:
    282     //
    283     //  [2, 2, 5, 7]
    284     //  [2, 3, 5, 7]
    285     //  [2, 4, 5, 7]
    286     //
    287     // Here, N=3, cdim=1, dims=4. The concatenated tensor has shape
    288     // [2,9,5,7]. We will compute the cumulative sum along the 2nd
    289     // dimension to figure out each input's offset in the concatenated
    290     // output:
    291     //  [0, 0, 0, 0]
    292     //  [0, 2, 0, 0]
    293     //  [0, 5, 0, 0]
    294     const int32 N = ctx->num_inputs() - 1;
    295     const Tensor& inp0 = ctx->input(1);
    296     auto inp0_vec = inp0.vec<int32>();
    297     const int64 cdim = internal::SubtleMustCopy(concat_dim.scalar<int32>()());
    298     const int64 dims = inp0.NumElements();
    299     int32 axis = cdim < 0 ? cdim + dims : cdim;
    300     OP_REQUIRES(ctx, FastBoundsCheck(axis, dims),
    301                 errors::InvalidArgument("Concat dim is out of range: ", cdim,
    302                                         " vs. ", dims));
    303     int32 offset = 0;
    304     for (int i = 0; i < N; ++i) {
    305       const Tensor& inp = ctx->input(1 + i);
    306       OP_REQUIRES(
    307           ctx, dims == inp.NumElements(),
    308           errors::InvalidArgument("input ", i, " should contain ", dims,
    309                                   " elements, but got ", inp.NumElements()));
    310       auto inp_vec = inp.vec<int32>();
    311       Tensor* out = nullptr;
    312       OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out));
    313       auto out_vec = out->vec<int32>();
    314       for (int64 j = 0; j < dims; ++j) {
    315         if (j == axis) {
    316           out_vec(j) = offset;
    317           offset += inp_vec(j);
    318         } else {
    319           OP_REQUIRES(ctx, (inp0_vec(j) == inp_vec(j)),
    320                       errors::InvalidArgument(
    321                           "All dimensions except ", axis, " must match. Input ",
    322                           i, " has shape [", inp.SummarizeValue(10),
    323                           "] and doesn't match input 0 with shape [",
    324                           inp0.SummarizeValue(10), "]."));
    325           out_vec(j) = 0;
    326         }
    327       }
    328     }
    329   }
    330 
    331   bool IsExpensive() override { return false; }
    332 };
    333 
    334 REGISTER_KERNEL_BUILDER(Name("ConcatOffset").Device(DEVICE_CPU),
    335                         ConcatOffsetOp);
    336 
    337 REGISTER_KERNEL_BUILDER(Name("ConcatOffset")
    338                             .Device(DEVICE_GPU)
    339                             .HostMemory("concat_dim")
    340                             .HostMemory("shape")
    341                             .HostMemory("offset"),
    342                         ConcatOffsetOp);
    343 
    344 #ifdef TENSORFLOW_USE_SYCL
    345 REGISTER_KERNEL_BUILDER(Name("ConcatOffset")
    346                             .Device(DEVICE_SYCL)
    347                             .HostMemory("concat_dim")
    348                             .HostMemory("shape")
    349                             .HostMemory("offset"),
    350                         ConcatOffsetOp);
    351 #endif  // TENSORFLOW_USE_SYCL
    352 }  // namespace tensorflow
    353