Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 
     18 #include <limits>
     19 #include <vector>
     20 
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_types.h"
     26 #include "tensorflow/core/framework/types.h"
     27 #include "tensorflow/core/kernels/concat_lib.h"
     28 #include "tensorflow/core/lib/core/status.h"
     29 #include "tensorflow/core/platform/types.h"
     30 
     31 namespace tensorflow {
     32 
     33 typedef Eigen::ThreadPoolDevice CPUDevice;
     34 #if GOOGLE_CUDA
     35 typedef Eigen::GpuDevice GPUDevice;
     36 #endif  // GOOGLE_CUDA
     37 #ifdef TENSORFLOW_USE_SYCL
     38 typedef Eigen::SyclDevice SYCLDevice;
     39 #endif  // TENSORFLOW_USE_SYCL
     40 
     41 // --------------------------------------------------------------------------
     42 template <typename Device, typename T>
     43 class PackOp : public OpKernel {
     44  public:
     45   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
     46       ConstMatrixVector;
     47 
     48   explicit PackOp(OpKernelConstruction* context) : OpKernel(context) {
     49     OP_REQUIRES_OK(context, context->GetAttr("axis", &axis_));
     50   }
     51 
     52   void Compute(OpKernelContext* c) override {
     53     OpInputList values;
     54     OP_REQUIRES_OK(c, c->input_list("values", &values));
     55     const int num = values.size();
     56 
     57     // Verify that all input shapes match
     58     for (int i = 1; i < num; i++) {
     59       OP_REQUIRES(c, values[0].shape().IsSameSize(values[i].shape()),
     60                   errors::InvalidArgument(
     61                       "Shapes of all inputs must match: values[0].shape = ",
     62                       values[0].shape().DebugString(), " != values[", i,
     63                       "].shape = ", values[i].shape().DebugString()));
     64     }
     65 
     66     int expanded_num_dims = values[0].dims() + 1;
     67     int axis = axis_;
     68     if (axis < 0) axis += expanded_num_dims;
     69 
     70     OP_REQUIRES(c, 0 <= axis && axis < expanded_num_dims,
     71                 errors::InvalidArgument("axis = ", axis_, " not in [",
     72                                         -expanded_num_dims, ", ",
     73                                         expanded_num_dims, ")"));
     74 
     75     TensorShape output_shape(values[0].shape());
     76     output_shape.InsertDim(axis, num);
     77 
     78     // In the num = 1 case, just reshape the input
     79     if (num == 1) {
     80       Tensor output;
     81       CHECK(output.CopyFrom(values[0], output_shape));
     82       c->set_output(0, output);
     83       return;
     84     }
     85 
     86     // Allocate output
     87     Tensor* output;
     88     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
     89 
     90     int64 before_dim = 1;
     91     for (int i = 0; i < axis; ++i) {
     92       before_dim *= output_shape.dim_size(i);
     93     }
     94 
     95     int64 after_dim = 1;
     96     for (int i = axis + 1; i < output_shape.dims(); ++i) {
     97       after_dim *= output_shape.dim_size(i);
     98     }
     99 
    100     const int64 axis_dim = output_shape.dim_size(axis);
    101 
    102     const int64 output_size = output->NumElements();
    103     if (output_size > 0) {
    104       auto output_flat =
    105           output->shaped<T, 2>({before_dim, after_dim * axis_dim});
    106 
    107       // Except for shapes, pack is a special case of concat, so we reuse the
    108       // same computational kernels.
    109       ConstMatrixVector inputs_flat;
    110       inputs_flat.reserve(num);
    111       for (int i = 0; i < num; ++i) {
    112         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
    113             values[i].shaped<T, 2>({before_dim, after_dim})));
    114       }
    115 #if GOOGLE_CUDA
    116       if (std::is_same<Device, GPUDevice>::value) {
    117         ConcatGPU<T>(c, inputs_flat, output, &output_flat);
    118         return;
    119       }
    120 #endif  // GOOGLE_CUDA
    121 #ifdef TENSORFLOW_USE_SYCL
    122       if (std::is_same<Device, SYCLDevice>::value) {
    123         ConcatSYCL<T>(c->eigen_sycl_device(), inputs_flat, &output_flat);
    124         return;
    125       }
    126 #endif  // TENSORFLOW_USE_SYCL
    127       ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
    128     }
    129   }
    130 
    131  private:
    132   int axis_;
    133 };
    134 
    135 #define REGISTER_PACK(type)                                      \
    136   REGISTER_KERNEL_BUILDER(                                       \
    137       Name("Pack").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    138       PackOp<CPUDevice, type>)
    139 
    140 TF_CALL_ALL_TYPES(REGISTER_PACK);
    141 TF_CALL_QUANTIZED_TYPES(REGISTER_PACK);
    142 
    143 #if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION)
    144 // Primarily used for SavedModel support on mobile.
    145 REGISTER_PACK(string);
    146 #endif  // defined(IS_MOBILE_PLATFORM) &&
    147         // !defined(SUPPORT_SELECTIVE_REGISTRATION)
    148 
    149 #undef REGISTER_PACK
    150 
    151 #if GOOGLE_CUDA
    152 
    153 #define REGISTER_GPU(type)                                       \
    154   REGISTER_KERNEL_BUILDER(                                       \
    155       Name("Pack").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    156       PackOp<GPUDevice, type>)
    157 
    158 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
    159 TF_CALL_bfloat16(REGISTER_GPU);
    160 TF_CALL_int64(REGISTER_GPU);
    161 REGISTER_GPU(bool);
    162 #undef REGISTER_GPU
    163 
    164 // A special GPU kernel for int32.
    165 // TODO(b/25387198): Also enable int32 in device memory. This kernel
    166 // registration requires all int32 inputs and outputs to be in host memory.
    167 REGISTER_KERNEL_BUILDER(Name("Pack")
    168                             .Device(DEVICE_GPU)
    169                             .HostMemory("values")
    170                             .HostMemory("output")
    171                             .TypeConstraint<int32>("T"),
    172                         PackOp<CPUDevice, int32>);
    173 
    174 #endif  // GOOGLE_CUDA
    175 
    176 #ifdef TENSORFLOW_USE_SYCL
    177 #define REGISTER_SYCL(type)                                       \
    178   REGISTER_KERNEL_BUILDER(                                        \
    179       Name("Pack").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
    180       PackOp<SYCLDevice, type>)
    181 
    182 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL);
    183 REGISTER_KERNEL_BUILDER(Name("Pack")
    184                             .Device(DEVICE_SYCL)
    185                             .HostMemory("values")
    186                             .HostMemory("output")
    187                             .TypeConstraint<int32>("T"),
    188                         PackOp<CPUDevice, int32>);
    189 #undef REGISTER_SYCL
    190 #endif  // TENSORFLOW_USE_SYCL
    191 }  // namespace tensorflow
    192