1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_STRIDED_SLICE_OP_IMPL_H_ 17 #define TENSORFLOW_KERNELS_STRIDED_SLICE_OP_IMPL_H_ 18 19 // Functor definition for StridedSliceOp, must be compilable by nvcc. 20 21 #include "tensorflow/core/kernels/slice_op.h" 22 #include "tensorflow/core/kernels/strided_slice_op.h" 23 24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/register_types.h" 27 #include "tensorflow/core/framework/register_types_traits.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/variant.h" 30 #include "tensorflow/core/framework/variant_encode_decode.h" 31 #include "tensorflow/core/kernels/bounds_check.h" 32 #include "tensorflow/core/kernels/dense_update_functor.h" 33 #include "tensorflow/core/kernels/ops_util.h" 34 #include "tensorflow/core/lib/core/status.h" 35 #include "tensorflow/core/lib/gtl/array_slice.h" 36 #include "tensorflow/core/platform/mem.h" 37 38 namespace tensorflow { 39 40 template <typename Device, typename T, int NDIM> 41 void HandleStridedSliceCase(OpKernelContext* context, 42 const gtl::ArraySlice<int64>& begin, 43 const gtl::ArraySlice<int64>& end, 44 const gtl::ArraySlice<int64>& strides, 45 const TensorShape& processing_shape, 46 bool is_simple_slice, Tensor* result); 47 48 template <typename Device, typename T, int NDIM> 49 void HandleStridedSliceGradCase(OpKernelContext* context, 50 const gtl::ArraySlice<int64>& begin, 51 const gtl::ArraySlice<int64>& end, 52 const gtl::ArraySlice<int64>& strides, 53 const TensorShape& processing_shape, 54 bool is_simple_slice, Tensor* result); 55 56 template <typename Device, typename T, int NDIM> 57 class HandleStridedSliceAssignCase { 58 public: 59 void operator()(OpKernelContext* context, const gtl::ArraySlice<int64>& begin, 60 const gtl::ArraySlice<int64>& end, 61 const gtl::ArraySlice<int64>& strides, 62 const TensorShape& processing_shape, bool is_simple_slice, 63 Tensor* result); 64 }; 65 } // namespace tensorflow 66 67 // The actual implementation. This is designed so multiple 68 // translation units can include this file in the form 69 // 70 // #define STRIDED_SLICE_INSTANTIATE_DIM 1 71 // #include <thisfile> 72 // #undef STRIDED_SLICE_INSTANTIATE_DIM 73 // 74 #ifdef STRIDED_SLICE_INSTANTIATE_DIM 75 76 namespace tensorflow { 77 78 template <typename Device, typename T, int NDIM> 79 void HandleStridedSliceCase(OpKernelContext* context, 80 const gtl::ArraySlice<int64>& begin, 81 const gtl::ArraySlice<int64>& end, 82 const gtl::ArraySlice<int64>& strides, 83 const TensorShape& processing_shape, 84 bool is_simple_slice, Tensor* result) { 85 typedef typename proxy_type<Device, T>::type Proxy; 86 87 gtl::InlinedVector<int64, 4> processing_dims = processing_shape.dim_sizes(); 88 if (is_simple_slice) { 89 Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di; 90 Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes_di; 91 for (int i = 0; i < NDIM; ++i) { 92 begin_di[i] = begin[i]; 93 sizes_di[i] = end[i] - begin[i]; 94 } 95 functor::Slice<Device, Proxy, NDIM>()( 96 context->eigen_device<Device>(), 97 result->bit_casted_shaped<Proxy, NDIM>(processing_dims), 98 context->input(0).bit_casted_tensor<Proxy, NDIM>(), begin_di, sizes_di); 99 } else { 100 Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di; 101 Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di; 102 Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di; 103 for (int i = 0; i < NDIM; ++i) { 104 begin_di[i] = begin[i]; 105 end_di[i] = end[i]; 106 strides_di[i] = strides[i]; 107 } 108 functor::StridedSlice<Device, Proxy, NDIM>()( 109 context->eigen_device<Device>(), 110 result->bit_casted_shaped<Proxy, NDIM>(processing_dims), 111 context->input(0).bit_casted_tensor<Proxy, NDIM>(), begin_di, end_di, 112 strides_di); 113 } 114 } 115 116 template <typename Device, typename T, int NDIM> 117 void HandleStridedSliceGradCase(OpKernelContext* context, 118 const gtl::ArraySlice<int64>& begin, 119 const gtl::ArraySlice<int64>& end, 120 const gtl::ArraySlice<int64>& strides, 121 const TensorShape& processing_shape, 122 bool is_simple_slice, Tensor* result) { 123 gtl::InlinedVector<int64, 4> processing_dims = processing_shape.dim_sizes(); 124 125 Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di; 126 Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di; 127 Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di; 128 for (int i = 0; i < NDIM; ++i) { 129 begin_di[i] = begin[i]; 130 end_di[i] = end[i]; 131 strides_di[i] = strides[i]; 132 } 133 134 typedef typename proxy_type<Device, T>::type Proxy; 135 functor::StridedSliceGrad<Device, Proxy, NDIM>()( 136 context->eigen_device<Device>(), result->bit_casted_tensor<Proxy, NDIM>(), 137 context->input(4).bit_casted_shaped<Proxy, NDIM>(processing_dims), 138 begin_di, end_di, strides_di); 139 } 140 141 template <typename Device, typename T, int NDIM> 142 void HandleStridedSliceAssignCase<Device, T, NDIM>::operator()( 143 OpKernelContext* context, const gtl::ArraySlice<int64>& begin, 144 const gtl::ArraySlice<int64>& end, const gtl::ArraySlice<int64>& strides, 145 const TensorShape& processing_shape, bool is_simple_slice, Tensor* result) { 146 gtl::InlinedVector<int64, 4> processing_dims = processing_shape.dim_sizes(); 147 typedef typename proxy_type<Device, T>::type Proxy; 148 Eigen::DSizes<Eigen::DenseIndex, NDIM> begin_di; 149 Eigen::DSizes<Eigen::DenseIndex, NDIM> end_di; 150 Eigen::DSizes<Eigen::DenseIndex, NDIM> strides_di; 151 for (int i = 0; i < NDIM; ++i) { 152 begin_di[i] = begin[i]; 153 end_di[i] = end[i]; 154 strides_di[i] = strides[i]; 155 } 156 functor::StridedSliceAssign<Device, Proxy, NDIM>()( 157 context->eigen_device<Device>(), result->bit_casted_tensor<Proxy, NDIM>(), 158 context->input(4).bit_casted_shaped<Proxy, NDIM>(processing_dims), 159 begin_di, end_di, strides_di); 160 } 161 162 template <typename Device, typename T> 163 class HandleStridedSliceAssignCase<Device, T, 0> { 164 public: 165 enum { NDIM_PROXY = 1 }; 166 void operator()(OpKernelContext* context, const gtl::ArraySlice<int64>& begin, 167 const gtl::ArraySlice<int64>& end, 168 const gtl::ArraySlice<int64>& strides, 169 const TensorShape& processing_shape, bool is_simple_slice, 170 Tensor* result) { 171 gtl::InlinedVector<int64, 1> processing_dims(1); 172 processing_dims[0] = 1; 173 174 typedef typename proxy_type<Device, T>::type Proxy; 175 functor::StridedSliceAssignScalar<Device, Proxy>()( 176 context->eigen_device<Device>(), 177 result->bit_casted_shaped<Proxy, 1>(processing_dims), 178 context->input(4).bit_casted_shaped<Proxy, 1>(processing_dims)); 179 } 180 }; 181 182 // NODE(aselle): according to bsteiner, we need this because otherwise 183 // nvcc instantiates templates that are invalid. strided_slice_op_gpu.cu 184 // handles instantiates externally. It is important that this is done# 185 186 // before the HandleXXCase's are instantiated to avoid duplicate 187 // specialization errors. 188 189 #define PREVENT_INSTANTIATE_DIM1_AND_UP(T, NDIM) \ 190 namespace functor { \ 191 template <> \ 192 void StridedSlice<GPUDevice, T, NDIM>::operator()( \ 193 const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \ 194 typename TTypes<T, NDIM>::ConstTensor input, \ 195 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start, \ 196 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop, \ 197 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides); \ 198 extern template struct StridedSlice<GPUDevice, T, NDIM>; \ 199 template <> \ 200 void Slice<GPUDevice, T, NDIM>::operator()( \ 201 const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \ 202 typename TTypes<T, NDIM>::ConstTensor input, \ 203 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \ 204 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \ 205 extern template struct Slice<GPUDevice, T, NDIM>; \ 206 template <> \ 207 void StridedSliceGrad<GPUDevice, T, NDIM>::operator()( \ 208 const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \ 209 typename TTypes<T, NDIM>::ConstTensor input, \ 210 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start, \ 211 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop, \ 212 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides); \ 213 extern template struct StridedSliceGrad<GPUDevice, T, NDIM>; \ 214 template <> \ 215 void StridedSliceAssign<GPUDevice, T, NDIM>::operator()( \ 216 const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \ 217 typename TTypes<T, NDIM>::ConstTensor input, \ 218 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& start, \ 219 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& stop, \ 220 const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides); \ 221 extern template struct StridedSliceAssign<GPUDevice, T, NDIM>; \ 222 } // namespace functor 223 #define PREVENT_INSTANTIATE_DIM0_ONLY(T, NDIM) \ 224 namespace functor { \ 225 template <> \ 226 void StridedSliceAssignScalar<GPUDevice, T>::operator()( \ 227 const GPUDevice& d, typename TTypes<T, 1>::Tensor output, \ 228 typename TTypes<T, 1>::ConstTensor input); \ 229 extern template struct StridedSliceAssignScalar<GPUDevice, T>; \ 230 } // namespace functor 231 232 // Dimension 0 only instantiates some functors. So we only need 233 // to prevent ones defined by PREVENT_INSTANTIATE_DIM0_ONLY 234 #if GOOGLE_CUDA 235 #if STRIDED_SLICE_INSTANTIATE_DIM == 0 236 #define PREVENT_INSTANTIATE(T, NDIM) PREVENT_INSTANTIATE_DIM0_ONLY(T, NDIM) 237 #else 238 #define PREVENT_INSTANTIATE(T, NDIM) PREVENT_INSTANTIATE_DIM1_AND_UP(T, NDIM) 239 #endif 240 #else 241 #define PREVENT_INSTANTIATE(T, NDIM) 242 #endif 243 244 #define INSTANTIATE_DIM1_AND_UP_HANDLERS(DEVICE, T, DIM) \ 245 template void HandleStridedSliceCase<DEVICE, T, DIM>( \ 246 OpKernelContext * context, const gtl::ArraySlice<int64>& begin, \ 247 const gtl::ArraySlice<int64>& end, \ 248 const gtl::ArraySlice<int64>& strides, \ 249 const TensorShape& processing_shape, bool is_simple_slice, \ 250 Tensor* result); \ 251 template void HandleStridedSliceGradCase<DEVICE, T, DIM>( \ 252 OpKernelContext * context, const gtl::ArraySlice<int64>& begin, \ 253 const gtl::ArraySlice<int64>& end, \ 254 const gtl::ArraySlice<int64>& strides, \ 255 const TensorShape& processing_shape, bool is_simple_slice, \ 256 Tensor* result); 257 258 #define INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) \ 259 template class HandleStridedSliceAssignCase<DEVICE, T, DIM>; 260 261 // Only some kernels need to be instantiated on dim 0. 262 #if STRIDED_SLICE_INSTANTIATE_DIM == 0 263 #define INSTANTIATE(DEVICE, T, DIM) \ 264 INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) 265 #else 266 #define INSTANTIATE(DEVICE, T, DIM) \ 267 INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) \ 268 INSTANTIATE_DIM1_AND_UP_HANDLERS(DEVICE, T, DIM) 269 #endif 270 271 #define DECLARE_FOR_N_CPU(T) \ 272 INSTANTIATE(CPUDevice, T, STRIDED_SLICE_INSTANTIATE_DIM) 273 274 #define PREVENT_FOR_N_GPU(T) \ 275 PREVENT_INSTANTIATE(T, STRIDED_SLICE_INSTANTIATE_DIM) 276 277 #define DECLARE_FOR_N_GPU(T) \ 278 INSTANTIATE(GPUDevice, T, STRIDED_SLICE_INSTANTIATE_DIM) 279 280 #if GOOGLE_CUDA 281 TF_CALL_GPU_PROXY_TYPES(PREVENT_FOR_N_GPU); 282 TF_CALL_complex64(PREVENT_FOR_N_GPU); 283 TF_CALL_complex128(PREVENT_FOR_N_GPU); 284 285 TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N_GPU); 286 TF_CALL_complex64(DECLARE_FOR_N_GPU); 287 TF_CALL_complex128(DECLARE_FOR_N_GPU); 288 DECLARE_FOR_N_GPU(int32); 289 DECLARE_FOR_N_GPU(int64); 290 #endif // END GOOGLE_CUDA 291 292 TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU); 293 294 #ifdef TENSORFLOW_USE_SYCL 295 #define PREVENT_FOR_N_SYCL(T) \ 296 PREVENT_INSTANTIATE(T, STRIDED_SLICE_INSTANTIATE_DIM) 297 298 #define DECLARE_FOR_N_SYCL(T) \ 299 INSTANTIATE(SYCLDevice, T, STRIDED_SLICE_INSTANTIATE_DIM) 300 301 TF_CALL_SYCL_PROXY_TYPES(PREVENT_FOR_N_SYCL); 302 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_FOR_N_SYCL); 303 DECLARE_FOR_N_SYCL(int32); 304 DECLARE_FOR_N_SYCL(int64); 305 306 #undef DECLARE_FOR_N_SYCL 307 #endif // TENSORFLOW_USE_SYCL 308 309 #undef INSTANTIATE 310 #undef DECLARE_FOR_N_CPU 311 #undef DECLARE_FOR_N_GPU 312 313 } // end namespace tensorflow 314 315 #endif // END STRIDED_SLICE_INSTANTIATE_DIM 316 #endif // TENSORFLOW_KERNELS_SLICE_OP_H_ 317