Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #if GOOGLE_CUDA
     21 #define EIGEN_USE_GPU
     22 #endif  // GOOGLE_CUDA
     23 
     24 #include "tensorflow/core/kernels/slice_op.h"
     25 
     26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     27 #include "tensorflow/core/framework/op_kernel.h"
     28 #include "tensorflow/core/framework/register_types.h"
     29 #include "tensorflow/core/framework/tensor.h"
     30 #include "tensorflow/core/kernels/ops_util.h"
     31 #include "tensorflow/core/lib/core/status.h"
     32 #include "tensorflow/core/lib/gtl/array_slice.h"
     33 #include "tensorflow/core/platform/prefetch.h"
     34 
     35 namespace tensorflow {
     36 
     37 namespace {
     38 
     39 gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) {
     40   gtl::InlinedVector<int64, 4> out;
     41   if (tensor.dtype() == DT_INT32) {
     42     for (int64 i = 0; i < tensor.NumElements(); ++i) {
     43       out.push_back(tensor.flat<int32>()(i));
     44     }
     45   } else if (tensor.dtype() == DT_INT64) {
     46     for (int64 i = 0; i < tensor.NumElements(); ++i) {
     47       out.push_back(tensor.flat<int64>()(i));
     48     }
     49   } else {
     50     LOG(FATAL) << "begin must be either int32 or int64";
     51   }
     52   return out;
     53 }
     54 
     55 }  // namespace
     56 
     57 typedef Eigen::ThreadPoolDevice CPUDevice;
     58 typedef Eigen::GpuDevice GPUDevice;
     59 #ifdef TENSORFLOW_USE_SYCL
     60 typedef Eigen::SyclDevice SYCLDevice;
     61 #endif  // TENSORFLOW_USE_SYCL
     62 
     63 // Shared code that is not dependent on the type of T.  We do this to reduce
     64 // code size by not duplicating all this for all T (float, double, int32, etc.)
     65 static void SharedValidation(OpKernelContext* context,
     66                              TensorShape* output_shape, bool* is_identity,
     67                              bool* slice_dim0,
     68                              gtl::InlinedVector<int64, 4>* begin,
     69                              gtl::InlinedVector<int64, 4>* size) {
     70   const Tensor& input = context->input(0);
     71   const Tensor& begin_tensor = context->input(1);
     72   const Tensor& size_tensor = context->input(2);
     73 
     74   OP_REQUIRES(
     75       context,
     76       context->op_kernel().IsLegacyVector(begin_tensor.shape()) &&
     77           context->op_kernel().IsLegacyVector(size_tensor.shape()) &&
     78           begin_tensor.NumElements() == input.dims() &&
     79           size_tensor.NumElements() == input.dims(),
     80       errors::InvalidArgument(
     81           "Expected begin and size arguments to be 1-D tensors of size ",
     82           input.dims(), ", but got shapes ", begin_tensor.shape().DebugString(),
     83           " and ", size_tensor.shape().DebugString(), " instead."));
     84 
     85   const int input_dims = input.dims();
     86   *begin = IntTensorToInt64Vec(begin_tensor);
     87   *size = IntTensorToInt64Vec(size_tensor);
     88   for (int i = 0; i < input_dims; ++i) {
     89     if ((*size)[i] == -1) {
     90       // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
     91       (*size)[i] = input.dim_size(i) - (*begin)[i];
     92     }
     93   }
     94 
     95   *is_identity = true;
     96   *slice_dim0 = true;
     97   for (int i = 0; i < input_dims; ++i) {
     98     int64 b = (*begin)[i];
     99     int64 s = (*size)[i];
    100     if (input.dim_size(i) == 0) {
    101       OP_REQUIRES(
    102           context, b == 0 && s == 0,
    103           errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b,
    104                                   ") and size[", i, "] == 0 ", "(got ", s,
    105                                   ") when ", "input.dim_size(", i, ") == 0"));
    106     } else {
    107       OP_REQUIRES(context, 0 <= b && b <= input.dim_size(i),
    108                   errors::InvalidArgument("Expected begin[", i, "] in [0, ",
    109                                           input.dim_size(i), "], but got ", b));
    110       OP_REQUIRES(
    111           context, 0 <= s && b + s <= input.dim_size(i),
    112           errors::InvalidArgument("Expected size[", i, "] in [0, ",
    113                                   input.dim_size(i) - b, "], but ", "got ", s));
    114     }
    115     output_shape->AddDim(s);
    116     const bool take_all = (b == 0) && (s == input.dim_size(i));
    117     (*is_identity) &= take_all;
    118     (*slice_dim0) &= (i == 0) || take_all;
    119   }
    120 }
    121 
    122 // Extracted out code in SliceOp::Compute so that MklSliceOp can reuse this
    123 // generic code
    124 template <typename T>
    125 static void SharedSliceCommonCases(OpKernelContext* context,
    126                                    TensorShape* output_shape,
    127                                    gtl::InlinedVector<int64, 4>* begin,
    128                                    gtl::InlinedVector<int64, 4>* size,
    129                                    Tensor** result, bool* done) {
    130   bool is_identity = true;
    131   bool slice_dim0 = true;
    132   *done = false;
    133 
    134   SharedValidation(context, output_shape, &is_identity, &slice_dim0, begin,
    135                    size);
    136   if (!context->status().ok()) return;
    137   const Tensor& input = context->input(0);
    138   if (is_identity) {
    139     VLOG(1) << "Slice identity";
    140     context->set_output(0, input);
    141     *done = true;
    142     return;
    143   }
    144 
    145   if (slice_dim0 &&
    146       IsDim0SliceAligned<T>(input.shape(), (*begin)[0], (*size)[0])) {
    147     VLOG(1) << "Slice dim 0: " << input.shape().DebugString();
    148     CHECK_GE(input.dims(), 1);  // Otherwise, is_identity should be true.
    149     context->set_output(0, input.Slice((*begin)[0], (*begin)[0] + (*size)[0]));
    150     *done = true;
    151     return;
    152   }
    153 
    154   OP_REQUIRES_OK(context, context->allocate_output(0, *output_shape, result));
    155 }
    156 
    157 template <typename Device, typename T>
    158 class SliceOp : public OpKernel {
    159  public:
    160   explicit SliceOp(OpKernelConstruction* context) : OpKernel(context) {}
    161 
    162   void Compute(OpKernelContext* context) override {
    163     TensorShape output_shape;
    164     gtl::InlinedVector<int64, 4> begin;
    165     gtl::InlinedVector<int64, 4> size;
    166     Tensor* result = nullptr;
    167     bool done = false;
    168     SharedSliceCommonCases<T>(context, &output_shape, &begin, &size, &result,
    169                               &done);
    170     if (!context->status().ok() || done == true) return;
    171 
    172     const Tensor& input = context->input(0);
    173     const int input_dims = input.dims();
    174 
    175     if (output_shape.num_elements() > 0) {
    176       if (std::is_same<Device, CPUDevice>::value && input_dims == 2 &&
    177           DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
    178         auto input = context->input(0).tensor<T, 2>();
    179         auto output = result->tensor<T, 2>();
    180         // TODO(agarwal): Consider multi-threading this loop for cases where
    181         // size[0] is very large.
    182         for (int i = 0; i < size[0]; ++i) {
    183           const int64 row = begin[0] + i;
    184           if (i + 1 < size[0]) {
    185             port::prefetch<port::PREFETCH_HINT_T0>(&output(i + 1, 0));
    186             port::prefetch<port::PREFETCH_HINT_T0>(&input(row + 1, begin[1]));
    187           }
    188           memcpy(&output(i, 0), &input(row, begin[1]), size[1] * sizeof(T));
    189         }
    190         return;
    191       }
    192 #define HANDLE_DIM(NDIM)                            \
    193   if (input_dims == NDIM) {                         \
    194     HandleCase<NDIM>(context, begin, size, result); \
    195     return;                                         \
    196   }
    197 
    198       HANDLE_DIM(1);
    199       HANDLE_DIM(2);
    200       HANDLE_DIM(3);
    201       HANDLE_DIM(4);
    202       HANDLE_DIM(5);
    203       HANDLE_DIM(6);
    204       HANDLE_DIM(7);
    205 
    206 #undef HANDLE_DIM
    207 
    208       OP_REQUIRES(
    209           context, false,
    210           errors::Unimplemented("SliceOp : Unhandled input dimensions"));
    211     }
    212   }
    213 
    214  private:
    215   template <int NDIM>
    216   void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
    217                   const gtl::ArraySlice<int64>& size, Tensor* result) {
    218     Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
    219     Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
    220     for (int i = 0; i < NDIM; ++i) {
    221       indices[i] = begin[i];
    222       sizes[i] = size[i];
    223     }
    224 
    225     functor::Slice<Device, T, NDIM>()(
    226         context->eigen_device<Device>(), result->tensor<T, NDIM>(),
    227         context->input(0).tensor<T, NDIM>(), indices, sizes);
    228   }
    229 };
    230 
    231 #ifdef INTEL_MKL
    232 template <typename Device, typename T>
    233 class MklSliceOp : public OpKernel {
    234  public:
    235   explicit MklSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
    236 
    237   void Compute(OpKernelContext* context) override {
    238     TensorShape output_shape;
    239     gtl::InlinedVector<int64, 4> begin;
    240     gtl::InlinedVector<int64, 4> size;
    241     Tensor* result = nullptr;
    242     bool done = false;
    243     SharedSliceCommonCases<T>(context, &output_shape, &begin, &size, &result,
    244                               &done);
    245     if (!context->status().ok() || done == true) return;
    246 
    247     const Tensor& input = context->input(0);
    248     const int input_dims = input.dims();
    249 
    250     if (output_shape.num_elements() > 0) {
    251       if (std::is_same<Device, CPUDevice>::value && input_dims == 2 &&
    252           DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
    253         auto input = context->input(0).tensor<T, 2>();
    254         auto output = result->tensor<T, 2>();
    255         // TODO(agarwal): Consider multi-threading this loop for cases where
    256         // size[0] is very large.
    257         for (int i = 0; i < size[0]; ++i) {
    258           const int64 row = begin[0] + i;
    259           if (i + 1 < size[0]) {
    260             port::prefetch<port::PREFETCH_HINT_T0>(&output(i + 1, 0));
    261             port::prefetch<port::PREFETCH_HINT_T0>(&input(row + 1, begin[1]));
    262           }
    263           memcpy(&output(i, 0), &input(row, begin[1]), size[1] * sizeof(T));
    264         }
    265         return;
    266       }
    267 #define HANDLE_DIM(NDIM)                            \
    268   if (input_dims == NDIM) {                         \
    269     HandleCase<NDIM>(context, begin, size, result); \
    270     return;                                         \
    271   }
    272 
    273       HANDLE_DIM(1);
    274       HANDLE_DIM(2);
    275       HANDLE_DIM(3);
    276       HANDLE_DIM(4);
    277       HANDLE_DIM(5);
    278       HANDLE_DIM(6);
    279       HANDLE_DIM(7);
    280 
    281 #undef HANDLE_DIM
    282 
    283       OP_REQUIRES(
    284           context, false,
    285           errors::Unimplemented("SliceOp : Unhandled input dimensions"));
    286     }
    287   }
    288 
    289  private:
    290   // Helper function for DoesSliceShapeDifferInOnly1D. Checks if the following
    291   // criteria matches for slice_dim: if indices for slice are 0 in all dims
    292   // except slice_dim and if sizes of all the dimensions of the slice are same
    293   // as the sizes of all the dimensions of the input except slice_dim, then
    294   // returns True. Otherwise, returns False.
    295   bool DoesSliceShapeDifferInOnly1DHelper(const TensorShape& input_shape,
    296                                           const gtl::ArraySlice<int64>& begin,
    297                                           const gtl::ArraySlice<int64>& size,
    298                                           int slice_dim) {
    299     for (int dim = 0; dim < 4; dim++) {
    300       if (dim != slice_dim &&
    301           (begin[dim] != 0 || size[dim] != input_shape.dim_size(dim))) {
    302         return false;
    303       }
    304     }
    305     return true;
    306   }
    307 
    308   // Is 'input' tensor being sliced over a single dimension out of 4?
    309   //
    310   // This check is applicable in the context of Slice of a 4-D tensor in
    311   // NHWC or NCHW format over channel dimension.
    312   //
    313   // If indices for slice are 0 in all dims except one dimension and if sizes of
    314   // all dimensions of slice are same as sizes of all dimensions of inputs
    315   // except that dimension, then we are slicing over a single dimension.
    316   //
    317   // Returns True if Slicing over a single dimension, and sets slice_dim
    318   // to the number of the dimension that satisfies criteria.
    319   bool DoesSliceShapeDifferInOnly1D(const TensorShape& input_shape,
    320                                     const gtl::ArraySlice<int64>& begin,
    321                                     const gtl::ArraySlice<int64>& size,
    322                                     int* slice_dim) {
    323     for (int dim = 0; dim < 4; dim++) {
    324       if (DoesSliceShapeDifferInOnly1DHelper(input_shape, begin, size, dim)) {
    325         *slice_dim = dim;
    326         return true;
    327       }
    328     }
    329     return false;
    330   }
    331 
    332   template <int NDIM>
    333   void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
    334                   const gtl::ArraySlice<int64>& size, Tensor* result) {
    335     int slice_dim = -1;
    336     TensorShape in_shape = context->input(0).shape();
    337     // Special case for handling 4-D tensor slice when shape of the slice
    338     // differs from the input tensor in only 1 out of 4 dimensions.
    339     // This case arises in the context of Slice of 4-D tensor in NHWC or NCHW
    340     // format over channel dimension.
    341     if (NDIM == 4 &&
    342         DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) {
    343       size_t in_strides[4] = {
    344           (size_t)in_shape.dim_size(1) * in_shape.dim_size(2) *
    345               in_shape.dim_size(3),
    346           (size_t)in_shape.dim_size(2) * in_shape.dim_size(3),
    347           (size_t)in_shape.dim_size(3), (size_t)1};
    348 
    349       size_t out_strides[4] = {(size_t)size[1] * size[2] * size[3],
    350                                (size_t)size[2] * size[3], (size_t)size[3],
    351                                (size_t)1};
    352 
    353       T* in_buf = const_cast<T*>(
    354           const_cast<const T*>(context->input(0).flat<T>().data()));
    355       T* op_buf = result->flat<T>().data();
    356 
    357       if (slice_dim == 1) {
    358         /* data format = NCHW */
    359 
    360 #pragma omp parallel for
    361         for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
    362           T* ip = in_buf + (d0 * in_strides[0]);
    363           T* op = op_buf + ((d0 - begin[0]) * out_strides[0]);
    364 #pragma omp parallel for
    365           for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
    366             T* ip1 = ip + (d1 * in_strides[1]);
    367             T* op1 = op + ((d1 - begin[1]) * out_strides[1]);
    368             // For NCHW, H and W will be contiguous. So we can copy
    369             // both with one memcpy.
    370             memcpy(static_cast<void*>(op1), static_cast<void*>(ip1),
    371                    sizeof(T) * in_strides[1]);
    372           }
    373         }
    374         return;
    375       } else if (slice_dim == 3) {
    376         /* data_format = NHWC */
    377 
    378 #pragma omp parallel for
    379         for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
    380           T* ip = in_buf + (d0 * in_strides[0]);
    381           T* op = op_buf + ((d0 - begin[0]) * out_strides[0]);
    382 #pragma omp parallel for
    383           for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
    384             T* ip1 = ip + (d1 * in_strides[1]);
    385             T* op1 = op + ((d1 - begin[1]) * out_strides[1]);
    386 #pragma omp parallel for
    387             for (ssize_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) {
    388               T* ip2 = ip1 + (d2 * in_strides[2]);
    389               T* ip3 = ip2 + begin[3];
    390               T* op2 = op1 + ((d2 - begin[2]) * out_strides[2]);
    391               T* op3 = op2;
    392               memcpy(static_cast<void*>(op3), static_cast<void*>(ip3),
    393                      sizeof(T) * size[3]);
    394             }
    395           }
    396         }
    397         return;
    398       }
    399       // slice_dim is not 1 or 3, then we fallback to Eigen implementation.
    400     }
    401 
    402     Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
    403     Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
    404     for (int i = 0; i < NDIM; ++i) {
    405       indices[i] = begin[i];
    406       sizes[i] = size[i];
    407     }
    408 
    409     functor::Slice<Device, T, NDIM>()(
    410         context->eigen_device<Device>(), result->tensor<T, NDIM>(),
    411         context->input(0).tensor<T, NDIM>(), indices, sizes);
    412   }
    413 };
    414 #endif
    415 
    416 // Forward declarations of the functor specializations for declared in the
    417 // sharded source files.
    418 namespace functor {
    419 #define DECLARE_CPU_SPEC(T, NDIM)                                  \
    420   template <>                                                      \
    421   void Slice<CPUDevice, T, NDIM>::operator()(                      \
    422       const CPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
    423       typename TTypes<T, NDIM>::ConstTensor input,                 \
    424       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices,       \
    425       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes);        \
    426   extern template struct Slice<CPUDevice, T, NDIM>;
    427 
    428 #define DECLARE_FOR_N(T)  \
    429   DECLARE_CPU_SPEC(T, 1); \
    430   DECLARE_CPU_SPEC(T, 2); \
    431   DECLARE_CPU_SPEC(T, 3); \
    432   DECLARE_CPU_SPEC(T, 4); \
    433   DECLARE_CPU_SPEC(T, 5); \
    434   DECLARE_CPU_SPEC(T, 6); \
    435   DECLARE_CPU_SPEC(T, 7);
    436 
    437 TF_CALL_ALL_TYPES(DECLARE_FOR_N);
    438 
    439 #undef DECLARE_FOR_N
    440 #undef DECLARE_CPU_SPEC
    441 }  // namespace functor
    442 
    443 #ifndef INTEL_MKL
    444 #define REGISTER_SLICE(type)                             \
    445   REGISTER_KERNEL_BUILDER(Name("Slice")                  \
    446                               .Device(DEVICE_CPU)        \
    447                               .TypeConstraint<type>("T") \
    448                               .HostMemory("begin")       \
    449                               .HostMemory("size"),       \
    450                           SliceOp<CPUDevice, type>)
    451 
    452 TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
    453 TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
    454 #undef REGISTER_SLICE
    455 #else
    456 #define REGISTER_SLICE(type)                             \
    457   REGISTER_KERNEL_BUILDER(Name("Slice")                  \
    458                               .Device(DEVICE_CPU)        \
    459                               .TypeConstraint<type>("T") \
    460                               .HostMemory("begin")       \
    461                               .HostMemory("size"),       \
    462                           MklSliceOp<CPUDevice, type>)
    463 
    464 TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
    465 TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
    466 #undef REGISTER_SLICE
    467 #endif  // INTEL_MKL
    468 
    469 #if GOOGLE_CUDA
    470 // Forward declarations of the functor specializations for GPU.
    471 namespace functor {
    472 #define DECLARE_GPU_SPEC(T, NDIM)                                  \
    473   template <>                                                      \
    474   void Slice<GPUDevice, T, NDIM>::operator()(                      \
    475       const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
    476       typename TTypes<T, NDIM>::ConstTensor input,                 \
    477       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices,       \
    478       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes);        \
    479   extern template struct Slice<GPUDevice, T, NDIM>;
    480 
    481 #define DECLARE_FOR_N(T)  \
    482   DECLARE_GPU_SPEC(T, 1); \
    483   DECLARE_GPU_SPEC(T, 2); \
    484   DECLARE_GPU_SPEC(T, 3); \
    485   DECLARE_GPU_SPEC(T, 4); \
    486   DECLARE_GPU_SPEC(T, 5); \
    487   DECLARE_GPU_SPEC(T, 6); \
    488   DECLARE_GPU_SPEC(T, 7);
    489 
    490 TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N);
    491 TF_CALL_complex64(DECLARE_FOR_N);
    492 TF_CALL_complex128(DECLARE_FOR_N);
    493 TF_CALL_bfloat16(DECLARE_FOR_N);
    494 DECLARE_FOR_N(int32);
    495 
    496 #undef DECLARE_FOR_N
    497 #undef DECLARE_GPU_SPEC
    498 }  // namespace functor
    499 
    500 #define REGISTER_GPU(type)                                     \
    501   REGISTER_KERNEL_BUILDER(Name("Slice")                        \
    502                               .Device(DEVICE_GPU)              \
    503                               .TypeConstraint<type>("T")       \
    504                               .HostMemory("begin")             \
    505                               .HostMemory("size")              \
    506                               .TypeConstraint<int32>("Index"), \
    507                           SliceOp<GPUDevice, type>)
    508 
    509 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
    510 TF_CALL_complex64(REGISTER_GPU);
    511 TF_CALL_complex128(REGISTER_GPU);
    512 TF_CALL_bfloat16(REGISTER_GPU);
    513 
    514 // A special GPU kernel for int32.
    515 // TODO(b/25387198): Also enable int32 in device memory. This kernel
    516 // registration requires all int32 inputs and outputs to be in host memory.
    517 REGISTER_KERNEL_BUILDER(Name("Slice")
    518                             .Device(DEVICE_GPU)
    519                             .TypeConstraint<int32>("T")
    520                             .TypeConstraint<int32>("Index")
    521                             .HostMemory("input")
    522                             .HostMemory("begin")
    523                             .HostMemory("size")
    524                             .HostMemory("output"),
    525                         SliceOp<CPUDevice, int32>);
    526 
    527 #undef REGISTER_GPU
    528 
    529 #endif  // GOOGLE_CUDA
    530 
    531 #ifdef TENSORFLOW_USE_SYCL
    532 // Forward declarations of the functor specializations for SYCL.
    533 namespace functor {
    534 #define DECLARE_SYCL_SPEC(T, NDIM)                                  \
    535   template <>                                                       \
    536   void Slice<SYCLDevice, T, NDIM>::operator()(                      \
    537       const SYCLDevice& d, typename TTypes<T, NDIM>::Tensor output, \
    538       typename TTypes<T, NDIM>::ConstTensor input,                  \
    539       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices,        \
    540       const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes);         \
    541   extern template struct Slice<SYCLDevice, T, NDIM>;
    542 
    543 #define DECLARE_FOR_N(T)   \
    544   DECLARE_SYCL_SPEC(T, 1); \
    545   DECLARE_SYCL_SPEC(T, 2); \
    546   DECLARE_SYCL_SPEC(T, 3); \
    547   DECLARE_SYCL_SPEC(T, 4); \
    548   DECLARE_SYCL_SPEC(T, 5); \
    549   DECLARE_SYCL_SPEC(T, 6); \
    550   DECLARE_SYCL_SPEC(T, 7);
    551 
    552 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_FOR_N);
    553 DECLARE_FOR_N(int32);
    554 DECLARE_FOR_N(bool);
    555 
    556 #undef DECLARE_FOR_N
    557 #undef DECLARE_SYCL_SPEC
    558 }  // namespace functor
    559 
    560 #define REGISTER_SYCL(type)                                    \
    561   REGISTER_KERNEL_BUILDER(Name("Slice")                        \
    562                               .Device(DEVICE_SYCL)             \
    563                               .TypeConstraint<type>("T")       \
    564                               .HostMemory("begin")             \
    565                               .HostMemory("size")              \
    566                               .TypeConstraint<int32>("Index"), \
    567                           SliceOp<SYCLDevice, type>)
    568 
    569 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL);
    570 
    571 REGISTER_KERNEL_BUILDER(Name("Slice")
    572                             .Device(DEVICE_SYCL)
    573                             .TypeConstraint<int32>("T")
    574                             .TypeConstraint<int32>("Index")
    575                             .HostMemory("input")
    576                             .HostMemory("begin")
    577                             .HostMemory("size")
    578                             .HostMemory("output"),
    579                         SliceOp<CPUDevice, int32>);
    580 #undef REGISTER_SYCL
    581 
    582 #endif  // TENSORFLOW_USE_SYCL
    583 }  // namespace tensorflow
    584