Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_STRIDED_SLICE_OP_IMPL_H_
     17 #define TENSORFLOW_KERNELS_STRIDED_SLICE_OP_IMPL_H_
     18 
     19 // Functor definition for StridedSliceOp, must be compilable by nvcc.
     20 
     21 #include "tensorflow/core/kernels/slice_op.h"
     22 #include "tensorflow/core/kernels/strided_slice_op.h"
     23 
     24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/framework/register_types_traits.h"
     28 #include "tensorflow/core/framework/tensor.h"
     29 #include "tensorflow/core/framework/variant.h"
     30 #include "tensorflow/core/framework/variant_encode_decode.h"
     31 #include "tensorflow/core/kernels/bounds_check.h"
     32 #include "tensorflow/core/kernels/dense_update_functor.h"
     33 #include "tensorflow/core/kernels/ops_util.h"
     34 #include "tensorflow/core/lib/core/status.h"
     35 #include "tensorflow/core/lib/gtl/array_slice.h"
     36 #include "tensorflow/core/platform/mem.h"
     37 
     38 namespace tensorflow {
     39 
     40 template <typename Device, typename T, int NDIM>
     41 void HandleStridedSliceCase(OpKernelContext* context,
     42                             const gtl::ArraySlice<int64>& begin,
     43                             const gtl::ArraySlice<int64>& end,
     44                             const gtl::ArraySlice<int64>& strides,
     45                             const TensorShape& processing_shape,
     46                             bool is_simple_slice, Tensor* result);
     47 
     48 template <typename Device, typename T, int NDIM>
     49 void HandleStridedSliceGradCase(OpKernelContext* context,
     50                                 const gtl::ArraySlice<int64>& begin,
     51                                 const gtl::ArraySlice<int64>& end,
     52                                 const gtl::ArraySlice<int64>& strides,
     53                                 const TensorShape& processing_shape,
     54                                 bool is_simple_slice, Tensor* result);
     55 
     56 template <typename Device, typename T, int NDIM>
     57 class HandleStridedSliceAssignCase {
     58  public:
     59   void operator()(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
     60                   const gtl::ArraySlice<int64>& end,
     61                   const gtl::ArraySlice<int64>& strides,
     62                   const TensorShape& processing_shape, bool is_simple_slice,
     63                   Tensor* result);
     64 };
     65 }  // namespace tensorflow
     66 
     67 // The actual implementation. This is designed so multiple
     68 // translation units can include this file in the form
     69 //
     70 // #define STRIDED_SLICE_INSTANTIATE_DIM 1
     71 // #include <thisfile>
     72 // #undef STRIDED_SLICE_INSTANTIATE_DIM
     73 //
     74 #ifdef STRIDED_SLICE_INSTANTIATE_DIM
     75 
     76 namespace tensorflow {
     77 
     78 template <typename Device, typename T, int NDIM>
     79 void HandleStridedSliceCase(OpKernelContext* context,
     80                             const gtl::ArraySlice<int64>& begin,
     81                             const gtl::ArraySlice<int64>& end,
     82                             const gtl::ArraySlice<int64>& strides,
     83                             const TensorShape& processing_shape,
     84                             bool is_simple_slice, Tensor* result) {
     85   typedef typename proxy_type<Device, T>::type Proxy;
     86 
     87   gtl::InlinedVector<int64, 4> processing_dims = processing_shape.dim_sizes();
     88   if (is_simple_slice) {
     89     Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
     90     Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes_di;
     91     for (int i = 0; i < NDIM; ++i) {
     92       begin_di[i] = begin[i];
     93       sizes_di[i] = end[i] - begin[i];
     94     }
     95     functor::Slice<Device, Proxy, NDIM>()(
     96         context->eigen_device<Device>(),
     97         result->bit_casted_shaped<Proxy, NDIM>(processing_dims),
     98         context->input(0).bit_casted_tensor<Proxy, NDIM>(), begin_di, sizes_di);
     99   } else {
    100     Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
    101     Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di;
    102     Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di;
    103     for (int i = 0; i < NDIM; ++i) {
    104       begin_di[i] = begin[i];
    105       end_di[i] = end[i];
    106       strides_di[i] = strides[i];
    107     }
    108     functor::StridedSlice<Device, Proxy, NDIM>()(
    109         context->eigen_device<Device>(),
    110         result->bit_casted_shaped<Proxy, NDIM>(processing_dims),
    111         context->input(0).bit_casted_tensor<Proxy, NDIM>(), begin_di, end_di,
    112         strides_di);
    113   }
    114 }
    115 
    116 template <typename Device, typename T, int NDIM>
    117 void HandleStridedSliceGradCase(OpKernelContext* context,
    118                                 const gtl::ArraySlice<int64>& begin,
    119                                 const gtl::ArraySlice<int64>& end,
    120                                 const gtl::ArraySlice<int64>& strides,
    121                                 const TensorShape& processing_shape,
    122                                 bool is_simple_slice, Tensor* result) {
    123   gtl::InlinedVector<int64, 4> processing_dims = processing_shape.dim_sizes();
    124 
    125   Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
    126   Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di;
    127   Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di;
    128   for (int i = 0; i < NDIM; ++i) {
    129     begin_di[i] = begin[i];
    130     end_di[i] = end[i];
    131     strides_di[i] = strides[i];
    132   }
    133 
    134   typedef typename proxy_type<Device, T>::type Proxy;
    135   functor::StridedSliceGrad<Device, Proxy, NDIM>()(
    136       context->eigen_device<Device>(), result->bit_casted_tensor<Proxy, NDIM>(),
    137       context->input(4).bit_casted_shaped<Proxy, NDIM>(processing_dims),
    138       begin_di, end_di, strides_di);
    139 }
    140 
    141 template <typename Device, typename T, int NDIM>
    142 void HandleStridedSliceAssignCase<Device, T, NDIM>::operator()(
    143     OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
    144     const gtl::ArraySlice<int64>& end, const gtl::ArraySlice<int64>& strides,
    145     const TensorShape& processing_shape, bool is_simple_slice, Tensor* result) {
    146   gtl::InlinedVector<int64, 4> processing_dims = processing_shape.dim_sizes();
    147   typedef typename proxy_type<Device, T>::type Proxy;
    148   Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di;
    149   Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di;
    150   Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di;
    151   for (int i = 0; i < NDIM; ++i) {
    152     begin_di[i] = begin[i];
    153     end_di[i] = end[i];
    154     strides_di[i] = strides[i];
    155   }
    156   functor::StridedSliceAssign<Device, Proxy, NDIM>()(
    157       context->eigen_device<Device>(), result->bit_casted_tensor<Proxy, NDIM>(),
    158       context->input(4).bit_casted_shaped<Proxy, NDIM>(processing_dims),
    159       begin_di, end_di, strides_di);
    160 }
    161 
    162 template <typename Device, typename T>
    163 class HandleStridedSliceAssignCase<Device, T, 0> {
    164  public:
    165   enum { NDIM_PROXY = 1 };
    166   void operator()(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
    167                   const gtl::ArraySlice<int64>& end,
    168                   const gtl::ArraySlice<int64>& strides,
    169                   const TensorShape& processing_shape, bool is_simple_slice,
    170                   Tensor* result) {
    171     gtl::InlinedVector<int64, 1> processing_dims(1);
    172     processing_dims[0] = 1;
    173 
    174     typedef typename proxy_type<Device, T>::type Proxy;
    175     functor::StridedSliceAssignScalar<Device, Proxy>()(
    176         context->eigen_device<Device>(),
    177         result->bit_casted_shaped<Proxy, 1>(processing_dims),
    178         context->input(4).bit_casted_shaped<Proxy, 1>(processing_dims));
    179   }
    180 };
    181 
    182 // NODE(aselle): according to bsteiner, we need this because otherwise
    183 // nvcc instantiates templates that are invalid. strided_slice_op_gpu.cu
    184 // handles instantiates externally. It is important that this is done#
    185 
    186 // before the HandleXXCase's are instantiated to avoid duplicate
    187 // specialization errors.
    188 
    189 #define PREVENT_INSTANTIATE_DIM1_AND_UP(T, NDIM)                   \
    190   namespace functor {                                              \
    191   template <>                                                      \
    192   void StridedSlice<GPUDevice, T, NDIM>::operator()(               \
    193       const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
    194       typename TTypes<T, NDIM>::ConstTensor input,                 \
    195       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start,         \
    196       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop,          \
    197       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides);      \
    198   extern template struct StridedSlice<GPUDevice, T, NDIM>;         \
    199   template <>                                                      \
    200   void Slice<GPUDevice, T, NDIM>::operator()(                      \
    201       const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
    202       typename TTypes<T, NDIM>::ConstTensor input,                 \
    203       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices,       \
    204       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes);        \
    205   extern template struct Slice<GPUDevice, T, NDIM>;                \
    206   template <>                                                      \
    207   void StridedSliceGrad<GPUDevice, T, NDIM>::operator()(           \
    208       const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
    209       typename TTypes<T, NDIM>::ConstTensor input,                 \
    210       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start,         \
    211       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop,          \
    212       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides);      \
    213   extern template struct StridedSliceGrad<GPUDevice, T, NDIM>;     \
    214   template <>                                                      \
    215   void StridedSliceAssign<GPUDevice, T, NDIM>::operator()(         \
    216       const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
    217       typename TTypes<T, NDIM>::ConstTensor input,                 \
    218       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start,         \
    219       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop,          \
    220       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides);      \
    221   extern template struct StridedSliceAssign<GPUDevice, T, NDIM>;   \
    222   }  // namespace functor
    223 #define PREVENT_INSTANTIATE_DIM0_ONLY(T, NDIM)                   \
    224   namespace functor {                                            \
    225   template <>                                                    \
    226   void StridedSliceAssignScalar<GPUDevice, T>::operator()(       \
    227       const GPUDevice& d, typename TTypes<T, 1>::Tensor output,  \
    228       typename TTypes<T, 1>::ConstTensor input);                 \
    229   extern template struct StridedSliceAssignScalar<GPUDevice, T>; \
    230   }  // namespace functor
    231 
    232 // Dimension 0 only instantiates some functors. So we only need
    233 // to prevent ones defined by PREVENT_INSTANTIATE_DIM0_ONLY
    234 #if GOOGLE_CUDA
    235 #if STRIDED_SLICE_INSTANTIATE_DIM == 0
    236 #define PREVENT_INSTANTIATE(T, NDIM) PREVENT_INSTANTIATE_DIM0_ONLY(T, NDIM)
    237 #else
    238 #define PREVENT_INSTANTIATE(T, NDIM) PREVENT_INSTANTIATE_DIM1_AND_UP(T, NDIM)
    239 #endif
    240 #else
    241 #define PREVENT_INSTANTIATE(T, NDIM)
    242 #endif
    243 
    244 #define INSTANTIATE_DIM1_AND_UP_HANDLERS(DEVICE, T, DIM)              \
    245   template void HandleStridedSliceCase<DEVICE, T, DIM>(               \
    246       OpKernelContext * context, const gtl::ArraySlice<int64>& begin, \
    247       const gtl::ArraySlice<int64>& end,                              \
    248       const gtl::ArraySlice<int64>& strides,                          \
    249       const TensorShape& processing_shape, bool is_simple_slice,      \
    250       Tensor* result);                                                \
    251   template void HandleStridedSliceGradCase<DEVICE, T, DIM>(           \
    252       OpKernelContext * context, const gtl::ArraySlice<int64>& begin, \
    253       const gtl::ArraySlice<int64>& end,                              \
    254       const gtl::ArraySlice<int64>& strides,                          \
    255       const TensorShape& processing_shape, bool is_simple_slice,      \
    256       Tensor* result);
    257 
    258 #define INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) \
    259   template class HandleStridedSliceAssignCase<DEVICE, T, DIM>;
    260 
    261 // Only some kernels need to be instantiated on dim 0.
    262 #if STRIDED_SLICE_INSTANTIATE_DIM == 0
    263 #define INSTANTIATE(DEVICE, T, DIM) \
    264   INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM)
    265 #else
    266 #define INSTANTIATE(DEVICE, T, DIM)                \
    267   INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) \
    268   INSTANTIATE_DIM1_AND_UP_HANDLERS(DEVICE, T, DIM)
    269 #endif
    270 
    271 #define DECLARE_FOR_N_CPU(T) \
    272   INSTANTIATE(CPUDevice, T, STRIDED_SLICE_INSTANTIATE_DIM)
    273 
    274 #define PREVENT_FOR_N_GPU(T) \
    275   PREVENT_INSTANTIATE(T, STRIDED_SLICE_INSTANTIATE_DIM)
    276 
    277 #define DECLARE_FOR_N_GPU(T) \
    278   INSTANTIATE(GPUDevice, T, STRIDED_SLICE_INSTANTIATE_DIM)
    279 
    280 #if GOOGLE_CUDA
    281 TF_CALL_GPU_PROXY_TYPES(PREVENT_FOR_N_GPU);
    282 TF_CALL_complex64(PREVENT_FOR_N_GPU);
    283 TF_CALL_complex128(PREVENT_FOR_N_GPU);
    284 
    285 TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N_GPU);
    286 TF_CALL_complex64(DECLARE_FOR_N_GPU);
    287 TF_CALL_complex128(DECLARE_FOR_N_GPU);
    288 DECLARE_FOR_N_GPU(int32);
    289 DECLARE_FOR_N_GPU(int64);
    290 #endif  // END GOOGLE_CUDA
    291 
    292 TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
    293 
    294 #ifdef TENSORFLOW_USE_SYCL
    295 #define PREVENT_FOR_N_SYCL(T) \
    296   PREVENT_INSTANTIATE(T, STRIDED_SLICE_INSTANTIATE_DIM)
    297 
    298 #define DECLARE_FOR_N_SYCL(T) \
    299   INSTANTIATE(SYCLDevice, T, STRIDED_SLICE_INSTANTIATE_DIM)
    300 
    301 TF_CALL_SYCL_PROXY_TYPES(PREVENT_FOR_N_SYCL);
    302 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_FOR_N_SYCL);
    303 DECLARE_FOR_N_SYCL(int32);
    304 DECLARE_FOR_N_SYCL(int64);
    305 
    306 #undef DECLARE_FOR_N_SYCL
    307 #endif  // TENSORFLOW_USE_SYCL
    308 
    309 #undef INSTANTIATE
    310 #undef DECLARE_FOR_N_CPU
    311 #undef DECLARE_FOR_N_GPU
    312 
    313 }  // end namespace tensorflow
    314 
    315 #endif  // END STRIDED_SLICE_INSTANTIATE_DIM
    316 #endif  // TENSORFLOW_KERNELS_SLICE_OP_H_
    317