Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc
     17 #define EIGEN_USE_THREADS
     18 
     19 #include "tensorflow/core/kernels/reverse_op.h"
     20 #include <memory>
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_shape.h"
     26 #include "tensorflow/core/framework/type_traits.h"
     27 #include "tensorflow/core/framework/types.h"
     28 #include "tensorflow/core/kernels/bounds_check.h"
     29 #include "tensorflow/core/lib/core/status.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/util/work_sharder.h"
     32 
     33 namespace tensorflow {
     34 
     35 typedef Eigen::ThreadPoolDevice CPUDevice;
     36 typedef Eigen::GpuDevice GPUDevice;
     37 #ifdef TENSORFLOW_USE_SYCL
     38 typedef Eigen::SyclDevice SYCLDevice;
     39 #endif  // TENSORFLOW_USE_SYCL
     40 
     41 namespace {
     42 
     43 // Reverse rows (middle dimension) of a three dimensional tensor.
     44 // NUM_CHANNELS can be <= 0 to compute it dynamically from <input>
     45 // Otherwise, it must equal input.dim_size(2) and is used as a compile-time
     46 // constant.
     47 template <typename T, int NUM_CHANNELS>
     48 void ReverseRows(OpKernelContext* context, const Tensor& input,
     49                  Tensor* result) {
     50   auto work = [&input, result](int64 start, int64 end) {
     51     const int64 inner_size =
     52         NUM_CHANNELS > 0 ? NUM_CHANNELS : input.dim_size(2);
     53     const int64 middle_size = input.dim_size(1);
     54     const int64 row_size = inner_size * middle_size;
     55     DCHECK_EQ(input.dim_size(2), inner_size);
     56 
     57     const T* in_ptr = input.bit_casted_tensor<T, 3>().data();
     58     T* out_ptr = result->bit_casted_tensor<T, 3>().data();
     59 
     60     in_ptr += start * row_size;
     61     out_ptr += start * row_size;
     62 
     63     for (int outer_dim = start; outer_dim < end; ++outer_dim) {
     64       out_ptr += row_size;
     65       int remaining = middle_size;
     66       while (remaining > 0) {
     67         out_ptr -= inner_size;
     68         memcpy(out_ptr, in_ptr, inner_size * sizeof(T));
     69         in_ptr += inner_size;
     70         --remaining;
     71       }
     72 
     73       out_ptr += row_size;
     74     }
     75   };
     76 
     77   // Shard across outer dimension.
     78   const int64 N = input.dim_size(0);
     79   const int64 cost_per_unit = input.NumElements() / N;
     80   auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
     81   Shard(worker_threads->num_threads, worker_threads->workers, N, cost_per_unit,
     82         std::move(work));
     83 }
     84 
     85 template <typename T>
     86 struct data_type_can_memcpy {
     87   static constexpr bool value =
     88       std::is_same<T, uint8>::value || std::is_same<T, int8>::value ||
     89       std::is_same<T, bool>::value || std::is_same<T, uint16>::value ||
     90       std::is_same<T, int16>::value || std::is_same<T, Eigen::half>::value ||
     91       std::is_same<T, int32>::value || std::is_same<T, float>::value ||
     92       std::is_same<T, int64>::value || std::is_same<T, double>::value ||
     93       std::is_same<T, complex64>::value || std::is_same<T, complex128>::value;
     94 };
     95 
     96 template <typename T, int NUM_CHANNELS>
     97 typename std::enable_if<data_type_can_memcpy<T>::value>::type
     98 DoHandleReverseCase(OpKernelContext* context, const Tensor& input,
     99                     Tensor* result) {
    100   if (sizeof(T) == 1) {
    101     static_assert(sizeof(uint8) == 1, "uint8 must be 1 byte.");
    102     ReverseRows<uint8, NUM_CHANNELS>(context, input, result);
    103   } else if (sizeof(T) == 2) {
    104     static_assert(sizeof(uint16) == 2, "uint16 must be 2 bytes");
    105     ReverseRows<uint16, NUM_CHANNELS>(context, input, result);
    106   } else if (sizeof(T) == 4) {
    107     static_assert(sizeof(uint32) == 4, "uint32 must be 4 bytes");
    108     ReverseRows<uint32, NUM_CHANNELS>(context, input, result);
    109   } else if (sizeof(T) == 8) {
    110     static_assert(sizeof(uint64) == 8, "uint64 must be 8 bytes");
    111     ReverseRows<uint64, NUM_CHANNELS>(context, input, result);
    112   } else if (sizeof(T) == 16) {
    113     static_assert(sizeof(complex128) == 16, "complex128 must be 16 bytes");
    114     ReverseRows<complex128, NUM_CHANNELS>(context, input, result);
    115   } else {
    116     context->CtxFailure(
    117         errors::InvalidArgument("%s has unexpected size of %d bytes",
    118                                 DataTypeString(input.dtype()), sizeof(T)));
    119   }
    120 }
    121 
    122 template <typename T, int NUM_CHANNELS>
    123 typename std::enable_if<!data_type_can_memcpy<T>::value>::type
    124 DoHandleReverseCase(OpKernelContext* context, const Tensor& input,
    125                     Tensor* result) {}
    126 
    127 }  // namespace
    128 
    129 template <typename Device, typename T, int NDIMS>
    130 void HandleReverseCase(OpKernelContext* context,
    131                        typename TTypes<bool, 1>::ConstTensor dims,
    132                        Tensor* result) {
    133   const Tensor& input = context->input(0);
    134 
    135   // Use optimized reverse if possible.
    136   if (NDIMS == 3 && std::is_same<Device, CPUDevice>::value &&
    137       data_type_can_memcpy<T>::value && (!dims(0) && dims(1) && !dims(2))) {
    138     if (input.dim_size(2) == 3) {
    139       DoHandleReverseCase<T, 3>(context, input, result);
    140     } else {
    141       DoHandleReverseCase<T, -1>(context, input, result);
    142     }
    143     return;
    144   }
    145   typename Eigen::array<bool, NDIMS> axes_di;
    146   for (int i = 0; i < NDIMS; i++) {
    147     axes_di[i] = dims(i);
    148   }
    149   functor::Reverse<Device, T, NDIMS>()(context->eigen_device<Device>(),
    150                                        input.tensor<T, NDIMS>(), axes_di,
    151                                        result->tensor<T, NDIMS>());
    152 }
    153 
    154 template <typename Device, typename T>
    155 class ReverseOp : public OpKernel {
    156  public:
    157   explicit ReverseOp(OpKernelConstruction* context) : OpKernel(context) {}
    158 
    159   void Compute(OpKernelContext* context) override {
    160     const Tensor& input = context->input(0);
    161     const Tensor& dims = context->input(1);
    162 
    163     if (TensorShapeUtils::IsScalar(input.shape())) {
    164       context->set_output(0, input);
    165     } else {
    166       const int input_dims = input.dims();
    167       OP_REQUIRES(context, TensorShapeUtils::IsVector(dims.shape()),
    168                   errors::InvalidArgument("'dims' must be 1-dimension, not ",
    169                                           dims.dims()));
    170 
    171       OP_REQUIRES(
    172           context, input_dims == dims.dim_size(0),
    173           errors::InvalidArgument(
    174               "'dims' must have the same number of values as 'input' has "
    175               "dimensions. 'input' has ",
    176               input_dims, "'dims' has ", dims.dim_size(0), " values"));
    177       OP_REQUIRES(context, input_dims <= 8,
    178                   errors::Unimplemented(
    179                       "reverse is not implemented for tensors of rank > 8."));
    180 
    181       Tensor* output = nullptr;
    182       OP_REQUIRES_OK(context,
    183                      context->allocate_output(0, input.shape(), &output));
    184 
    185 #define HANDLE_REVERSE(NDIMS)                                               \
    186   case NDIMS:                                                               \
    187     HandleReverseCase<Device, T, NDIMS>(context, dims.vec<bool>(), output); \
    188     return;
    189 
    190       switch (input_dims) {
    191         HANDLE_REVERSE(0);
    192         HANDLE_REVERSE(1);
    193         HANDLE_REVERSE(2);
    194         HANDLE_REVERSE(3);
    195         HANDLE_REVERSE(4);
    196         HANDLE_REVERSE(5);
    197         HANDLE_REVERSE(6);
    198         HANDLE_REVERSE(7);
    199         HANDLE_REVERSE(8);
    200       }
    201 #undef HANDLE_REVERSE
    202     }
    203   }
    204 };
    205 
    206 template <typename Device, typename T, int NDIMS>
    207 void HandleReverseV2Case(OpKernelContext* context,
    208                          const gtl::ArraySlice<bool>& axes, Tensor* result) {
    209   const Tensor& input = context->input(0);
    210 
    211   // Use optimized reverse if possible.
    212   if (NDIMS == 3 && std::is_same<Device, CPUDevice>::value &&
    213       data_type_can_memcpy<T>::value && (!axes[0] && axes[1] && !axes[2])) {
    214     if (input.dim_size(2) == 3) {
    215       DoHandleReverseCase<T, 3>(context, input, result);
    216     } else {
    217       DoHandleReverseCase<T, -1>(context, input, result);
    218     }
    219     return;
    220   }
    221 
    222   typename Eigen::array<bool, NDIMS> axes_di;
    223   for (int i = 0; i < NDIMS; i++) {
    224     axes_di[i] = axes[i];
    225   }
    226   functor::Reverse<Device, T, NDIMS>()(context->eigen_device<Device>(),
    227                                        input.tensor<T, NDIMS>(), axes_di,
    228                                        result->tensor<T, NDIMS>());
    229 }
    230 
    231 template <typename Device, typename T, typename Tidx>
    232 class ReverseV2Op : public OpKernel {
    233  public:
    234   explicit ReverseV2Op(OpKernelConstruction* context) : OpKernel(context) {}
    235 
    236   void Compute(OpKernelContext* context) override {
    237     const Tensor& input = context->input(0);
    238     const Tensor& sparse_dims = context->input(1);
    239 
    240     if (TensorShapeUtils::IsScalar(input.shape())) {
    241       context->set_output(0, input);
    242     } else {
    243       const int input_dims = input.dims();
    244       const TensorShape& sparse_dims_shape = sparse_dims.shape();
    245       const auto& axes_sparse_flat = sparse_dims.flat<Tidx>();
    246 
    247       OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_dims_shape),
    248                   errors::InvalidArgument("'dims' must be 1-dimension, not ",
    249                                           sparse_dims.dims()));
    250       gtl::InlinedVector<bool, 8> axes_dense(input_dims, false);
    251       for (int dummy = 0; dummy < axes_sparse_flat.size(); dummy++) {
    252         Tidx axis = internal::SubtleMustCopy<Tidx>(axes_sparse_flat(dummy));
    253         Tidx canonical_axis = axis < 0 ? input_dims + axis : axis;
    254         OP_REQUIRES(context, canonical_axis >= 0 && canonical_axis < input_dims,
    255                     errors::InvalidArgument("'axis'[", dummy, "] = ", axis,
    256                                             " is out of valid range [", 0, ", ",
    257                                             input_dims - 1));
    258         OP_REQUIRES(context, !axes_dense[canonical_axis],
    259                     errors::InvalidArgument("axis ", canonical_axis,
    260                                             " specified more than once."));
    261         axes_dense[canonical_axis] = true;
    262       }
    263 
    264       OP_REQUIRES(context, input_dims <= 8,
    265                   errors::Unimplemented(
    266                       "reverse is not implemented for tensors of rank > 8."));
    267 
    268       Tensor* output = nullptr;
    269       OP_REQUIRES_OK(context,
    270                      context->allocate_output(0, input.shape(), &output));
    271 
    272       // TODO(cwhipkey): we can do dimension folding to reduce, e.g., a reverse
    273       // of a single dimension to the dims=3 or dims=2 case, regardless of the
    274       // number of dimensions in the tensor. This would let some ops use faster
    275       // lower-dimension code (and use optimized versions).
    276 
    277 #define HANDLE_REVERSE(NDIMS)                                           \
    278   case NDIMS:                                                           \
    279     HandleReverseV2Case<Device, T, NDIMS>(context, axes_dense, output); \
    280     return;
    281 
    282       switch (input_dims) {
    283         HANDLE_REVERSE(0);
    284         HANDLE_REVERSE(1);
    285         HANDLE_REVERSE(2);
    286         HANDLE_REVERSE(3);
    287         HANDLE_REVERSE(4);
    288         HANDLE_REVERSE(5);
    289         HANDLE_REVERSE(6);
    290         HANDLE_REVERSE(7);
    291         HANDLE_REVERSE(8);
    292       }
    293 #undef HANDLE_REVERSE
    294     }
    295   }
    296 };
    297 
    298 #define REGISTER_KERNELS(T)                                  \
    299   REGISTER_KERNEL_BUILDER(Name("Reverse")                    \
    300                               .Device(DEVICE_CPU)            \
    301                               .TypeConstraint<T>("T")        \
    302                               .HostMemory("dims"),           \
    303                           ReverseOp<CPUDevice, T>)           \
    304   REGISTER_KERNEL_BUILDER(Name("ReverseV2")                  \
    305                               .Device(DEVICE_CPU)            \
    306                               .TypeConstraint<T>("T")        \
    307                               .TypeConstraint<int32>("Tidx") \
    308                               .HostMemory("axis"),           \
    309                           ReverseV2Op<CPUDevice, T, int32>)  \
    310   REGISTER_KERNEL_BUILDER(Name("ReverseV2")                  \
    311                               .Device(DEVICE_CPU)            \
    312                               .TypeConstraint<T>("T")        \
    313                               .TypeConstraint<int64>("Tidx") \
    314                               .HostMemory("axis"),           \
    315                           ReverseV2Op<CPUDevice, T, int64>)
    316 TF_CALL_POD_TYPES(REGISTER_KERNELS);
    317 TF_CALL_string(REGISTER_KERNELS);
    318 #undef REGISTER_KERNELS
    319 
    320 #if GOOGLE_CUDA
    321 
    322 // Forward declarations of the function specializations for GPU (to prevent
    323 // building the GPU versions here, they will be built compiling _gpu.cu.cc).
    324 namespace functor {
    325 #define DECLARE_GPU_SPEC_DIM(T, DIM)                                  \
    326   template <>                                                         \
    327   void Reverse<GPUDevice, T, DIM>::operator()(                        \
    328       const GPUDevice& d, typename TTypes<T, DIM>::ConstTensor input, \
    329       const Eigen::array<bool, DIM>& reverse_dims,                    \
    330       typename TTypes<T, DIM>::Tensor output);                        \
    331   extern template struct Reverse<GPUDevice, T, DIM>;
    332 #define DECLARE_GPU_SPEC(T)  \
    333   DECLARE_GPU_SPEC_DIM(T, 0) \
    334   DECLARE_GPU_SPEC_DIM(T, 1) \
    335   DECLARE_GPU_SPEC_DIM(T, 2) \
    336   DECLARE_GPU_SPEC_DIM(T, 3) \
    337   DECLARE_GPU_SPEC_DIM(T, 4) \
    338   DECLARE_GPU_SPEC_DIM(T, 5) \
    339   DECLARE_GPU_SPEC_DIM(T, 6) \
    340   DECLARE_GPU_SPEC_DIM(T, 7) \
    341   DECLARE_GPU_SPEC_DIM(T, 8)
    342 
    343 TF_CALL_uint8(DECLARE_GPU_SPEC);
    344 TF_CALL_int8(DECLARE_GPU_SPEC);
    345 TF_CALL_bool(DECLARE_GPU_SPEC);
    346 TF_CALL_half(DECLARE_GPU_SPEC);
    347 TF_CALL_float(DECLARE_GPU_SPEC);
    348 TF_CALL_double(DECLARE_GPU_SPEC);
    349 TF_CALL_complex64(DECLARE_GPU_SPEC);
    350 TF_CALL_complex128(DECLARE_GPU_SPEC);
    351 #undef DECLARE_GPU_SPEC
    352 #undef DECLARE_GPU_SPEC_DIM
    353 }  // namespace functor
    354 
    355 // Registration of the GPU implementations.
    356 #define REGISTER_GPU_KERNELS(T)                              \
    357   REGISTER_KERNEL_BUILDER(Name("Reverse")                    \
    358                               .Device(DEVICE_GPU)            \
    359                               .TypeConstraint<T>("T")        \
    360                               .HostMemory("dims"),           \
    361                           ReverseOp<GPUDevice, T>)           \
    362   REGISTER_KERNEL_BUILDER(Name("ReverseV2")                  \
    363                               .Device(DEVICE_GPU)            \
    364                               .TypeConstraint<T>("T")        \
    365                               .TypeConstraint<int32>("Tidx") \
    366                               .HostMemory("axis"),           \
    367                           ReverseV2Op<GPUDevice, T, int32>)  \
    368   REGISTER_KERNEL_BUILDER(Name("ReverseV2")                  \
    369                               .Device(DEVICE_GPU)            \
    370                               .TypeConstraint<T>("T")        \
    371                               .TypeConstraint<int64>("Tidx") \
    372                               .HostMemory("axis"),           \
    373                           ReverseV2Op<GPUDevice, T, int64>)
    374 TF_CALL_uint8(REGISTER_GPU_KERNELS);
    375 TF_CALL_int8(REGISTER_GPU_KERNELS);
    376 // TODO decide whether we want to enable the bool kernel.
    377 // TF_CALL_bool(REGISTER_GPU_KERNELS);
    378 TF_CALL_half(REGISTER_GPU_KERNELS);
    379 TF_CALL_float(REGISTER_GPU_KERNELS);
    380 TF_CALL_double(REGISTER_GPU_KERNELS);
    381 TF_CALL_complex64(REGISTER_GPU_KERNELS);
    382 TF_CALL_complex128(REGISTER_GPU_KERNELS);
    383 #undef REGISTER_GPU_KERNEL
    384 
    385 // A special GPU kernel for int32.
    386 // TODO(b/25387198): Also enable int32 in device memory. This kernel
    387 // registration requires all int32 inputs and outputs to be in host memory.
    388 REGISTER_KERNEL_BUILDER(Name("Reverse")
    389                             .Device(DEVICE_GPU)
    390                             .TypeConstraint<int32>("T")
    391                             .HostMemory("tensor")
    392                             .HostMemory("dims")
    393                             .HostMemory("output"),
    394                         ReverseOp<CPUDevice, int32>);
    395 REGISTER_KERNEL_BUILDER(Name("ReverseV2")
    396                             .Device(DEVICE_GPU)
    397                             .TypeConstraint<int32>("T")
    398                             .TypeConstraint<int32>("Tidx")
    399                             .HostMemory("tensor")
    400                             .HostMemory("axis")
    401                             .HostMemory("output"),
    402                         ReverseV2Op<CPUDevice, int32, int32>);
    403 REGISTER_KERNEL_BUILDER(Name("ReverseV2")
    404                             .Device(DEVICE_GPU)
    405                             .TypeConstraint<int32>("T")
    406                             .TypeConstraint<int64>("Tidx")
    407                             .HostMemory("tensor")
    408                             .HostMemory("axis")
    409                             .HostMemory("output"),
    410                         ReverseV2Op<CPUDevice, int32, int64>);
    411 #endif  // GOOGLE_CUDA
    412 
    413 #ifdef TENSORFLOW_USE_SYCL
    414 #define REGISTER_SYCL_KERNELS(T)                             \
    415   REGISTER_KERNEL_BUILDER(Name("Reverse")                    \
    416                               .Device(DEVICE_SYCL)           \
    417                               .TypeConstraint<T>("T")        \
    418                               .HostMemory("dims"),           \
    419                           ReverseOp<SYCLDevice, T>)          \
    420   REGISTER_KERNEL_BUILDER(Name("ReverseV2")                  \
    421                               .Device(DEVICE_SYCL)           \
    422                               .TypeConstraint<T>("T")        \
    423                               .TypeConstraint<int32>("Tidx") \
    424                               .HostMemory("axis"),           \
    425                           ReverseV2Op<SYCLDevice, T, int32>) \
    426   REGISTER_KERNEL_BUILDER(Name("ReverseV2")                  \
    427                               .Device(DEVICE_SYCL)           \
    428                               .TypeConstraint<T>("T")        \
    429                               .TypeConstraint<int64>("Tidx") \
    430                               .HostMemory("axis"),           \
    431                           ReverseV2Op<SYCLDevice, T, int64>)
    432 TF_CALL_uint8(REGISTER_SYCL_KERNELS);
    433 TF_CALL_int8(REGISTER_SYCL_KERNELS);
    434 TF_CALL_float(REGISTER_SYCL_KERNELS);
    435 TF_CALL_double(REGISTER_SYCL_KERNELS);
    436 
    437 REGISTER_KERNEL_BUILDER(Name("Reverse")
    438                             .Device(DEVICE_SYCL)
    439                             .TypeConstraint<int32>("T")
    440                             .HostMemory("tensor")
    441                             .HostMemory("dims")
    442                             .HostMemory("output"),
    443                         ReverseOp<CPUDevice, int32>);
    444 REGISTER_KERNEL_BUILDER(Name("ReverseV2")
    445                             .Device(DEVICE_SYCL)
    446                             .TypeConstraint<int32>("T")
    447                             .TypeConstraint<int32>("Tidx")
    448                             .HostMemory("tensor")
    449                             .HostMemory("axis")
    450                             .HostMemory("output"),
    451                         ReverseV2Op<CPUDevice, int32, int32>);
    452 REGISTER_KERNEL_BUILDER(Name("ReverseV2")
    453                             .Device(DEVICE_SYCL)
    454                             .TypeConstraint<int32>("T")
    455                             .TypeConstraint<int64>("Tidx")
    456                             .HostMemory("tensor")
    457                             .HostMemory("axis")
    458                             .HostMemory("output"),
    459                         ReverseV2Op<CPUDevice, int32, int64>);
    460 #endif  // TENSORFLOW_USE_SYCL
    461 }  // namespace tensorflow
    462