Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #if GOOGLE_CUDA
     21 #define EIGEN_USE_GPU
     22 #endif  // GOOGLE_CUDA
     23 
     24 #include "tensorflow/core/kernels/reverse_sequence_op.h"
     25 
     26 #include <memory>
     27 #include <vector>
     28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     29 #include "tensorflow/core/framework/op_kernel.h"
     30 #include "tensorflow/core/framework/register_types.h"
     31 #include "tensorflow/core/framework/tensor.h"
     32 #include "tensorflow/core/framework/tensor_shape.h"
     33 #include "tensorflow/core/framework/tensor_types.h"
     34 #include "tensorflow/core/framework/types.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 #include "tensorflow/core/platform/macros.h"
     37 
     38 namespace tensorflow {
     39 
     40 typedef Eigen::ThreadPoolDevice CPUDevice;
     41 typedef Eigen::GpuDevice GPUDevice;
     42 
     43 template <typename Device, typename Tlen>
     44 void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
     45   const Tensor& input = context->input(0);
     46   const Tensor& seq_lens = context->input(1);
     47 
     48   auto seq_lens_t = seq_lens.vec<Tlen>();
     49 
     50   std::vector<Tlen> seq_lens_vec(seq_lens_t.size());
     51 
     52   // Copy seq_len info down for validity checks
     53   context->eigen_device<Device>().memcpyDeviceToHost(
     54       seq_lens_vec.data(), seq_lens_t.data(), sizeof(Tlen) * seq_lens_t.size());
     55 
     56   OP_REQUIRES(context, batch_dim != seq_dim,
     57               errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
     58   OP_REQUIRES(context, seq_dim < input.dims(),
     59               errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
     60                                       seq_dim, " vs. ", input.dims(), ")"));
     61   OP_REQUIRES(context, batch_dim < input.dims(),
     62               errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
     63                                       batch_dim, " vs. ", input.dims(), ")"));
     64   OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
     65               errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
     66                                       "), ", "(", seq_lens.NumElements(),
     67                                       " vs. ", input.dim_size(batch_dim)));
     68 
     69   for (size_t d = 0; d < seq_lens_vec.size(); ++d) {
     70     OP_REQUIRES(context, seq_lens_vec[d] >= 0,
     71                 errors::InvalidArgument("seq_lens(", d, ") < 0"));
     72     OP_REQUIRES(context, seq_lens_vec[d] <= input.dim_size(seq_dim),
     73                 errors::InvalidArgument("seq_lens(", d, ") > input.dims(",
     74                                         seq_dim, ")"));
     75   }
     76 }
     77 
     78 void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) {
     79   const Tensor& input = context->input(0);
     80   const Tensor& seq_lens = context->input(1);
     81 
     82   OP_REQUIRES(context, batch_dim != seq_dim,
     83               errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
     84   OP_REQUIRES(context, seq_dim < input.dims(),
     85               errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
     86                                       seq_dim, " vs. ", input.dims(), ")"));
     87   OP_REQUIRES(context, batch_dim < input.dims(),
     88               errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
     89                                       batch_dim, " vs. ", input.dims(), ")"));
     90 
     91   OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
     92               errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
     93                                       "), ", "(", seq_lens.NumElements(),
     94                                       " vs. ", input.dim_size(batch_dim)));
     95 }
     96 
     97 template <>
     98 void CheckErrors<GPUDevice, int32>(OpKernelContext* context, int batch_dim,
     99                                    int seq_dim) {
    100   CheckErrorsGPU(context, batch_dim, seq_dim);
    101 }
    102 
    103 template <>
    104 void CheckErrors<GPUDevice, int64>(OpKernelContext* context, int batch_dim,
    105                                    int seq_dim) {
    106   CheckErrorsGPU(context, batch_dim, seq_dim);
    107 }
    108 
    109 template <typename Device, typename T, typename Tlen>
    110 class ReverseSequenceOp : public OpKernel {
    111  public:
    112   explicit ReverseSequenceOp(OpKernelConstruction* context)
    113       : OpKernel(context) {
    114     OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_));
    115     OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_));
    116   }
    117 
    118   void Compute(OpKernelContext* context) override {
    119     const Tensor& input = context->input(0);
    120     const Tensor& seq_lens = context->input(1);
    121 
    122     // Preliminary validation of sizes.
    123     OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens.shape()),
    124                 errors::InvalidArgument("seq_lens input must be 1-dim, not ",
    125                                         seq_lens.dims()));
    126 
    127     auto seq_lens_t = seq_lens.vec<Tlen>();
    128 
    129     CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_);
    130 
    131     const int input_dims = input.dims();
    132 
    133     Tensor* output = nullptr;
    134     OP_REQUIRES_OK(context,
    135                    context->allocate_output(0, input.shape(), &output));
    136 
    137 #define HANDLE_DIM(NDIM)                                                      \
    138   case NDIM:                                                                  \
    139     functor::ReverseSequence<Device, T, Tlen, NDIM>::Compute(                 \
    140         context->eigen_device<Device>(), input.tensor<T, NDIM>(), batch_dim_, \
    141         seq_dim_, seq_lens_t, output->tensor<T, NDIM>());                     \
    142     break;
    143 
    144     switch (input_dims) {
    145       HANDLE_DIM(2);
    146       HANDLE_DIM(3);
    147       HANDLE_DIM(4);
    148       HANDLE_DIM(5);
    149 
    150       default:
    151         OP_REQUIRES(context, false,
    152                     errors::InvalidArgument(
    153                         "ReverseSequenceOp : Unhandled input dimensions: ",
    154                         input_dims));
    155     }
    156   }
    157 
    158  private:
    159   int32 batch_dim_;
    160   int32 seq_dim_;
    161 
    162   TF_DISALLOW_COPY_AND_ASSIGN(ReverseSequenceOp);
    163 };
    164 
    165 #define REGISTER_REVERSE_SEQUENCE(type, len_type)                \
    166   REGISTER_KERNEL_BUILDER(Name("ReverseSequence")                \
    167                               .Device(DEVICE_CPU)                \
    168                               .TypeConstraint<type>("T")         \
    169                               .TypeConstraint<len_type>("Tlen"), \
    170                           ReverseSequenceOp<CPUDevice, type, len_type>);
    171 
    172 #define REGISTER_REVERSE_SEQUENCE_LEN(type) \
    173   REGISTER_REVERSE_SEQUENCE(type, int32);   \
    174   REGISTER_REVERSE_SEQUENCE(type, int64);
    175 
    176 TF_CALL_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_LEN);
    177 TF_CALL_bool(REGISTER_REVERSE_SEQUENCE_LEN);
    178 
    179 #if GOOGLE_CUDA
    180 
    181 // Forward declarations of the functor specializations for GPU.
    182 namespace functor {
    183 #define DECLARE_GPU_SPEC(T, Tlen, Dims)                                \
    184   template <>                                                          \
    185   void ReverseSequence<GPUDevice, T, Tlen, Dims>::Compute(             \
    186       const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
    187       int32 batch_dim, int32 seq_dim,                                  \
    188       typename TTypes<Tlen>::ConstVec seq_lens,                        \
    189       typename TTypes<T, Dims>::Tensor output);                        \
    190   extern template struct ReverseSequence<GPUDevice, T, Tlen, Dims>;
    191 
    192 #define DECLARE_GPU_SPEC_LEN(T, Dims) \
    193   DECLARE_GPU_SPEC(T, int32, Dims);   \
    194   DECLARE_GPU_SPEC(T, int64, Dims);
    195 
    196 #define DECLARE_GPU_SPECS(T)  \
    197   DECLARE_GPU_SPEC_LEN(T, 2); \
    198   DECLARE_GPU_SPEC_LEN(T, 3); \
    199   DECLARE_GPU_SPEC_LEN(T, 4); \
    200   DECLARE_GPU_SPEC_LEN(T, 5);
    201 
    202 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
    203 TF_CALL_bool(DECLARE_GPU_SPECS);
    204 
    205 }  // namespace functor
    206 
    207 // Registration of the GPU implementations.
    208 #define REGISTER_REVERSE_SEQUENCE_GPU(type, len_type)            \
    209   REGISTER_KERNEL_BUILDER(Name("ReverseSequence")                \
    210                               .Device(DEVICE_GPU)                \
    211                               .TypeConstraint<type>("T")         \
    212                               .TypeConstraint<len_type>("Tlen"), \
    213                           ReverseSequenceOp<GPUDevice, type, len_type>);
    214 
    215 #define REGISTER_REVERSE_SEQUENCE_GPU_LEN(type) \
    216   REGISTER_REVERSE_SEQUENCE_GPU(type, int32);   \
    217   REGISTER_REVERSE_SEQUENCE_GPU(type, int64);
    218 
    219 TF_CALL_GPU_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_GPU_LEN);
    220 TF_CALL_bool(REGISTER_REVERSE_SEQUENCE_GPU_LEN);
    221 
    222 #undef REGISTER_REVERSE_SEQUENCE_GPU
    223 
    224 #endif  // GOOGLE_CUDA
    225 
    226 }  // namespace tensorflow
    227