1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #if GOOGLE_CUDA 21 #define EIGEN_USE_GPU 22 #endif // GOOGLE_CUDA 23 24 #include "tensorflow/core/kernels/strided_slice_op.h" 25 #include "tensorflow/core/kernels/dense_update_functor.h" 26 #include "tensorflow/core/kernels/slice_op.h" 27 #include "tensorflow/core/kernels/strided_slice_op_impl.h" 28 29 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 30 #include "tensorflow/core/framework/op_kernel.h" 31 #include "tensorflow/core/framework/register_types.h" 32 #include "tensorflow/core/framework/tensor.h" 33 #include "tensorflow/core/kernels/bounds_check.h" 34 #include "tensorflow/core/kernels/ops_util.h" 35 #include "tensorflow/core/kernels/variable_ops.h" 36 #include "tensorflow/core/lib/core/status.h" 37 #include "tensorflow/core/lib/gtl/array_slice.h" 38 #include "tensorflow/core/platform/prefetch.h" 39 #include "tensorflow/core/util/strided_slice_op.h" 40 41 namespace tensorflow { 42 namespace { 43 44 template <typename T> 45 struct MemCpyFunctor { 46 // Returns true if the copy was made with memcpy, false otherwise. 47 bool Copy(const Tensor& input, const gtl::InlinedVector<int64, 4>& begin, 48 const gtl::InlinedVector<int64, 4>& end, Tensor* result) { 49 if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) { 50 auto in = input.tensor<T, 2>(); 51 auto output = result->tensor<T, 2>(); 52 // TODO(agarwal): Consider multi-threading if size[0] is large 53 for (int row_in = begin[0], row_out = 0; row_in < end[0]; 54 ++row_in, ++row_out) { 55 if (row_in + 1 < end[0]) { 56 port::prefetch<port::PREFETCH_HINT_T0>(&output(row_in + 1, 0)); 57 port::prefetch<port::PREFETCH_HINT_T0>(&in(row_in + 1, begin[1])); 58 } 59 memcpy(&output(row_out, 0), &in(row_in, begin[1]), 60 (end[1] - begin[1]) * sizeof(T)); 61 } 62 return true; 63 } 64 return false; 65 } 66 }; 67 68 template <> 69 struct MemCpyFunctor<ResourceHandle> { 70 bool Copy(const Tensor& input, const gtl::InlinedVector<int64, 4>& begin, 71 const gtl::InlinedVector<int64, 4>& end, Tensor* result) { 72 return false; 73 } 74 }; 75 76 } // namespace 77 78 template <typename Device, typename T> 79 class StridedSliceOp : public OpKernel { 80 public: 81 explicit StridedSliceOp(OpKernelConstruction* context) : OpKernel(context) { 82 OP_REQUIRES_OK(context, context->GetAttr("begin_mask", &begin_mask)); 83 OP_REQUIRES_OK(context, context->GetAttr("end_mask", &end_mask)); 84 OP_REQUIRES_OK(context, context->GetAttr("ellipsis_mask", &ellipsis_mask)); 85 OP_REQUIRES_OK(context, context->GetAttr("new_axis_mask", &new_axis_mask)); 86 OP_REQUIRES_OK(context, 87 context->GetAttr("shrink_axis_mask", &shrink_axis_mask)); 88 } 89 90 void Compute(OpKernelContext* context) override { 91 TensorShape processing_shape, final_shape; 92 bool is_identity = true; 93 bool slice_dim0 = true; 94 bool is_simple_slice = true; 95 gtl::InlinedVector<int64, 4> begin; 96 gtl::InlinedVector<int64, 4> end; 97 gtl::InlinedVector<int64, 4> strides; 98 99 OP_REQUIRES_OK( 100 context, ValidateStridedSliceOp( 101 &context->input(1), &context->input(2), context->input(3), 102 context->input(0).shape(), begin_mask, end_mask, 103 ellipsis_mask, new_axis_mask, shrink_axis_mask, 104 &processing_shape, &final_shape, &is_identity, 105 &is_simple_slice, &slice_dim0, &begin, &end, &strides)); 106 const Tensor& input = context->input(0); 107 108 // Optimization #1, slice is a no-op plus reshape 109 if (is_identity) { 110 VLOG(1) << "Strided slice identity "; 111 Tensor tmp; 112 CHECK(tmp.CopyFrom(input, final_shape)); 113 context->set_output(0, tmp); 114 return; 115 } 116 117 // Optimization #2, slice is memory contiguous (only occurs in dim 0) 118 if (slice_dim0 && IsDim0SliceAligned<T>(input.shape(), begin[0], end[0])) { 119 CHECK_GE(input.dims(), 1); // Otherwise, is_identity should be true. 120 VLOG(1) << "Strided slice dim 0: " << input.shape().DebugString(); 121 Tensor tmp; 122 CHECK(tmp.CopyFrom(input.Slice(begin[0], end[0]), final_shape)); 123 context->set_output(0, tmp); 124 return; 125 } 126 127 Tensor* result = nullptr; 128 OP_REQUIRES_OK(context, context->allocate_output(0, final_shape, &result)); 129 const int input_dims = input.dims(); 130 const int processing_dims = processing_shape.dims(); 131 132 if (processing_shape.num_elements() > 0) { 133 // Optimization #3, slice has stride 1 in all dimensions 134 // Optimization #3A, slice has only two dimensions 135 // TODO(aselle): Here we are restricting to processing_shape and 136 // final_shape being 2D. This isn't strictly necessary, but I don't 137 // want to blow up code gen size, because to shape<> you need static 138 // NDIM and T 139 if (is_simple_slice && std::is_same<Device, CPUDevice>::value && 140 input_dims == 2 && processing_shape.dims() == 2 && 141 final_shape.dims() == 2) { 142 MemCpyFunctor<T> functor; 143 if (functor.Copy(input, begin, end, result)) { 144 return; 145 } 146 } 147 148 #define HANDLE_DIM(NDIM) \ 149 if (processing_dims == NDIM) { \ 150 HandleStridedSliceCase<Device, T, NDIM>(context, begin, end, strides, \ 151 processing_shape, is_simple_slice, \ 152 result); \ 153 return; \ 154 } 155 156 HANDLE_DIM(1); 157 HANDLE_DIM(2); 158 HANDLE_DIM(3); 159 HANDLE_DIM(4); 160 HANDLE_DIM(5); 161 HANDLE_DIM(6); 162 HANDLE_DIM(7); 163 164 #undef HANDLE_DIM 165 166 OP_REQUIRES( 167 context, false, 168 errors::Unimplemented("Unhandled input dimensions ", input_dims)); 169 } 170 } 171 172 private: 173 int32 begin_mask, end_mask; 174 int32 ellipsis_mask, new_axis_mask, shrink_axis_mask; 175 }; 176 177 template <typename Device, typename T> 178 class StridedSliceGradOp : public OpKernel { 179 public: 180 explicit StridedSliceGradOp(OpKernelConstruction* context) 181 : OpKernel(context) { 182 OP_REQUIRES_OK(context, context->GetAttr("begin_mask", &begin_mask)); 183 OP_REQUIRES_OK(context, context->GetAttr("end_mask", &end_mask)); 184 OP_REQUIRES_OK(context, context->GetAttr("ellipsis_mask", &ellipsis_mask)); 185 OP_REQUIRES_OK(context, context->GetAttr("new_axis_mask", &new_axis_mask)); 186 OP_REQUIRES_OK(context, 187 context->GetAttr("shrink_axis_mask", &shrink_axis_mask)); 188 } 189 190 void Compute(OpKernelContext* context) override { 191 TensorShape processing_shape, final_shape; 192 bool is_identity = true; 193 bool slice_dim0 = true; 194 bool is_simple_slice = true; 195 gtl::InlinedVector<int64, 4> begin; 196 gtl::InlinedVector<int64, 4> end; 197 gtl::InlinedVector<int64, 4> strides; 198 199 TensorShape input_shape; 200 const Tensor& input_shape_tensor = context->input(0); 201 OP_REQUIRES( 202 context, input_shape_tensor.dims() == 1, 203 errors::InvalidArgument("shape must be 1-D, got shape.shape = ", 204 input_shape_tensor.shape().DebugString())); 205 if (input_shape_tensor.dtype() == DT_INT32) { 206 OP_REQUIRES_OK( 207 context, TensorShapeUtils::MakeShape(input_shape_tensor.vec<int32>(), 208 &input_shape)); 209 } else if (input_shape_tensor.dtype() == DT_INT64) { 210 OP_REQUIRES_OK( 211 context, TensorShapeUtils::MakeShape(input_shape_tensor.vec<int64>(), 212 &input_shape)); 213 } else { 214 LOG(FATAL) << "shape must have type int32 or int64."; 215 } 216 217 OP_REQUIRES_OK( 218 context, 219 ValidateStridedSliceOp( 220 &context->input(1), &context->input(2), context->input(3), 221 input_shape, begin_mask, end_mask, ellipsis_mask, new_axis_mask, 222 shrink_axis_mask, &processing_shape, &final_shape, &is_identity, 223 &is_simple_slice, &slice_dim0, &begin, &end, &strides)); 224 225 // Check to make sure dy is consistent with the original slice 226 TensorShape dy_shape = context->input(4).shape(); 227 OP_REQUIRES( 228 context, final_shape == dy_shape, 229 errors::InvalidArgument("shape of dy was ", dy_shape.DebugString(), 230 " instead of ", final_shape.DebugString())); 231 232 if (!context->status().ok()) return; 233 234 // const int input_dims = input.dims(); 235 const int processing_dims = processing_shape.dims(); 236 Tensor* result = nullptr; 237 OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &result)); 238 239 if (processing_shape.dims() == 0) { 240 auto in = context->input(4); 241 CHECK(result->CopyFrom(in, processing_shape)); 242 return; 243 } 244 245 #define HANDLE_DIM(NDIM) \ 246 if (processing_dims == NDIM) { \ 247 HandleStridedSliceGradCase<Device, T, NDIM>(context, begin, end, strides, \ 248 processing_shape, \ 249 is_simple_slice, result); \ 250 return; \ 251 } 252 253 HANDLE_DIM(1); 254 HANDLE_DIM(2); 255 HANDLE_DIM(3); 256 HANDLE_DIM(4); 257 HANDLE_DIM(5); 258 HANDLE_DIM(6); 259 HANDLE_DIM(7); 260 261 #undef HANDLE_DIM 262 } 263 264 private: 265 int32 begin_mask, end_mask; 266 int32 ellipsis_mask, new_axis_mask, shrink_axis_mask; 267 }; 268 269 template <typename Device, typename T> 270 class StridedSliceAssignOp : public OpKernel { 271 public: 272 explicit StridedSliceAssignOp(OpKernelConstruction* context) 273 : OpKernel(context) { 274 OP_REQUIRES_OK(context, context->GetAttr("begin_mask", &begin_mask)); 275 OP_REQUIRES_OK(context, context->GetAttr("end_mask", &end_mask)); 276 OP_REQUIRES_OK(context, context->GetAttr("ellipsis_mask", &ellipsis_mask)); 277 OP_REQUIRES_OK(context, context->GetAttr("new_axis_mask", &new_axis_mask)); 278 OP_REQUIRES_OK(context, 279 context->GetAttr("shrink_axis_mask", &shrink_axis_mask)); 280 } 281 282 void Compute(OpKernelContext* context) override { 283 TensorShape processing_shape, final_shape; 284 bool is_identity = true; 285 bool slice_dim0 = true; 286 bool is_simple_slice = true; 287 gtl::InlinedVector<int64, 4> begin; 288 gtl::InlinedVector<int64, 4> end; 289 gtl::InlinedVector<int64, 4> strides; 290 291 Tensor old_lhs; 292 if (context->input_dtype(0) == DT_RESOURCE) { 293 Var* v; 294 OP_REQUIRES_OK(context, 295 LookupResource(context, HandleFromInput(context, 0), &v)); 296 old_lhs = *v->tensor(); 297 OP_REQUIRES(context, old_lhs.dtype() == DataTypeToEnum<T>::value, 298 errors::InvalidArgument( 299 "l-value dtype ", DataTypeString(old_lhs.dtype()), 300 " does not match r-value dtype ", 301 DataTypeString(DataTypeToEnum<T>::value))); 302 } else { 303 context->forward_ref_input_to_ref_output(0, 0); 304 old_lhs = context->mutable_input(0, true); 305 } 306 307 OP_REQUIRES_OK( 308 context, 309 ValidateStridedSliceOp( 310 &context->input(1), &context->input(2), context->input(3), 311 old_lhs.shape(), begin_mask, end_mask, ellipsis_mask, new_axis_mask, 312 shrink_axis_mask, &processing_shape, &final_shape, &is_identity, 313 &is_simple_slice, &slice_dim0, &begin, &end, &strides)); 314 315 if (processing_shape.num_elements()) { 316 const Tensor& input = context->input(4); 317 TensorShape input_shape = input.shape(); 318 TensorShape original_shape = old_lhs.shape(); 319 // TODO(aselle): This check is too strong, we only should need 320 // input_shape to be broadcastable to final_shape 321 OP_REQUIRES( 322 context, final_shape == input_shape, 323 errors::Unimplemented( 324 "sliced l-value shape ", final_shape.DebugString(), 325 " does not match r-value shape ", input_shape.DebugString(), 326 ". Automatic broadcasting not ", "yet implemented.")); 327 const int processing_dims = processing_shape.dims(); 328 329 // 0-dimensional case implies the left and right are exactly the same 330 // scalar shape 331 332 // Handle general dimensions 333 #define HANDLE_DIM(NDIM) \ 334 if (processing_dims == NDIM) { \ 335 HandleStridedSliceAssignCase<Device, T, NDIM>()( \ 336 context, begin, end, strides, processing_shape, is_simple_slice, \ 337 &old_lhs); \ 338 return; \ 339 } 340 HANDLE_DIM(0); 341 HANDLE_DIM(1); 342 HANDLE_DIM(2); 343 HANDLE_DIM(3); 344 HANDLE_DIM(4); 345 HANDLE_DIM(5); 346 HANDLE_DIM(6); 347 HANDLE_DIM(7); 348 #undef HANDLE_DIM 349 350 OP_REQUIRES(context, false, 351 errors::Unimplemented("Unhandled input dimensions ", 352 processing_dims)); 353 } 354 } 355 356 private: 357 int32 begin_mask, end_mask; 358 int32 ellipsis_mask, new_axis_mask, shrink_axis_mask; 359 }; 360 361 #define REGISTER_STRIDED_SLICE(type) \ 362 REGISTER_KERNEL_BUILDER(Name("StridedSlice") \ 363 .Device(DEVICE_CPU) \ 364 .TypeConstraint<type>("T") \ 365 .HostMemory("begin") \ 366 .HostMemory("end") \ 367 .HostMemory("strides"), \ 368 StridedSliceOp<CPUDevice, type>) \ 369 REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") \ 370 .Device(DEVICE_CPU) \ 371 .TypeConstraint<type>("T") \ 372 .HostMemory("shape") \ 373 .HostMemory("begin") \ 374 .HostMemory("end") \ 375 .HostMemory("strides"), \ 376 StridedSliceGradOp<CPUDevice, type>) \ 377 REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") \ 378 .Device(DEVICE_CPU) \ 379 .TypeConstraint<type>("T") \ 380 .HostMemory("begin") \ 381 .HostMemory("end") \ 382 .HostMemory("strides"), \ 383 StridedSliceAssignOp<CPUDevice, type>) \ 384 REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") \ 385 .Device(DEVICE_CPU) \ 386 .TypeConstraint<type>("T") \ 387 .HostMemory("ref") \ 388 .HostMemory("begin") \ 389 .HostMemory("end") \ 390 .HostMemory("strides"), \ 391 StridedSliceAssignOp<CPUDevice, type>) 392 393 TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE); 394 395 #undef REGISTER_STRIDED_SLICE 396 397 #if GOOGLE_CUDA 398 399 #define REGISTER_GPU(type) \ 400 REGISTER_KERNEL_BUILDER(Name("StridedSlice") \ 401 .Device(DEVICE_GPU) \ 402 .TypeConstraint<type>("T") \ 403 .HostMemory("begin") \ 404 .HostMemory("end") \ 405 .HostMemory("strides"), \ 406 StridedSliceOp<GPUDevice, type>) \ 407 REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") \ 408 .Device(DEVICE_GPU) \ 409 .TypeConstraint<type>("T") \ 410 .HostMemory("shape") \ 411 .HostMemory("begin") \ 412 .HostMemory("end") \ 413 .HostMemory("strides"), \ 414 StridedSliceGradOp<GPUDevice, type>) \ 415 REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") \ 416 .Device(DEVICE_GPU) \ 417 .TypeConstraint<type>("T") \ 418 .HostMemory("begin") \ 419 .HostMemory("end") \ 420 .HostMemory("strides"), \ 421 StridedSliceAssignOp<GPUDevice, type>) \ 422 REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") \ 423 .Device(DEVICE_GPU) \ 424 .TypeConstraint<type>("T") \ 425 .HostMemory("ref") \ 426 .HostMemory("begin") \ 427 .HostMemory("end") \ 428 .HostMemory("strides"), \ 429 StridedSliceAssignOp<GPUDevice, type>) 430 431 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); 432 TF_CALL_complex64(REGISTER_GPU); 433 TF_CALL_complex128(REGISTER_GPU); 434 TF_CALL_int64(REGISTER_GPU); 435 436 // A special GPU kernel for int32. 437 // TODO(b/25387198): Also enable int32 in device memory. This kernel 438 // registration requires all int32 inputs and outputs to be in host memory. 439 REGISTER_KERNEL_BUILDER(Name("StridedSlice") 440 .Device(DEVICE_GPU) 441 .TypeConstraint<int32>("T") 442 .HostMemory("input") 443 .HostMemory("begin") 444 .HostMemory("end") 445 .HostMemory("strides") 446 .HostMemory("output"), 447 StridedSliceOp<CPUDevice, int32>); 448 REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") 449 .Device(DEVICE_GPU) 450 .TypeConstraint<int32>("T") 451 .HostMemory("shape") 452 .HostMemory("begin") 453 .HostMemory("end") 454 .HostMemory("strides") 455 .HostMemory("dy") 456 .HostMemory("output"), 457 StridedSliceGradOp<CPUDevice, int32>); 458 REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") 459 .Device(DEVICE_GPU) 460 .TypeConstraint<int32>("T") 461 .HostMemory("ref") 462 .HostMemory("begin") 463 .HostMemory("end") 464 .HostMemory("strides"), 465 StridedSliceAssignOp<CPUDevice, int32>) 466 REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") 467 .Device(DEVICE_GPU) 468 .TypeConstraint<int32>("T") 469 .HostMemory("ref") 470 .HostMemory("begin") 471 .HostMemory("end") 472 .HostMemory("strides"), 473 StridedSliceAssignOp<CPUDevice, int32>) 474 #undef REGISTER_GPU 475 476 #endif // GOOGLE_CUDA 477 478 #ifdef TENSORFLOW_USE_SYCL 479 #define REGISTER_SYCL(type) \ 480 REGISTER_KERNEL_BUILDER(Name("StridedSlice") \ 481 .Device(DEVICE_SYCL) \ 482 .TypeConstraint<type>("T") \ 483 .HostMemory("begin") \ 484 .HostMemory("end") \ 485 .HostMemory("strides"), \ 486 StridedSliceOp<SYCLDevice, type>) \ 487 REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") \ 488 .Device(DEVICE_SYCL) \ 489 .TypeConstraint<type>("T") \ 490 .HostMemory("shape") \ 491 .HostMemory("begin") \ 492 .HostMemory("end") \ 493 .HostMemory("strides"), \ 494 StridedSliceGradOp<SYCLDevice, type>) \ 495 REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") \ 496 .Device(DEVICE_SYCL) \ 497 .TypeConstraint<type>("T") \ 498 .HostMemory("begin") \ 499 .HostMemory("end") \ 500 .HostMemory("strides"), \ 501 StridedSliceAssignOp<SYCLDevice, type>) \ 502 REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") \ 503 .Device(DEVICE_SYCL) \ 504 .TypeConstraint<type>("T") \ 505 .HostMemory("ref") \ 506 .HostMemory("begin") \ 507 .HostMemory("end") \ 508 .HostMemory("strides"), \ 509 StridedSliceAssignOp<SYCLDevice, type>) 510 511 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL); 512 513 REGISTER_KERNEL_BUILDER(Name("StridedSlice") 514 .Device(DEVICE_SYCL) 515 .TypeConstraint<int32>("T") 516 .HostMemory("input") 517 .HostMemory("begin") 518 .HostMemory("end") 519 .HostMemory("strides") 520 .HostMemory("output"), 521 StridedSliceOp<CPUDevice, int32>); 522 REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") 523 .Device(DEVICE_SYCL) 524 .TypeConstraint<int32>("T") 525 .HostMemory("shape") 526 .HostMemory("begin") 527 .HostMemory("end") 528 .HostMemory("strides") 529 .HostMemory("dy") 530 .HostMemory("output"), 531 StridedSliceGradOp<CPUDevice, int32>); 532 REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") 533 .Device(DEVICE_SYCL) 534 .TypeConstraint<int32>("T") 535 .HostMemory("ref") 536 .HostMemory("begin") 537 .HostMemory("end") 538 .HostMemory("strides"), 539 StridedSliceAssignOp<CPUDevice, int32>) 540 REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign") 541 .Device(DEVICE_SYCL) 542 .TypeConstraint<int32>("T") 543 .HostMemory("ref") 544 .HostMemory("begin") 545 .HostMemory("end") 546 .HostMemory("strides"), 547 StridedSliceAssignOp<CPUDevice, int32>) 548 #undef REGISTER_SYCL 549 #endif // TENSORFLOW_USE_SYCL 550 } // namespace tensorflow 551