Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #if GOOGLE_CUDA
     21 #define EIGEN_USE_GPU
     22 #endif  // GOOGLE_CUDA
     23 
     24 #include "tensorflow/core/kernels/strided_slice_op.h"
     25 #include "tensorflow/core/kernels/dense_update_functor.h"
     26 #include "tensorflow/core/kernels/slice_op.h"
     27 #include "tensorflow/core/kernels/strided_slice_op_impl.h"
     28 
     29 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     30 #include "tensorflow/core/framework/op_kernel.h"
     31 #include "tensorflow/core/framework/register_types.h"
     32 #include "tensorflow/core/framework/tensor.h"
     33 #include "tensorflow/core/kernels/bounds_check.h"
     34 #include "tensorflow/core/kernels/ops_util.h"
     35 #include "tensorflow/core/kernels/variable_ops.h"
     36 #include "tensorflow/core/lib/core/status.h"
     37 #include "tensorflow/core/lib/gtl/array_slice.h"
     38 #include "tensorflow/core/platform/prefetch.h"
     39 #include "tensorflow/core/util/strided_slice_op.h"
     40 
     41 namespace tensorflow {
     42 namespace {
     43 
     44 template <typename T>
     45 struct MemCpyFunctor {
     46   // Returns true if the copy was made with memcpy, false otherwise.
     47   bool Copy(const Tensor& input, const gtl::InlinedVector<int64, 4>& begin,
     48             const gtl::InlinedVector<int64, 4>& end, Tensor* result) {
     49     if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
     50       auto in = input.tensor<T, 2>();
     51       auto output = result->tensor<T, 2>();
     52       // TODO(agarwal): Consider multi-threading if size[0] is large
     53       for (int row_in = begin[0], row_out = 0; row_in < end[0];
     54            ++row_in, ++row_out) {
     55         if (row_in + 1 < end[0]) {
     56           port::prefetch<port::PREFETCH_HINT_T0>(&output(row_in + 1, 0));
     57           port::prefetch<port::PREFETCH_HINT_T0>(&in(row_in + 1, begin[1]));
     58         }
     59         memcpy(&output(row_out, 0), &in(row_in, begin[1]),
     60                (end[1] - begin[1]) * sizeof(T));
     61       }
     62       return true;
     63     }
     64     return false;
     65   }
     66 };
     67 
     68 template <>
     69 struct MemCpyFunctor<ResourceHandle> {
     70   bool Copy(const Tensor& input, const gtl::InlinedVector<int64, 4>& begin,
     71             const gtl::InlinedVector<int64, 4>& end, Tensor* result) {
     72     return false;
     73   }
     74 };
     75 
     76 }  // namespace
     77 
     78 template <typename Device, typename T>
     79 class StridedSliceOp : public OpKernel {
     80  public:
     81   explicit StridedSliceOp(OpKernelConstruction* context) : OpKernel(context) {
     82     OP_REQUIRES_OK(context, context->GetAttr("begin_mask", &begin_mask));
     83     OP_REQUIRES_OK(context, context->GetAttr("end_mask", &end_mask));
     84     OP_REQUIRES_OK(context, context->GetAttr("ellipsis_mask", &ellipsis_mask));
     85     OP_REQUIRES_OK(context, context->GetAttr("new_axis_mask", &new_axis_mask));
     86     OP_REQUIRES_OK(context,
     87                    context->GetAttr("shrink_axis_mask", &shrink_axis_mask));
     88   }
     89 
     90   void Compute(OpKernelContext* context) override {
     91     TensorShape processing_shape, final_shape;
     92     bool is_identity = true;
     93     bool slice_dim0 = true;
     94     bool is_simple_slice = true;
     95     gtl::InlinedVector<int64, 4> begin;
     96     gtl::InlinedVector<int64, 4> end;
     97     gtl::InlinedVector<int64, 4> strides;
     98 
     99     OP_REQUIRES_OK(
    100         context, ValidateStridedSliceOp(
    101                      &context->input(1), &context->input(2), context->input(3),
    102                      context->input(0).shape(), begin_mask, end_mask,
    103                      ellipsis_mask, new_axis_mask, shrink_axis_mask,
    104                      &processing_shape, &final_shape, &is_identity,
    105                      &is_simple_slice, &slice_dim0, &begin, &end, &strides));
    106     const Tensor& input = context->input(0);
    107 
    108     // Optimization #1, slice is a no-op plus reshape
    109     if (is_identity) {
    110       VLOG(1) << "Strided slice identity ";
    111       Tensor tmp;
    112       CHECK(tmp.CopyFrom(input, final_shape));
    113       context->set_output(0, tmp);
    114       return;
    115     }
    116 
    117     // Optimization #2, slice is memory contiguous (only occurs in dim 0)
    118     if (slice_dim0 && IsDim0SliceAligned<T>(input.shape(), begin[0], end[0])) {
    119       CHECK_GE(input.dims(), 1);  // Otherwise, is_identity should be true.
    120       VLOG(1) << "Strided slice dim 0: " << input.shape().DebugString();
    121       Tensor tmp;
    122       CHECK(tmp.CopyFrom(input.Slice(begin[0], end[0]), final_shape));
    123       context->set_output(0, tmp);
    124       return;
    125     }
    126 
    127     Tensor* result = nullptr;
    128     OP_REQUIRES_OK(context, context->allocate_output(0, final_shape, &result));
    129     const int input_dims = input.dims();
    130     const int processing_dims = processing_shape.dims();
    131 
    132     if (processing_shape.num_elements() > 0) {
    133       // Optimization #3, slice has stride 1 in all dimensions
    134       // Optimization #3A, slice has only two dimensions
    135       // TODO(aselle): Here we are restricting to processing_shape and
    136       // final_shape being 2D. This isn't strictly necessary, but I don't
    137       // want to blow up code gen size, because to shape<> you need static
    138       // NDIM and T
    139       if (is_simple_slice && std::is_same<Device, CPUDevice>::value &&
    140           input_dims == 2 && processing_shape.dims() == 2 &&
    141           final_shape.dims() == 2) {
    142         MemCpyFunctor<T> functor;
    143         if (functor.Copy(input, begin, end, result)) {
    144           return;
    145         }
    146       }
    147 
    148 #define HANDLE_DIM(NDIM)                                                       \
    149   if (processing_dims == NDIM) {                                               \
    150     HandleStridedSliceCase<Device, T, NDIM>(context, begin, end, strides,      \
    151                                             processing_shape, is_simple_slice, \
    152                                             result);                           \
    153     return;                                                                    \
    154   }
    155 
    156       HANDLE_DIM(1);
    157       HANDLE_DIM(2);
    158       HANDLE_DIM(3);
    159       HANDLE_DIM(4);
    160       HANDLE_DIM(5);
    161       HANDLE_DIM(6);
    162       HANDLE_DIM(7);
    163 
    164 #undef HANDLE_DIM
    165 
    166       OP_REQUIRES(
    167           context, false,
    168           errors::Unimplemented("Unhandled input dimensions ", input_dims));
    169     }
    170   }
    171 
    172  private:
    173   int32 begin_mask, end_mask;
    174   int32 ellipsis_mask, new_axis_mask, shrink_axis_mask;
    175 };
    176 
    177 template <typename Device, typename T>
    178 class StridedSliceGradOp : public OpKernel {
    179  public:
    180   explicit StridedSliceGradOp(OpKernelConstruction* context)
    181       : OpKernel(context) {
    182     OP_REQUIRES_OK(context, context->GetAttr("begin_mask", &begin_mask));
    183     OP_REQUIRES_OK(context, context->GetAttr("end_mask", &end_mask));
    184     OP_REQUIRES_OK(context, context->GetAttr("ellipsis_mask", &ellipsis_mask));
    185     OP_REQUIRES_OK(context, context->GetAttr("new_axis_mask", &new_axis_mask));
    186     OP_REQUIRES_OK(context,
    187                    context->GetAttr("shrink_axis_mask", &shrink_axis_mask));
    188   }
    189 
    190   void Compute(OpKernelContext* context) override {
    191     TensorShape processing_shape, final_shape;
    192     bool is_identity = true;
    193     bool slice_dim0 = true;
    194     bool is_simple_slice = true;
    195     gtl::InlinedVector<int64, 4> begin;
    196     gtl::InlinedVector<int64, 4> end;
    197     gtl::InlinedVector<int64, 4> strides;
    198 
    199     TensorShape input_shape;
    200     const Tensor& input_shape_tensor = context->input(0);
    201     OP_REQUIRES(
    202         context, input_shape_tensor.dims() == 1,
    203         errors::InvalidArgument("shape must be 1-D, got shape.shape = ",
    204                                 input_shape_tensor.shape().DebugString()));
    205     if (input_shape_tensor.dtype() == DT_INT32) {
    206       OP_REQUIRES_OK(
    207           context, TensorShapeUtils::MakeShape(input_shape_tensor.vec<int32>(),
    208                                                &input_shape));
    209     } else if (input_shape_tensor.dtype() == DT_INT64) {
    210       OP_REQUIRES_OK(
    211           context, TensorShapeUtils::MakeShape(input_shape_tensor.vec<int64>(),
    212                                                &input_shape));
    213     } else {
    214       LOG(FATAL) << "shape must have type int32 or int64.";
    215     }
    216 
    217     OP_REQUIRES_OK(
    218         context,
    219         ValidateStridedSliceOp(
    220             &context->input(1), &context->input(2), context->input(3),
    221             input_shape, begin_mask, end_mask, ellipsis_mask, new_axis_mask,
    222             shrink_axis_mask, &processing_shape, &final_shape, &is_identity,
    223             &is_simple_slice, &slice_dim0, &begin, &end, &strides));
    224 
    225     // Check to make sure dy is consistent with the original slice
    226     TensorShape dy_shape = context->input(4).shape();
    227     OP_REQUIRES(
    228         context, final_shape == dy_shape,
    229         errors::InvalidArgument("shape of dy was ", dy_shape.DebugString(),
    230                                 " instead of ", final_shape.DebugString()));
    231 
    232     if (!context->status().ok()) return;
    233 
    234     // const int input_dims = input.dims();
    235     const int processing_dims = processing_shape.dims();
    236     Tensor* result = nullptr;
    237     OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &result));
    238 
    239     if (processing_shape.dims() == 0) {
    240       auto in = context->input(4);
    241       CHECK(result->CopyFrom(in, processing_shape));
    242       return;
    243     }
    244 
    245 #define HANDLE_DIM(NDIM)                                                      \
    246   if (processing_dims == NDIM) {                                              \
    247     HandleStridedSliceGradCase<Device, T, NDIM>(context, begin, end, strides, \
    248                                                 processing_shape,             \
    249                                                 is_simple_slice, result);     \
    250     return;                                                                   \
    251   }
    252 
    253     HANDLE_DIM(1);
    254     HANDLE_DIM(2);
    255     HANDLE_DIM(3);
    256     HANDLE_DIM(4);
    257     HANDLE_DIM(5);
    258     HANDLE_DIM(6);
    259     HANDLE_DIM(7);
    260 
    261 #undef HANDLE_DIM
    262   }
    263 
    264  private:
    265   int32 begin_mask, end_mask;
    266   int32 ellipsis_mask, new_axis_mask, shrink_axis_mask;
    267 };
    268 
    269 template <typename Device, typename T>
    270 class StridedSliceAssignOp : public OpKernel {
    271  public:
    272   explicit StridedSliceAssignOp(OpKernelConstruction* context)
    273       : OpKernel(context) {
    274     OP_REQUIRES_OK(context, context->GetAttr("begin_mask", &begin_mask));
    275     OP_REQUIRES_OK(context, context->GetAttr("end_mask", &end_mask));
    276     OP_REQUIRES_OK(context, context->GetAttr("ellipsis_mask", &ellipsis_mask));
    277     OP_REQUIRES_OK(context, context->GetAttr("new_axis_mask", &new_axis_mask));
    278     OP_REQUIRES_OK(context,
    279                    context->GetAttr("shrink_axis_mask", &shrink_axis_mask));
    280   }
    281 
    282   void Compute(OpKernelContext* context) override {
    283     TensorShape processing_shape, final_shape;
    284     bool is_identity = true;
    285     bool slice_dim0 = true;
    286     bool is_simple_slice = true;
    287     gtl::InlinedVector<int64, 4> begin;
    288     gtl::InlinedVector<int64, 4> end;
    289     gtl::InlinedVector<int64, 4> strides;
    290 
    291     Tensor old_lhs;
    292     if (context->input_dtype(0) == DT_RESOURCE) {
    293       Var* v;
    294       OP_REQUIRES_OK(context,
    295                      LookupResource(context, HandleFromInput(context, 0), &v));
    296       old_lhs = *v->tensor();
    297       OP_REQUIRES(context, old_lhs.dtype() == DataTypeToEnum<T>::value,
    298                   errors::InvalidArgument(
    299                       "l-value dtype ", DataTypeString(old_lhs.dtype()),
    300                       " does not match r-value dtype ",
    301                       DataTypeString(DataTypeToEnum<T>::value)));
    302     } else {
    303       context->forward_ref_input_to_ref_output(0, 0);
    304       old_lhs = context->mutable_input(0, true);
    305     }
    306 
    307     OP_REQUIRES_OK(
    308         context,
    309         ValidateStridedSliceOp(
    310             &context->input(1), &context->input(2), context->input(3),
    311             old_lhs.shape(), begin_mask, end_mask, ellipsis_mask, new_axis_mask,
    312             shrink_axis_mask, &processing_shape, &final_shape, &is_identity,
    313             &is_simple_slice, &slice_dim0, &begin, &end, &strides));
    314 
    315     if (processing_shape.num_elements()) {
    316       const Tensor& input = context->input(4);
    317       TensorShape input_shape = input.shape();
    318       TensorShape original_shape = old_lhs.shape();
    319       // TODO(aselle): This check is too strong, we only should need
    320       // input_shape to be broadcastable to final_shape
    321       OP_REQUIRES(
    322           context, final_shape == input_shape,
    323           errors::Unimplemented(
    324               "sliced l-value shape ", final_shape.DebugString(),
    325               " does not match r-value shape ", input_shape.DebugString(),
    326               ". Automatic broadcasting not ", "yet implemented."));
    327       const int processing_dims = processing_shape.dims();
    328 
    329       // 0-dimensional case implies the left and right are exactly the same
    330       // scalar shape
    331 
    332 // Handle general dimensions
    333 #define HANDLE_DIM(NDIM)                                                 \
    334   if (processing_dims == NDIM) {                                         \
    335     HandleStridedSliceAssignCase<Device, T, NDIM>()(                     \
    336         context, begin, end, strides, processing_shape, is_simple_slice, \
    337         &old_lhs);                                                       \
    338     return;                                                              \
    339   }
    340       HANDLE_DIM(0);
    341       HANDLE_DIM(1);
    342       HANDLE_DIM(2);
    343       HANDLE_DIM(3);
    344       HANDLE_DIM(4);
    345       HANDLE_DIM(5);
    346       HANDLE_DIM(6);
    347       HANDLE_DIM(7);
    348 #undef HANDLE_DIM
    349 
    350       OP_REQUIRES(context, false,
    351                   errors::Unimplemented("Unhandled input dimensions ",
    352                                         processing_dims));
    353     }
    354   }
    355 
    356  private:
    357   int32 begin_mask, end_mask;
    358   int32 ellipsis_mask, new_axis_mask, shrink_axis_mask;
    359 };
    360 
    361 #define REGISTER_STRIDED_SLICE(type)                             \
    362   REGISTER_KERNEL_BUILDER(Name("StridedSlice")                   \
    363                               .Device(DEVICE_CPU)                \
    364                               .TypeConstraint<type>("T")         \
    365                               .HostMemory("begin")               \
    366                               .HostMemory("end")                 \
    367                               .HostMemory("strides"),            \
    368                           StridedSliceOp<CPUDevice, type>)       \
    369   REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad")               \
    370                               .Device(DEVICE_CPU)                \
    371                               .TypeConstraint<type>("T")         \
    372                               .HostMemory("shape")               \
    373                               .HostMemory("begin")               \
    374                               .HostMemory("end")                 \
    375                               .HostMemory("strides"),            \
    376                           StridedSliceGradOp<CPUDevice, type>)   \
    377   REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign")             \
    378                               .Device(DEVICE_CPU)                \
    379                               .TypeConstraint<type>("T")         \
    380                               .HostMemory("begin")               \
    381                               .HostMemory("end")                 \
    382                               .HostMemory("strides"),            \
    383                           StridedSliceAssignOp<CPUDevice, type>) \
    384   REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign")     \
    385                               .Device(DEVICE_CPU)                \
    386                               .TypeConstraint<type>("T")         \
    387                               .HostMemory("ref")                 \
    388                               .HostMemory("begin")               \
    389                               .HostMemory("end")                 \
    390                               .HostMemory("strides"),            \
    391                           StridedSliceAssignOp<CPUDevice, type>)
    392 
    393 TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE);
    394 
    395 #undef REGISTER_STRIDED_SLICE
    396 
    397 #if GOOGLE_CUDA
    398 
    399 #define REGISTER_GPU(type)                                       \
    400   REGISTER_KERNEL_BUILDER(Name("StridedSlice")                   \
    401                               .Device(DEVICE_GPU)                \
    402                               .TypeConstraint<type>("T")         \
    403                               .HostMemory("begin")               \
    404                               .HostMemory("end")                 \
    405                               .HostMemory("strides"),            \
    406                           StridedSliceOp<GPUDevice, type>)       \
    407   REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad")               \
    408                               .Device(DEVICE_GPU)                \
    409                               .TypeConstraint<type>("T")         \
    410                               .HostMemory("shape")               \
    411                               .HostMemory("begin")               \
    412                               .HostMemory("end")                 \
    413                               .HostMemory("strides"),            \
    414                           StridedSliceGradOp<GPUDevice, type>)   \
    415   REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign")             \
    416                               .Device(DEVICE_GPU)                \
    417                               .TypeConstraint<type>("T")         \
    418                               .HostMemory("begin")               \
    419                               .HostMemory("end")                 \
    420                               .HostMemory("strides"),            \
    421                           StridedSliceAssignOp<GPUDevice, type>) \
    422   REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign")     \
    423                               .Device(DEVICE_GPU)                \
    424                               .TypeConstraint<type>("T")         \
    425                               .HostMemory("ref")                 \
    426                               .HostMemory("begin")               \
    427                               .HostMemory("end")                 \
    428                               .HostMemory("strides"),            \
    429                           StridedSliceAssignOp<GPUDevice, type>)
    430 
    431 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
    432 TF_CALL_complex64(REGISTER_GPU);
    433 TF_CALL_complex128(REGISTER_GPU);
    434 TF_CALL_int64(REGISTER_GPU);
    435 
    436 // A special GPU kernel for int32.
    437 // TODO(b/25387198): Also enable int32 in device memory. This kernel
    438 // registration requires all int32 inputs and outputs to be in host memory.
    439 REGISTER_KERNEL_BUILDER(Name("StridedSlice")
    440                             .Device(DEVICE_GPU)
    441                             .TypeConstraint<int32>("T")
    442                             .HostMemory("input")
    443                             .HostMemory("begin")
    444                             .HostMemory("end")
    445                             .HostMemory("strides")
    446                             .HostMemory("output"),
    447                         StridedSliceOp<CPUDevice, int32>);
    448 REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad")
    449                             .Device(DEVICE_GPU)
    450                             .TypeConstraint<int32>("T")
    451                             .HostMemory("shape")
    452                             .HostMemory("begin")
    453                             .HostMemory("end")
    454                             .HostMemory("strides")
    455                             .HostMemory("dy")
    456                             .HostMemory("output"),
    457                         StridedSliceGradOp<CPUDevice, int32>);
    458 REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign")
    459                             .Device(DEVICE_GPU)
    460                             .TypeConstraint<int32>("T")
    461                             .HostMemory("ref")
    462                             .HostMemory("begin")
    463                             .HostMemory("end")
    464                             .HostMemory("strides"),
    465                         StridedSliceAssignOp<CPUDevice, int32>)
    466 REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign")
    467                             .Device(DEVICE_GPU)
    468                             .TypeConstraint<int32>("T")
    469                             .HostMemory("ref")
    470                             .HostMemory("begin")
    471                             .HostMemory("end")
    472                             .HostMemory("strides"),
    473                         StridedSliceAssignOp<CPUDevice, int32>)
    474 #undef REGISTER_GPU
    475 
    476 #endif  // GOOGLE_CUDA
    477 
    478 #ifdef TENSORFLOW_USE_SYCL
    479 #define REGISTER_SYCL(type)                                       \
    480   REGISTER_KERNEL_BUILDER(Name("StridedSlice")                    \
    481                               .Device(DEVICE_SYCL)                \
    482                               .TypeConstraint<type>("T")          \
    483                               .HostMemory("begin")                \
    484                               .HostMemory("end")                  \
    485                               .HostMemory("strides"),             \
    486                           StridedSliceOp<SYCLDevice, type>)       \
    487   REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad")                \
    488                               .Device(DEVICE_SYCL)                \
    489                               .TypeConstraint<type>("T")          \
    490                               .HostMemory("shape")                \
    491                               .HostMemory("begin")                \
    492                               .HostMemory("end")                  \
    493                               .HostMemory("strides"),             \
    494                           StridedSliceGradOp<SYCLDevice, type>)   \
    495   REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign")              \
    496                               .Device(DEVICE_SYCL)                \
    497                               .TypeConstraint<type>("T")          \
    498                               .HostMemory("begin")                \
    499                               .HostMemory("end")                  \
    500                               .HostMemory("strides"),             \
    501                           StridedSliceAssignOp<SYCLDevice, type>) \
    502   REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign")      \
    503                               .Device(DEVICE_SYCL)                \
    504                               .TypeConstraint<type>("T")          \
    505                               .HostMemory("ref")                  \
    506                               .HostMemory("begin")                \
    507                               .HostMemory("end")                  \
    508                               .HostMemory("strides"),             \
    509                           StridedSliceAssignOp<SYCLDevice, type>)
    510 
    511 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL);
    512 
    513 REGISTER_KERNEL_BUILDER(Name("StridedSlice")
    514                             .Device(DEVICE_SYCL)
    515                             .TypeConstraint<int32>("T")
    516                             .HostMemory("input")
    517                             .HostMemory("begin")
    518                             .HostMemory("end")
    519                             .HostMemory("strides")
    520                             .HostMemory("output"),
    521                         StridedSliceOp<CPUDevice, int32>);
    522 REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad")
    523                             .Device(DEVICE_SYCL)
    524                             .TypeConstraint<int32>("T")
    525                             .HostMemory("shape")
    526                             .HostMemory("begin")
    527                             .HostMemory("end")
    528                             .HostMemory("strides")
    529                             .HostMemory("dy")
    530                             .HostMemory("output"),
    531                         StridedSliceGradOp<CPUDevice, int32>);
    532 REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign")
    533                             .Device(DEVICE_SYCL)
    534                             .TypeConstraint<int32>("T")
    535                             .HostMemory("ref")
    536                             .HostMemory("begin")
    537                             .HostMemory("end")
    538                             .HostMemory("strides"),
    539                         StridedSliceAssignOp<CPUDevice, int32>)
    540 REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign")
    541                             .Device(DEVICE_SYCL)
    542                             .TypeConstraint<int32>("T")
    543                             .HostMemory("ref")
    544                             .HostMemory("begin")
    545                             .HostMemory("end")
    546                             .HostMemory("strides"),
    547                         StridedSliceAssignOp<CPUDevice, int32>)
    548 #undef REGISTER_SYCL
    549 #endif  // TENSORFLOW_USE_SYCL
    550 }  // namespace tensorflow
    551