1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc 17 #define EIGEN_USE_THREADS 18 19 #include "tensorflow/core/kernels/reverse_op.h" 20 #include <memory> 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/framework/type_traits.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/kernels/bounds_check.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/util/work_sharder.h" 32 33 namespace tensorflow { 34 35 typedef Eigen::ThreadPoolDevice CPUDevice; 36 typedef Eigen::GpuDevice GPUDevice; 37 #ifdef TENSORFLOW_USE_SYCL 38 typedef Eigen::SyclDevice SYCLDevice; 39 #endif // TENSORFLOW_USE_SYCL 40 41 namespace { 42 43 // Reverse rows (middle dimension) of a three dimensional tensor. 44 // NUM_CHANNELS can be <= 0 to compute it dynamically from <input> 45 // Otherwise, it must equal input.dim_size(2) and is used as a compile-time 46 // constant. 47 template <typename T, int NUM_CHANNELS> 48 void ReverseRows(OpKernelContext* context, const Tensor& input, 49 Tensor* result) { 50 auto work = [&input, result](int64 start, int64 end) { 51 const int64 inner_size = 52 NUM_CHANNELS > 0 ? NUM_CHANNELS : input.dim_size(2); 53 const int64 middle_size = input.dim_size(1); 54 const int64 row_size = inner_size * middle_size; 55 DCHECK_EQ(input.dim_size(2), inner_size); 56 57 const T* in_ptr = input.bit_casted_tensor<T, 3>().data(); 58 T* out_ptr = result->bit_casted_tensor<T, 3>().data(); 59 60 in_ptr += start * row_size; 61 out_ptr += start * row_size; 62 63 for (int outer_dim = start; outer_dim < end; ++outer_dim) { 64 out_ptr += row_size; 65 int remaining = middle_size; 66 while (remaining > 0) { 67 out_ptr -= inner_size; 68 memcpy(out_ptr, in_ptr, inner_size * sizeof(T)); 69 in_ptr += inner_size; 70 --remaining; 71 } 72 73 out_ptr += row_size; 74 } 75 }; 76 77 // Shard across outer dimension. 78 const int64 N = input.dim_size(0); 79 const int64 cost_per_unit = input.NumElements() / N; 80 auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); 81 Shard(worker_threads->num_threads, worker_threads->workers, N, cost_per_unit, 82 std::move(work)); 83 } 84 85 template <typename T> 86 struct data_type_can_memcpy { 87 static constexpr bool value = 88 std::is_same<T, uint8>::value || std::is_same<T, int8>::value || 89 std::is_same<T, bool>::value || std::is_same<T, uint16>::value || 90 std::is_same<T, int16>::value || std::is_same<T, Eigen::half>::value || 91 std::is_same<T, int32>::value || std::is_same<T, float>::value || 92 std::is_same<T, int64>::value || std::is_same<T, double>::value || 93 std::is_same<T, complex64>::value || std::is_same<T, complex128>::value; 94 }; 95 96 template <typename T, int NUM_CHANNELS> 97 typename std::enable_if<data_type_can_memcpy<T>::value>::type 98 DoHandleReverseCase(OpKernelContext* context, const Tensor& input, 99 Tensor* result) { 100 if (sizeof(T) == 1) { 101 static_assert(sizeof(uint8) == 1, "uint8 must be 1 byte."); 102 ReverseRows<uint8, NUM_CHANNELS>(context, input, result); 103 } else if (sizeof(T) == 2) { 104 static_assert(sizeof(uint16) == 2, "uint16 must be 2 bytes"); 105 ReverseRows<uint16, NUM_CHANNELS>(context, input, result); 106 } else if (sizeof(T) == 4) { 107 static_assert(sizeof(uint32) == 4, "uint32 must be 4 bytes"); 108 ReverseRows<uint32, NUM_CHANNELS>(context, input, result); 109 } else if (sizeof(T) == 8) { 110 static_assert(sizeof(uint64) == 8, "uint64 must be 8 bytes"); 111 ReverseRows<uint64, NUM_CHANNELS>(context, input, result); 112 } else if (sizeof(T) == 16) { 113 static_assert(sizeof(complex128) == 16, "complex128 must be 16 bytes"); 114 ReverseRows<complex128, NUM_CHANNELS>(context, input, result); 115 } else { 116 context->CtxFailure( 117 errors::InvalidArgument("%s has unexpected size of %d bytes", 118 DataTypeString(input.dtype()), sizeof(T))); 119 } 120 } 121 122 template <typename T, int NUM_CHANNELS> 123 typename std::enable_if<!data_type_can_memcpy<T>::value>::type 124 DoHandleReverseCase(OpKernelContext* context, const Tensor& input, 125 Tensor* result) {} 126 127 } // namespace 128 129 template <typename Device, typename T, int NDIMS> 130 void HandleReverseCase(OpKernelContext* context, 131 typename TTypes<bool, 1>::ConstTensor dims, 132 Tensor* result) { 133 const Tensor& input = context->input(0); 134 135 // Use optimized reverse if possible. 136 if (NDIMS == 3 && std::is_same<Device, CPUDevice>::value && 137 data_type_can_memcpy<T>::value && (!dims(0) && dims(1) && !dims(2))) { 138 if (input.dim_size(2) == 3) { 139 DoHandleReverseCase<T, 3>(context, input, result); 140 } else { 141 DoHandleReverseCase<T, -1>(context, input, result); 142 } 143 return; 144 } 145 typename Eigen::array<bool, NDIMS> axes_di; 146 for (int i = 0; i < NDIMS; i++) { 147 axes_di[i] = dims(i); 148 } 149 functor::Reverse<Device, T, NDIMS>()(context->eigen_device<Device>(), 150 input.tensor<T, NDIMS>(), axes_di, 151 result->tensor<T, NDIMS>()); 152 } 153 154 template <typename Device, typename T> 155 class ReverseOp : public OpKernel { 156 public: 157 explicit ReverseOp(OpKernelConstruction* context) : OpKernel(context) {} 158 159 void Compute(OpKernelContext* context) override { 160 const Tensor& input = context->input(0); 161 const Tensor& dims = context->input(1); 162 163 if (TensorShapeUtils::IsScalar(input.shape())) { 164 context->set_output(0, input); 165 } else { 166 const int input_dims = input.dims(); 167 OP_REQUIRES(context, TensorShapeUtils::IsVector(dims.shape()), 168 errors::InvalidArgument("'dims' must be 1-dimension, not ", 169 dims.dims())); 170 171 OP_REQUIRES( 172 context, input_dims == dims.dim_size(0), 173 errors::InvalidArgument( 174 "'dims' must have the same number of values as 'input' has " 175 "dimensions. 'input' has ", 176 input_dims, "'dims' has ", dims.dim_size(0), " values")); 177 OP_REQUIRES(context, input_dims <= 8, 178 errors::Unimplemented( 179 "reverse is not implemented for tensors of rank > 8.")); 180 181 Tensor* output = nullptr; 182 OP_REQUIRES_OK(context, 183 context->allocate_output(0, input.shape(), &output)); 184 185 #define HANDLE_REVERSE(NDIMS) \ 186 case NDIMS: \ 187 HandleReverseCase<Device, T, NDIMS>(context, dims.vec<bool>(), output); \ 188 return; 189 190 switch (input_dims) { 191 HANDLE_REVERSE(0); 192 HANDLE_REVERSE(1); 193 HANDLE_REVERSE(2); 194 HANDLE_REVERSE(3); 195 HANDLE_REVERSE(4); 196 HANDLE_REVERSE(5); 197 HANDLE_REVERSE(6); 198 HANDLE_REVERSE(7); 199 HANDLE_REVERSE(8); 200 } 201 #undef HANDLE_REVERSE 202 } 203 } 204 }; 205 206 template <typename Device, typename T, int NDIMS> 207 void HandleReverseV2Case(OpKernelContext* context, 208 const gtl::ArraySlice<bool>& axes, Tensor* result) { 209 const Tensor& input = context->input(0); 210 211 // Use optimized reverse if possible. 212 if (NDIMS == 3 && std::is_same<Device, CPUDevice>::value && 213 data_type_can_memcpy<T>::value && (!axes[0] && axes[1] && !axes[2])) { 214 if (input.dim_size(2) == 3) { 215 DoHandleReverseCase<T, 3>(context, input, result); 216 } else { 217 DoHandleReverseCase<T, -1>(context, input, result); 218 } 219 return; 220 } 221 222 typename Eigen::array<bool, NDIMS> axes_di; 223 for (int i = 0; i < NDIMS; i++) { 224 axes_di[i] = axes[i]; 225 } 226 functor::Reverse<Device, T, NDIMS>()(context->eigen_device<Device>(), 227 input.tensor<T, NDIMS>(), axes_di, 228 result->tensor<T, NDIMS>()); 229 } 230 231 template <typename Device, typename T, typename Tidx> 232 class ReverseV2Op : public OpKernel { 233 public: 234 explicit ReverseV2Op(OpKernelConstruction* context) : OpKernel(context) {} 235 236 void Compute(OpKernelContext* context) override { 237 const Tensor& input = context->input(0); 238 const Tensor& sparse_dims = context->input(1); 239 240 if (TensorShapeUtils::IsScalar(input.shape())) { 241 context->set_output(0, input); 242 } else { 243 const int input_dims = input.dims(); 244 const TensorShape& sparse_dims_shape = sparse_dims.shape(); 245 const auto& axes_sparse_flat = sparse_dims.flat<Tidx>(); 246 247 OP_REQUIRES(context, TensorShapeUtils::IsVector(sparse_dims_shape), 248 errors::InvalidArgument("'dims' must be 1-dimension, not ", 249 sparse_dims.dims())); 250 gtl::InlinedVector<bool, 8> axes_dense(input_dims, false); 251 for (int dummy = 0; dummy < axes_sparse_flat.size(); dummy++) { 252 Tidx axis = internal::SubtleMustCopy<Tidx>(axes_sparse_flat(dummy)); 253 Tidx canonical_axis = axis < 0 ? input_dims + axis : axis; 254 OP_REQUIRES(context, canonical_axis >= 0 && canonical_axis < input_dims, 255 errors::InvalidArgument("'axis'[", dummy, "] = ", axis, 256 " is out of valid range [", 0, ", ", 257 input_dims - 1)); 258 OP_REQUIRES(context, !axes_dense[canonical_axis], 259 errors::InvalidArgument("axis ", canonical_axis, 260 " specified more than once.")); 261 axes_dense[canonical_axis] = true; 262 } 263 264 OP_REQUIRES(context, input_dims <= 8, 265 errors::Unimplemented( 266 "reverse is not implemented for tensors of rank > 8.")); 267 268 Tensor* output = nullptr; 269 OP_REQUIRES_OK(context, 270 context->allocate_output(0, input.shape(), &output)); 271 272 // TODO(cwhipkey): we can do dimension folding to reduce, e.g., a reverse 273 // of a single dimension to the dims=3 or dims=2 case, regardless of the 274 // number of dimensions in the tensor. This would let some ops use faster 275 // lower-dimension code (and use optimized versions). 276 277 #define HANDLE_REVERSE(NDIMS) \ 278 case NDIMS: \ 279 HandleReverseV2Case<Device, T, NDIMS>(context, axes_dense, output); \ 280 return; 281 282 switch (input_dims) { 283 HANDLE_REVERSE(0); 284 HANDLE_REVERSE(1); 285 HANDLE_REVERSE(2); 286 HANDLE_REVERSE(3); 287 HANDLE_REVERSE(4); 288 HANDLE_REVERSE(5); 289 HANDLE_REVERSE(6); 290 HANDLE_REVERSE(7); 291 HANDLE_REVERSE(8); 292 } 293 #undef HANDLE_REVERSE 294 } 295 } 296 }; 297 298 #define REGISTER_KERNELS(T) \ 299 REGISTER_KERNEL_BUILDER(Name("Reverse") \ 300 .Device(DEVICE_CPU) \ 301 .TypeConstraint<T>("T") \ 302 .HostMemory("dims"), \ 303 ReverseOp<CPUDevice, T>) \ 304 REGISTER_KERNEL_BUILDER(Name("ReverseV2") \ 305 .Device(DEVICE_CPU) \ 306 .TypeConstraint<T>("T") \ 307 .TypeConstraint<int32>("Tidx") \ 308 .HostMemory("axis"), \ 309 ReverseV2Op<CPUDevice, T, int32>) \ 310 REGISTER_KERNEL_BUILDER(Name("ReverseV2") \ 311 .Device(DEVICE_CPU) \ 312 .TypeConstraint<T>("T") \ 313 .TypeConstraint<int64>("Tidx") \ 314 .HostMemory("axis"), \ 315 ReverseV2Op<CPUDevice, T, int64>) 316 TF_CALL_POD_TYPES(REGISTER_KERNELS); 317 TF_CALL_string(REGISTER_KERNELS); 318 #undef REGISTER_KERNELS 319 320 #if GOOGLE_CUDA 321 322 // Forward declarations of the function specializations for GPU (to prevent 323 // building the GPU versions here, they will be built compiling _gpu.cu.cc). 324 namespace functor { 325 #define DECLARE_GPU_SPEC_DIM(T, DIM) \ 326 template <> \ 327 void Reverse<GPUDevice, T, DIM>::operator()( \ 328 const GPUDevice& d, typename TTypes<T, DIM>::ConstTensor input, \ 329 const Eigen::array<bool, DIM>& reverse_dims, \ 330 typename TTypes<T, DIM>::Tensor output); \ 331 extern template struct Reverse<GPUDevice, T, DIM>; 332 #define DECLARE_GPU_SPEC(T) \ 333 DECLARE_GPU_SPEC_DIM(T, 0) \ 334 DECLARE_GPU_SPEC_DIM(T, 1) \ 335 DECLARE_GPU_SPEC_DIM(T, 2) \ 336 DECLARE_GPU_SPEC_DIM(T, 3) \ 337 DECLARE_GPU_SPEC_DIM(T, 4) \ 338 DECLARE_GPU_SPEC_DIM(T, 5) \ 339 DECLARE_GPU_SPEC_DIM(T, 6) \ 340 DECLARE_GPU_SPEC_DIM(T, 7) \ 341 DECLARE_GPU_SPEC_DIM(T, 8) 342 343 TF_CALL_uint8(DECLARE_GPU_SPEC); 344 TF_CALL_int8(DECLARE_GPU_SPEC); 345 TF_CALL_bool(DECLARE_GPU_SPEC); 346 TF_CALL_half(DECLARE_GPU_SPEC); 347 TF_CALL_float(DECLARE_GPU_SPEC); 348 TF_CALL_double(DECLARE_GPU_SPEC); 349 TF_CALL_complex64(DECLARE_GPU_SPEC); 350 TF_CALL_complex128(DECLARE_GPU_SPEC); 351 #undef DECLARE_GPU_SPEC 352 #undef DECLARE_GPU_SPEC_DIM 353 } // namespace functor 354 355 // Registration of the GPU implementations. 356 #define REGISTER_GPU_KERNELS(T) \ 357 REGISTER_KERNEL_BUILDER(Name("Reverse") \ 358 .Device(DEVICE_GPU) \ 359 .TypeConstraint<T>("T") \ 360 .HostMemory("dims"), \ 361 ReverseOp<GPUDevice, T>) \ 362 REGISTER_KERNEL_BUILDER(Name("ReverseV2") \ 363 .Device(DEVICE_GPU) \ 364 .TypeConstraint<T>("T") \ 365 .TypeConstraint<int32>("Tidx") \ 366 .HostMemory("axis"), \ 367 ReverseV2Op<GPUDevice, T, int32>) \ 368 REGISTER_KERNEL_BUILDER(Name("ReverseV2") \ 369 .Device(DEVICE_GPU) \ 370 .TypeConstraint<T>("T") \ 371 .TypeConstraint<int64>("Tidx") \ 372 .HostMemory("axis"), \ 373 ReverseV2Op<GPUDevice, T, int64>) 374 TF_CALL_uint8(REGISTER_GPU_KERNELS); 375 TF_CALL_int8(REGISTER_GPU_KERNELS); 376 // TODO decide whether we want to enable the bool kernel. 377 // TF_CALL_bool(REGISTER_GPU_KERNELS); 378 TF_CALL_half(REGISTER_GPU_KERNELS); 379 TF_CALL_float(REGISTER_GPU_KERNELS); 380 TF_CALL_double(REGISTER_GPU_KERNELS); 381 TF_CALL_complex64(REGISTER_GPU_KERNELS); 382 TF_CALL_complex128(REGISTER_GPU_KERNELS); 383 #undef REGISTER_GPU_KERNEL 384 385 // A special GPU kernel for int32. 386 // TODO(b/25387198): Also enable int32 in device memory. This kernel 387 // registration requires all int32 inputs and outputs to be in host memory. 388 REGISTER_KERNEL_BUILDER(Name("Reverse") 389 .Device(DEVICE_GPU) 390 .TypeConstraint<int32>("T") 391 .HostMemory("tensor") 392 .HostMemory("dims") 393 .HostMemory("output"), 394 ReverseOp<CPUDevice, int32>); 395 REGISTER_KERNEL_BUILDER(Name("ReverseV2") 396 .Device(DEVICE_GPU) 397 .TypeConstraint<int32>("T") 398 .TypeConstraint<int32>("Tidx") 399 .HostMemory("tensor") 400 .HostMemory("axis") 401 .HostMemory("output"), 402 ReverseV2Op<CPUDevice, int32, int32>); 403 REGISTER_KERNEL_BUILDER(Name("ReverseV2") 404 .Device(DEVICE_GPU) 405 .TypeConstraint<int32>("T") 406 .TypeConstraint<int64>("Tidx") 407 .HostMemory("tensor") 408 .HostMemory("axis") 409 .HostMemory("output"), 410 ReverseV2Op<CPUDevice, int32, int64>); 411 #endif // GOOGLE_CUDA 412 413 #ifdef TENSORFLOW_USE_SYCL 414 #define REGISTER_SYCL_KERNELS(T) \ 415 REGISTER_KERNEL_BUILDER(Name("Reverse") \ 416 .Device(DEVICE_SYCL) \ 417 .TypeConstraint<T>("T") \ 418 .HostMemory("dims"), \ 419 ReverseOp<SYCLDevice, T>) \ 420 REGISTER_KERNEL_BUILDER(Name("ReverseV2") \ 421 .Device(DEVICE_SYCL) \ 422 .TypeConstraint<T>("T") \ 423 .TypeConstraint<int32>("Tidx") \ 424 .HostMemory("axis"), \ 425 ReverseV2Op<SYCLDevice, T, int32>) \ 426 REGISTER_KERNEL_BUILDER(Name("ReverseV2") \ 427 .Device(DEVICE_SYCL) \ 428 .TypeConstraint<T>("T") \ 429 .TypeConstraint<int64>("Tidx") \ 430 .HostMemory("axis"), \ 431 ReverseV2Op<SYCLDevice, T, int64>) 432 TF_CALL_uint8(REGISTER_SYCL_KERNELS); 433 TF_CALL_int8(REGISTER_SYCL_KERNELS); 434 TF_CALL_float(REGISTER_SYCL_KERNELS); 435 TF_CALL_double(REGISTER_SYCL_KERNELS); 436 437 REGISTER_KERNEL_BUILDER(Name("Reverse") 438 .Device(DEVICE_SYCL) 439 .TypeConstraint<int32>("T") 440 .HostMemory("tensor") 441 .HostMemory("dims") 442 .HostMemory("output"), 443 ReverseOp<CPUDevice, int32>); 444 REGISTER_KERNEL_BUILDER(Name("ReverseV2") 445 .Device(DEVICE_SYCL) 446 .TypeConstraint<int32>("T") 447 .TypeConstraint<int32>("Tidx") 448 .HostMemory("tensor") 449 .HostMemory("axis") 450 .HostMemory("output"), 451 ReverseV2Op<CPUDevice, int32, int32>); 452 REGISTER_KERNEL_BUILDER(Name("ReverseV2") 453 .Device(DEVICE_SYCL) 454 .TypeConstraint<int32>("T") 455 .TypeConstraint<int64>("Tidx") 456 .HostMemory("tensor") 457 .HostMemory("axis") 458 .HostMemory("output"), 459 ReverseV2Op<CPUDevice, int32, int64>); 460 #endif // TENSORFLOW_USE_SYCL 461 } // namespace tensorflow 462