Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if !TENSORFLOW_USE_SYCL
     17 #error This file must only be included when building with SYCL support
     18 #endif
     19 
     20 #ifndef TENSORFLOW_CORE_KERNELS_POOLING_OP_3D_SYCL_H_
     21 #define TENSORFLOW_CORE_KERNELS_POOLING_OP_3D_SYCL_H_
     22 
     23 #include "tensorflow/core/kernels/pooling_ops_3d.h"
     24 
     25 namespace tensorflow {
     26 
     27 typedef Eigen::SyclDevice SYCLDevice;
     28 
     29 // Helper struct to contain the various pool parameters used in the SYCL
     30 // pooling kernels. Similar to the Pool3dParameters, but with a number of
     31 // convenient constructors.
     32 struct SYCL3DPoolParams {
     33   SYCL3DPoolParams(const int depth, const int batch, const int in_planes,
     34                    const int in_rows, const int in_cols, const int out_planes,
     35                    const int out_rows, const int out_cols,
     36                    const std::array<int64, 3>& window,
     37                    const std::array<int64, 3>& stride,
     38                    const std::array<int64, 3>& padding)
     39       : depth_(depth),
     40         batch_(batch),
     41         in_planes_(in_planes),
     42         in_rows_(in_rows),
     43         in_cols_(in_cols),
     44         window_planes_(window[2]),
     45         window_rows_(window[1]),
     46         window_cols_(window[0]),
     47         stride_planes_(stride[2]),
     48         stride_rows_(stride[1]),
     49         stride_cols_(stride[0]),
     50         out_planes_(out_planes),
     51         out_rows_(out_rows),
     52         out_cols_(out_cols),
     53         pad_planes_(padding[2]),
     54         pad_rows_(padding[1]),
     55         pad_cols_(padding[0]) {}
     56 
     57   SYCL3DPoolParams(const int depth, const int batch, const int in_planes,
     58                    const int in_rows, const int in_cols,
     59                    const std::array<int64, 3>& out_shape,
     60                    const std::array<int64, 3>& window,
     61                    const std::array<int64, 3>& stride,
     62                    const std::array<int64, 3>& padding)
     63       : SYCL3DPoolParams(depth, batch, in_planes, in_rows, in_cols,
     64                          out_shape[2], out_shape[1], out_shape[0], window,
     65                          stride, padding) {}
     66 
     67   SYCL3DPoolParams(const Pool3dParameters& params)
     68       : depth_(params.depth),
     69         batch_(params.tensor_in_batch),
     70         in_planes_(params.tensor_in_planes),
     71         in_rows_(params.tensor_in_rows),
     72         in_cols_(params.tensor_in_cols),
     73         window_planes_(params.window_planes),
     74         window_rows_(params.window_rows),
     75         window_cols_(params.window_cols),
     76         stride_planes_(params.plane_stride),
     77         stride_rows_(params.row_stride),
     78         stride_cols_(params.col_stride),
     79         out_planes_(params.out_plane),
     80         out_rows_(params.out_height),
     81         out_cols_(params.out_width),
     82         pad_planes_(params.pad_planes),
     83         pad_rows_(params.pad_rows),
     84         pad_cols_(params.pad_cols) {}
     85 
     86   const int depth_;
     87   const int batch_;
     88   const int in_planes_;
     89   const int in_rows_;
     90   const int in_cols_;
     91 
     92   const int window_planes_;
     93   const int window_rows_;
     94   const int window_cols_;
     95 
     96   const int stride_planes_;
     97   const int stride_rows_;
     98   const int stride_cols_;
     99 
    100   const int out_planes_;
    101   const int out_rows_;
    102   const int out_cols_;
    103 
    104   const int pad_planes_;
    105   const int pad_rows_;
    106   const int pad_cols_;
    107 };
    108 // MaxPool3d SYCL kernel. Expects the number of threads to be equal to the
    109 // number of elements in the output tensor.
    110 //
    111 // For each output element, find the corresponding input window and run over
    112 // all values in the window to find the maximum value. This value is then
    113 // copied into that output element.
    114 template <typename T>
    115 class MaxPool3DSYCL {
    116   using write_accessor =
    117       cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::write,
    118                          cl::sycl::access::target::global_buffer>;
    119   using read_accessor =
    120       cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::read,
    121                          cl::sycl::access::target::global_buffer>;
    122 
    123  public:
    124   MaxPool3DSYCL(const int depth, const int batch, const int in_planes,
    125                 const int in_rows, const int in_cols, const int out_planes,
    126                 const int out_rows, const int out_cols,
    127                 const std::array<int64, 3>& window,
    128                 const std::array<int64, 3>& stride,
    129                 const std::array<int64, 3>& padding,
    130                 const read_accessor input_accessor,
    131                 write_accessor output_accessor)
    132       : p_(depth, batch, in_planes, in_rows, in_cols, out_planes, out_rows,
    133            out_cols, window, stride, padding),
    134         input_accessor_(input_accessor),
    135         output_accessor_(output_accessor) {}
    136   void operator()(cl::sycl::item<1> item) {
    137     T* input_data = ConvertToActualTypeSycl(T, input_accessor_);
    138     T* output_data = ConvertToActualTypeSycl(T, output_accessor_);
    139 
    140     int index = item.get_linear_id();
    141     int n = index;
    142     int d = n % p_.depth_;
    143     n /= p_.depth_;
    144     int cstart = (n % p_.out_cols_) * p_.stride_cols_ - p_.pad_cols_;
    145     int cend = std::min(cstart + p_.window_cols_, p_.in_cols_);
    146     cstart = std::max(cstart, 0);
    147     n /= p_.out_cols_;
    148     int rstart = (n % p_.out_rows_) * p_.stride_rows_ - p_.pad_rows_;
    149     int rend = std::min(rstart + p_.window_rows_, p_.in_rows_);
    150     rstart = std::max(rstart, 0);
    151     n /= p_.out_rows_;
    152     int pstart = (n % p_.out_planes_) * p_.stride_planes_ - p_.pad_planes_;
    153     int pend = std::min(pstart + p_.window_planes_, p_.in_planes_);
    154     pstart = std::max(pstart, 0);
    155     n /= p_.out_planes_;
    156     T maxval = Eigen::NumTraits<T>::lowest();
    157     const T* input_data_n =
    158         input_data + n * p_.in_planes_ * p_.in_cols_ * p_.in_rows_ * p_.depth_;
    159     for (int p = pstart; p < pend; ++p) {
    160       for (int r = rstart; r < rend; ++r) {
    161         for (int c = cstart; c < cend; ++c) {
    162           int idx = ((p * p_.in_rows_ + r) * p_.in_cols_ + c) * p_.depth_ + d;
    163           if (input_data_n[idx] > maxval) {
    164             maxval = input_data_n[idx];
    165           }
    166         }
    167       }
    168     }
    169     output_data[index] = maxval;
    170   }
    171 
    172  private:
    173   const SYCL3DPoolParams p_;
    174   const read_accessor input_accessor_;
    175   write_accessor output_accessor_;
    176 };
    177 template <typename T>
    178 struct LaunchPoolingOp<SYCLDevice, T, MAX> {
    179   static void launch(OpKernelContext* context, const Tensor& tensor_in,
    180                      const std::array<int64, 3>& window,
    181                      const std::array<int64, 3>& stride,
    182                      const std::array<int64, 3>& padding,
    183                      TensorFormat data_format, Padding padding_type,
    184                      Tensor* output) {
    185     const SYCLDevice& device = context->eigen_device<SYCLDevice>();
    186     const int out_planes = GetTensorDim(*output, data_format, '0');
    187     const int out_rows = GetTensorDim(*output, data_format, '1');
    188     const int out_cols = GetTensorDim(*output, data_format, '2');
    189     const int batch = GetTensorDim(tensor_in, data_format, 'N');
    190     const int in_planes = GetTensorDim(tensor_in, data_format, '0');
    191     const int in_rows = GetTensorDim(tensor_in, data_format, '1');
    192     const int in_cols = GetTensorDim(tensor_in, data_format, '2');
    193     const int depth = GetTensorDim(tensor_in, data_format, 'C');
    194 
    195     const int num_threads = output->NumElements();
    196 
    197     auto input_buffer =
    198         device.get_sycl_buffer(tensor_in.template flat<T>().data());
    199     auto output_buffer =
    200         device.get_sycl_buffer(output->template flat<T>().data());
    201 
    202     device.sycl_queue().submit([&](cl::sycl::handler& cgh) {
    203       auto input_access =
    204           input_buffer.template get_access<cl::sycl::access::mode::read>(cgh);
    205       auto output_access =
    206           output_buffer.template get_access<cl::sycl::access::mode::write>(cgh);
    207       MaxPool3DSYCL<T> max_pool(depth, batch, in_planes, in_rows, in_cols,
    208                                 out_planes, out_rows, out_cols, window, stride,
    209                                 padding, input_access, output_access);
    210 
    211       cgh.parallel_for(cl::sycl::range<1>(num_threads), max_pool);
    212     });
    213   }
    214 };
    215 // MaxPool3DGrad SYCL kernel. Expects the number of threads to be equal to the
    216 // number of elements in the output backprop tensor (i.e. the number of elements
    217 // in the input data tensor).
    218 //
    219 // For each output backprop element we compute the possible window of values in
    220 // the input backprop tensor which might contribute to this element. Then for
    221 // each error in this window, compute the corresponding input window which was
    222 // pooled into that element in the output. Walk through this input window to
    223 // determine whether the input value is the first maximum value, and so the
    224 // error should be propagated back to the corresponding backprop element.
    225 template <typename T>
    226 class MaxPool3DGradSYCL {
    227   using write_accessor =
    228       cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::write,
    229                          cl::sycl::access::target::global_buffer>;
    230   using read_accessor =
    231       cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::read,
    232                          cl::sycl::access::target::global_buffer>;
    233 
    234  public:
    235   MaxPool3DGradSYCL(const int depth, const int batch, const int in_planes,
    236                     const int in_rows, const int in_cols,
    237                     const std::array<int64, 3>& output_shape,
    238                     const std::array<int64, 3>& window,
    239                     const std::array<int64, 3>& stride,
    240                     const std::array<int64, 3>& padding,
    241                     const read_accessor input_data_accessor,
    242                     const read_accessor output_data_accessor,
    243                     const read_accessor input_backprop_accessor,
    244                     write_accessor output_backprop_accessor)
    245       : p_(depth, batch, in_planes, in_rows, in_cols, output_shape, window,
    246            stride, padding),
    247         input_data_accessor_(input_data_accessor),
    248         output_data_accessor_(output_data_accessor),
    249         input_backprop_accessor_(input_backprop_accessor),
    250         output_backprop_accessor_(output_backprop_accessor) {}
    251   void operator()(cl::sycl::item<1> item) {
    252     T* input_data = ConvertToActualTypeSycl(T, input_data_accessor_);
    253     T* output_data = ConvertToActualTypeSycl(T, output_data_accessor_);
    254     T* input_backprop = ConvertToActualTypeSycl(T, input_backprop_accessor_);
    255     T* output_backprop = ConvertToActualTypeSycl(T, output_backprop_accessor_);
    256 
    257     const int index = item.get_linear_id();
    258     T output_value = 0;
    259     int n = index;
    260     const int d = n % p_.depth_;
    261     n /= p_.depth_;
    262     const int c = (n % p_.in_cols_) + p_.pad_cols_;
    263     const int poolcstart =
    264         (c < p_.window_cols_) ? 0 : (c - p_.window_cols_) / p_.stride_cols_ + 1;
    265     const int poolcend = std::min(c / p_.stride_cols_ + 1, p_.out_cols_);
    266     n /= p_.in_cols_;
    267     const int r = (n % p_.in_rows_) + p_.pad_rows_;
    268     const int poolrstart =
    269         (r < p_.window_rows_) ? 0 : (r - p_.window_rows_) / p_.stride_rows_ + 1;
    270     const int poolrend = std::min(r / p_.stride_rows_ + 1, p_.out_rows_);
    271     n /= p_.in_rows_;
    272     const int p = (n % p_.in_planes_) + p_.pad_planes_;
    273     const int poolpstart =
    274         (p < p_.window_planes_)
    275             ? 0
    276             : (p - p_.window_planes_) / p_.stride_planes_ + 1;
    277     const int poolpend = std::min(p / p_.stride_planes_ + 1, p_.out_planes_);
    278     n /= p_.in_planes_;
    279     const int index_no_n =
    280         index - n * p_.in_planes_ * p_.in_cols_ * p_.in_rows_ * p_.depth_;
    281 
    282     const T* input_data_n =
    283         input_data + n * p_.in_planes_ * p_.in_cols_ * p_.in_rows_ * p_.depth_;
    284     const T* output_data_n = output_data + n * p_.out_planes_ * p_.out_cols_ *
    285                                                p_.out_rows_ * p_.depth_;
    286     const T* input_backprop_n = input_backprop + n * p_.out_planes_ *
    287                                                      p_.out_cols_ *
    288                                                      p_.out_rows_ * p_.depth_;
    289     for (int poolp = poolpstart; poolp < poolpend; ++poolp) {
    290       int pstart = poolp * p_.stride_planes_ - p_.pad_planes_;
    291       const int pend = std::min(pstart + p_.window_planes_, p_.in_planes_);
    292       pstart = std::max(pstart, 0);
    293 
    294       for (int poolr = poolrstart; poolr < poolrend; ++poolr) {
    295         int rstart = poolr * p_.stride_rows_ - p_.pad_rows_;
    296         const int rend = std::min(rstart + p_.window_rows_, p_.in_rows_);
    297         rstart = std::max(rstart, 0);
    298 
    299         for (int poolc = poolcstart; poolc < poolcend; ++poolc) {
    300           int cstart = poolc * p_.stride_cols_ - p_.pad_cols_;
    301           const int cend = std::min(cstart + p_.window_cols_, p_.in_cols_);
    302           cstart = std::max(cstart, 0);
    303 
    304           const int output_data_idx =
    305               ((poolp * p_.out_rows_ + poolr) * p_.out_cols_ + poolc) *
    306                   p_.depth_ +
    307               d;
    308           bool should_continue = true;
    309           bool is_max = (input_data[index] == output_data_n[output_data_idx]);
    310           for (int win_p = pstart; win_p < pend && should_continue; ++win_p) {
    311             for (int win_r = rstart; win_r < rend && should_continue; ++win_r) {
    312               for (int win_c = cstart; win_c < cend && should_continue;
    313                    ++win_c) {
    314                 const int input_data_idx =
    315                     ((win_p * p_.in_rows_ + win_r) * p_.in_cols_ + win_c) *
    316                         p_.depth_ +
    317                     d;
    318                 if (input_data_idx == index_no_n) {
    319                   should_continue = false;
    320                 } else if (input_data_n[input_data_idx] ==
    321                            output_data_n[output_data_idx]) {
    322                   should_continue = false;
    323                   is_max = false;
    324                 }
    325               }
    326             }
    327           }
    328           if (is_max) {
    329             output_value += input_backprop_n[output_data_idx];
    330           }
    331         }
    332       }
    333     }
    334     output_backprop[index] = output_value;
    335   }
    336 
    337  private:
    338   const SYCL3DPoolParams p_;
    339 
    340   const read_accessor input_data_accessor_;
    341   const read_accessor output_data_accessor_;
    342   const read_accessor input_backprop_accessor_;
    343   write_accessor output_backprop_accessor_;
    344 };
    345 template <typename T>
    346 struct LaunchMaxPooling3dGradOp<SYCLDevice, T> {
    347   static void launch(OpKernelContext* context, const Tensor& tensor_in,
    348                      const Tensor& tensor_out, const Tensor& out_backprop,
    349                      const std::array<int64, 3>& window,
    350                      const std::array<int64, 3>& stride,
    351                      const std::array<int64, 3>& out,
    352                      const std::array<int64, 3>& padding,
    353                      TensorFormat data_format, Tensor* output) {
    354     const SYCLDevice& device = context->eigen_device<SYCLDevice>();
    355     const int batch = GetTensorDim(tensor_in, data_format, 'N');
    356     const int in_planes = GetTensorDim(tensor_in, data_format, '0');
    357     const int in_rows = GetTensorDim(tensor_in, data_format, '1');
    358     const int in_cols = GetTensorDim(tensor_in, data_format, '2');
    359     const int depth = GetTensorDim(tensor_in, data_format, 'C');
    360 
    361     const int output_size = output->NumElements();
    362 
    363     auto input_data_buffer =
    364         device.get_sycl_buffer(tensor_in.template flat<T>().data());
    365     auto output_data_buffer =
    366         device.get_sycl_buffer(tensor_out.template flat<T>().data());
    367     auto input_backprop_buffer =
    368         device.get_sycl_buffer(out_backprop.template flat<T>().data());
    369     auto output_backprop_buffer =
    370         device.get_sycl_buffer(output->template flat<T>().data());
    371 
    372     device.sycl_queue().submit([&](cl::sycl::handler& cgh) {
    373       auto input_data_access =
    374           input_data_buffer.template get_access<cl::sycl::access::mode::read>(
    375               cgh);
    376       auto output_data_access =
    377           output_data_buffer.template get_access<cl::sycl::access::mode::read>(
    378               cgh);
    379       auto input_backprop_access =
    380           input_backprop_buffer
    381               .template get_access<cl::sycl::access::mode::read>(cgh);
    382       auto output_backprop_access =
    383           output_backprop_buffer
    384               .template get_access<cl::sycl::access::mode::write>(cgh);
    385       MaxPool3DGradSYCL<T> max_pool(
    386           depth, batch, in_planes, in_rows, in_cols, out, window, stride,
    387           padding, input_data_access, output_data_access, input_backprop_access,
    388           output_backprop_access);
    389 
    390       cgh.parallel_for(cl::sycl::range<1>(output_size), max_pool);
    391     });
    392   }
    393 };
    394 // MaxPool3DGradGrad SYCL kernel. Expects the number of threads to be equal to
    395 // the number of elements in the output backprop tensor, i.e. the number of
    396 // elements in the output tensor.
    397 //
    398 // For each element in the output backprop tensor, find the corresponding input
    399 // window, and compare the input and output data to find the index of the
    400 // maximum value in the input tensor. This is then the index of the gradient to
    401 // pass through to the output backprop tensor.
    402 template <typename T>
    403 class MaxPool3DGradGradSYCL {
    404   using write_accessor =
    405       cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::write,
    406                          cl::sycl::access::target::global_buffer>;
    407   using read_accessor =
    408       cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::read,
    409                          cl::sycl::access::target::global_buffer>;
    410 
    411  public:
    412   MaxPool3DGradGradSYCL(const Pool3dParameters& params,
    413                         const read_accessor input_data_accessor,
    414                         const read_accessor output_data_accessor,
    415                         const read_accessor input_backprop_accessor,
    416                         write_accessor output_backprop_accessor)
    417       : p_(params),
    418         input_data_accessor_(input_data_accessor),
    419         output_data_accessor_(output_data_accessor),
    420         input_backprop_accessor_(input_backprop_accessor),
    421         output_backprop_accessor_(output_backprop_accessor) {}
    422   void operator()(cl::sycl::item<1> item) {
    423     T* input_data = ConvertToActualTypeSycl(T, input_data_accessor_);
    424     T* output_data = ConvertToActualTypeSycl(T, output_data_accessor_);
    425     T* input_backprop = ConvertToActualTypeSycl(T, input_backprop_accessor_);
    426     T* output_backprop = ConvertToActualTypeSycl(T, output_backprop_accessor_);
    427 
    428     int index = item.get_linear_id();
    429     int n = index;
    430     int d = n % p_.depth_;
    431     n /= p_.depth_;
    432     int cstart = (n % p_.out_cols_) * p_.stride_cols_ - p_.pad_cols_;
    433     int cend = std::min(cstart + p_.window_cols_, p_.in_cols_);
    434     cstart = std::max(cstart, 0);
    435     n /= p_.out_cols_;
    436     int rstart = (n % p_.out_rows_) * p_.stride_rows_ - p_.pad_rows_;
    437     int rend = std::min(rstart + p_.window_rows_, p_.in_rows_);
    438     rstart = std::max(rstart, 0);
    439     n /= p_.out_rows_;
    440     int pstart = (n % p_.out_planes_) * p_.stride_planes_ - p_.pad_planes_;
    441     int pend = std::min(pstart + p_.window_planes_, p_.in_planes_);
    442     pstart = std::max(pstart, 0);
    443     n /= p_.out_planes_;
    444     int maxidx = -1;
    445     bool should_stop = false;
    446     const T* input_data_n =
    447         input_data + n * p_.in_planes_ * p_.in_cols_ * p_.in_rows_ * p_.depth_;
    448     for (int p = pstart; p < pend && !should_stop; ++p) {
    449       for (int r = rstart; r < rend && !should_stop; ++r) {
    450         for (int c = cstart; c < cend && !should_stop; ++c) {
    451           int idx = ((p * p_.in_rows_ + r) * p_.in_cols_ + c) * p_.depth_ + d;
    452           if (output_data[index] == input_data_n[idx]) {
    453             maxidx = idx;
    454             should_stop = true;
    455           }
    456         }
    457       }
    458     }
    459     if (maxidx != -1) {
    460       output_backprop[index] = input_backprop[n * p_.in_planes_ * p_.in_rows_ *
    461                                                   p_.in_cols_ * p_.depth_ +
    462                                               maxidx];
    463     }
    464   }
    465 
    466  private:
    467   const SYCL3DPoolParams p_;
    468 
    469   const read_accessor input_data_accessor_;
    470   const read_accessor output_data_accessor_;
    471   const read_accessor input_backprop_accessor_;
    472   write_accessor output_backprop_accessor_;
    473 };
    474 template <typename T>
    475 struct LaunchMaxPooling3dGradGradOp<SYCLDevice, T> {
    476   static void launch(OpKernelContext* context, const Pool3dParameters& params,
    477                      const Tensor& tensor_in, const Tensor& tensor_out,
    478                      const Tensor& out_backprop, Tensor* output) {
    479     const SYCLDevice& device = context->eigen_device<SYCLDevice>();
    480 
    481     const int num_threads = output->NumElements();
    482 
    483     auto input_data_buffer =
    484         device.get_sycl_buffer(tensor_in.template flat<T>().data());
    485     auto output_data_buffer =
    486         device.get_sycl_buffer(tensor_out.template flat<T>().data());
    487     auto input_backprop_buffer =
    488         device.get_sycl_buffer(out_backprop.template flat<T>().data());
    489     auto output_backprop_buffer =
    490         device.get_sycl_buffer(output->template flat<T>().data());
    491 
    492     device.sycl_queue().submit([&](cl::sycl::handler& cgh) {
    493       auto input_data_access =
    494           input_data_buffer.template get_access<cl::sycl::access::mode::read>(
    495               cgh);
    496       auto output_data_access =
    497           output_data_buffer.template get_access<cl::sycl::access::mode::read>(
    498               cgh);
    499       auto input_backprop_access =
    500           input_backprop_buffer
    501               .template get_access<cl::sycl::access::mode::read>(cgh);
    502       auto output_backprop_access =
    503           output_backprop_buffer
    504               .template get_access<cl::sycl::access::mode::write>(cgh);
    505       MaxPool3DGradGradSYCL<T> functor(
    506           params, input_data_access, output_data_access, input_backprop_access,
    507           output_backprop_access);
    508 
    509       cgh.parallel_for(cl::sycl::range<1>(num_threads), functor);
    510     });
    511   }
    512 };
    513 // AvgPool3D SYCL kernel. Expects the number of threads to be equal to the
    514 // number of elements in the output tensor.
    515 //
    516 // For each output value find the corresponding input window, and run through
    517 // the window accumulating the values to form an average. We divide each value
    518 // before accumulating to prevent the accumulator from becoming significantly
    519 // bigger than the values we are adding and so decrease any errors.
    520 template <typename T>
    521 class AvgPool3DSYCL {
    522   using write_accessor =
    523       cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::write,
    524                          cl::sycl::access::target::global_buffer>;
    525   using read_accessor =
    526       cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::read,
    527                          cl::sycl::access::target::global_buffer>;
    528 
    529  public:
    530   AvgPool3DSYCL(const int depth, const int batch, const int in_planes,
    531                 const int in_rows, const int in_cols, const int out_planes,
    532                 const int out_rows, const int out_cols,
    533                 const std::array<int64, 3>& window,
    534                 const std::array<int64, 3>& stride,
    535                 const std::array<int64, 3>& padding,
    536                 const read_accessor input_accessor,
    537                 write_accessor output_accessor)
    538       : p_(depth, batch, in_planes, in_rows, in_cols, out_planes, out_rows,
    539            out_cols, window, stride, padding),
    540         input_accessor_(input_accessor),
    541         output_accessor_(output_accessor) {}
    542   void operator()(cl::sycl::item<1> item) {
    543     T* input_data = ConvertToActualTypeSycl(T, input_accessor_);
    544     T* output_data = ConvertToActualTypeSycl(T, output_accessor_);
    545 
    546     int index = item.get_linear_id();
    547     int n = index;
    548     int d = n % p_.depth_;
    549     n /= p_.depth_;
    550     int cstart = (n % p_.out_cols_) * p_.stride_cols_ - p_.pad_cols_;
    551     int cend = std::min(cstart + p_.window_cols_, p_.in_cols_);
    552     cstart = std::max(cstart, 0);
    553     n /= p_.out_cols_;
    554     int rstart = (n % p_.out_rows_) * p_.stride_rows_ - p_.pad_rows_;
    555     int rend = std::min(rstart + p_.window_rows_, p_.in_rows_);
    556     rstart = std::max(rstart, 0);
    557     n /= p_.out_rows_;
    558     int pstart = (n % p_.out_planes_) * p_.stride_planes_ - p_.pad_planes_;
    559     int pend = std::min(pstart + p_.window_planes_, p_.in_planes_);
    560     pstart = std::max(pstart, 0);
    561     n /= p_.out_planes_;
    562     T accum = T(0);
    563     T count =
    564         static_cast<T>((pend - pstart) * (rend - rstart) * (cend - cstart));
    565     const T* input_data_n =
    566         input_data + n * p_.in_planes_ * p_.in_cols_ * p_.in_rows_ * p_.depth_;
    567     for (int p = pstart; p < pend; ++p) {
    568       for (int r = rstart; r < rend; ++r) {
    569         for (int c = cstart; c < cend; ++c) {
    570           int idx = ((p * p_.in_rows_ + r) * p_.in_cols_ + c) * p_.depth_ + d;
    571           accum += input_data_n[idx] / count;
    572         }
    573       }
    574     }
    575     output_data[index] = accum;
    576   }
    577 
    578  private:
    579   const SYCL3DPoolParams p_;
    580   const read_accessor input_accessor_;
    581   write_accessor output_accessor_;
    582 };
    583 template <typename T>
    584 struct LaunchPoolingOp<SYCLDevice, T, AVG> {
    585   static void launch(OpKernelContext* context, const Tensor& tensor_in,
    586                      const std::array<int64, 3>& window,
    587                      const std::array<int64, 3>& stride,
    588                      const std::array<int64, 3>& padding,
    589                      TensorFormat data_format, Padding padding_type,
    590                      Tensor* output) {
    591     const SYCLDevice& device = context->eigen_device<SYCLDevice>();
    592     const int out_planes = GetTensorDim(*output, data_format, '0');
    593     const int out_rows = GetTensorDim(*output, data_format, '1');
    594     const int out_cols = GetTensorDim(*output, data_format, '2');
    595     const int batch = GetTensorDim(tensor_in, data_format, 'N');
    596     const int in_planes = GetTensorDim(tensor_in, data_format, '0');
    597     const int in_rows = GetTensorDim(tensor_in, data_format, '1');
    598     const int in_cols = GetTensorDim(tensor_in, data_format, '2');
    599     const int depth = GetTensorDim(tensor_in, data_format, 'C');
    600 
    601     const int num_threads = output->NumElements();
    602 
    603     auto input_buffer =
    604         device.get_sycl_buffer(tensor_in.template flat<T>().data());
    605     auto output_buffer =
    606         device.get_sycl_buffer(output->template flat<T>().data());
    607 
    608     device.sycl_queue().submit([&](cl::sycl::handler& cgh) {
    609       auto input_access =
    610           input_buffer.template get_access<cl::sycl::access::mode::read>(cgh);
    611       auto output_access =
    612           output_buffer.template get_access<cl::sycl::access::mode::write>(cgh);
    613       AvgPool3DSYCL<T> avg_pool(depth, batch, in_planes, in_rows, in_cols,
    614                                 out_planes, out_rows, out_cols, window, stride,
    615                                 padding, input_access, output_access);
    616 
    617       cgh.parallel_for(cl::sycl::range<1>(num_threads), avg_pool);
    618     });
    619   }
    620 };
    621 // AvgPool3DGrad SYCL kernel. Expects the number of threads to be equal to the
    622 // number of elements in the output backprop tensor, i.e. the number of
    623 // elements in the input tensor.
    624 //
    625 // For each output backprop index find a window in the input backprop tensor
    626 // which corresponds to all the values of the output which were affected by the
    627 // input value at this index. Then for each gradient in this window, compute
    628 // the size of the input window which was averaged to give this output, and use
    629 // this size to scale the gradient accordingly. Add this scaled gradient to the
    630 // output backprop value.
    631 template <typename T>
    632 class AvgPool3DGradSYCL {
    633   using write_accessor =
    634       cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::write,
    635                          cl::sycl::access::target::global_buffer>;
    636   using read_accessor =
    637       cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::read,
    638                          cl::sycl::access::target::global_buffer>;
    639 
    640  public:
    641   AvgPool3DGradSYCL(const int depth, const int batch, const int in_planes,
    642                     const int in_rows, const int in_cols,
    643                     const std::array<int64, 3>& out_shape,
    644                     const std::array<int64, 3>& window,
    645                     const std::array<int64, 3>& stride,
    646                     const std::array<int64, 3>& padding,
    647                     const read_accessor input_backprop_accessor,
    648                     write_accessor output_backprop_accessor)
    649       : p_(depth, batch, in_planes, in_rows, in_cols, out_shape, window, stride,
    650            padding),
    651         input_backprop_accessor_(input_backprop_accessor),
    652         output_backprop_accessor_(output_backprop_accessor) {}
    653   void operator()(cl::sycl::item<1> item) {
    654     T* input_backprop = ConvertToActualTypeSycl(T, input_backprop_accessor_);
    655     T* output_backprop = ConvertToActualTypeSycl(T, output_backprop_accessor_);
    656 
    657     const int index = item.get_linear_id();
    658     int n = index;
    659     const int d = n % p_.depth_;
    660     n /= p_.depth_;
    661     const int c = (n % p_.in_cols_) + p_.pad_cols_;
    662     const int poolcstart =
    663         (c < p_.window_cols_) ? 0 : (c - p_.window_cols_) / p_.stride_cols_ + 1;
    664     const int poolcend = std::min(c / p_.stride_cols_ + 1, p_.out_cols_);
    665     n /= p_.in_cols_;
    666     const int r = (n % p_.in_rows_) + p_.pad_rows_;
    667     const int poolrstart =
    668         (r < p_.window_rows_) ? 0 : (r - p_.window_rows_) / p_.stride_rows_ + 1;
    669     const int poolrend = std::min(r / p_.stride_rows_ + 1, p_.out_rows_);
    670     n /= p_.in_rows_;
    671     const int p = (n % p_.in_planes_) + p_.pad_planes_;
    672     const int poolpstart =
    673         (p < p_.window_planes_)
    674             ? 0
    675             : (p - p_.window_planes_) / p_.stride_planes_ + 1;
    676     const int poolpend = std::min(p / p_.stride_planes_ + 1, p_.out_planes_);
    677     n /= p_.in_planes_;
    678 
    679     T gradient = T(0);
    680     const T* input_backprop_n = input_backprop + n * p_.out_planes_ *
    681                                                      p_.out_cols_ *
    682                                                      p_.out_rows_ * p_.depth_;
    683     for (int poolp = poolpstart; poolp < poolpend; ++poolp) {
    684       int pstart = poolp * p_.stride_planes_ - p_.pad_planes_;
    685       const int pend = std::min(pstart + p_.window_planes_, p_.in_planes_);
    686       pstart = std::max(pstart, 0);
    687       const int plane_window_size = pend - pstart;
    688       for (int poolr = poolrstart; poolr < poolrend; ++poolr) {
    689         int rstart = poolr * p_.stride_rows_ - p_.pad_rows_;
    690         const int rend = std::min(rstart + p_.window_rows_, p_.in_rows_);
    691         rstart = std::max(rstart, 0);
    692         const int row_window_size = rend - rstart;
    693         for (int poolc = poolcstart; poolc < poolcend; ++poolc) {
    694           const int idx =
    695               ((poolp * p_.out_rows_ + poolr) * p_.out_cols_ + poolc) *
    696                   p_.depth_ +
    697               d;
    698           int cstart = poolc * p_.stride_cols_ - p_.pad_cols_;
    699           const int cend = std::min(cstart + p_.window_cols_, p_.in_cols_);
    700           cstart = std::max(cstart, 0);
    701           const int col_window_size = cend - cstart;
    702           const int window_size =
    703               plane_window_size * row_window_size * col_window_size;
    704           gradient += input_backprop_n[idx] / static_cast<T>(window_size);
    705         }
    706       }
    707     }
    708     output_backprop[index] = gradient;
    709   }
    710 
    711  private:
    712   const SYCL3DPoolParams p_;
    713   const read_accessor input_backprop_accessor_;
    714   write_accessor output_backprop_accessor_;
    715 };
    716 template <typename T>
    717 struct LaunchAvgPooling3dGradOp<SYCLDevice, T> {
    718   static void launch(OpKernelContext* context,
    719                      const TensorShape& tensor_in_shape,
    720                      const Tensor& out_backprop,
    721                      const std::array<int64, 3>& window,
    722                      const std::array<int64, 3>& stride,
    723                      const std::array<int64, 3>& output_shape,
    724                      const std::array<int64, 3>& padding,
    725                      TensorFormat data_format, Tensor* output) {
    726     const SYCLDevice& device = context->eigen_device<SYCLDevice>();
    727     const int batch = GetTensorDim(tensor_in_shape, data_format, 'N');
    728     const int in_planes = GetTensorDim(tensor_in_shape, data_format, '0');
    729     const int in_rows = GetTensorDim(tensor_in_shape, data_format, '1');
    730     const int in_cols = GetTensorDim(tensor_in_shape, data_format, '2');
    731     const int depth = GetTensorDim(tensor_in_shape, data_format, 'C');
    732 
    733     const int num_threads = output->NumElements();
    734 
    735     auto input_backprop_buffer =
    736         device.get_sycl_buffer(out_backprop.template flat<T>().data());
    737     auto output_backprop_buffer =
    738         device.get_sycl_buffer(output->template flat<T>().data());
    739 
    740     device.sycl_queue().submit([&](cl::sycl::handler& cgh) {
    741       auto input_backprop_access =
    742           input_backprop_buffer
    743               .template get_access<cl::sycl::access::mode::read>(cgh);
    744       auto output_backprop_access =
    745           output_backprop_buffer
    746               .template get_access<cl::sycl::access::mode::write>(cgh);
    747       AvgPool3DGradSYCL<T> functor(
    748           depth, batch, in_planes, in_rows, in_cols, output_shape, window,
    749           stride, padding, input_backprop_access, output_backprop_access);
    750 
    751       cgh.parallel_for(cl::sycl::range<1>(num_threads), functor);
    752     });
    753   }
    754 };
    755 
    756 }  // namespace tensorflow
    757 
    758 #endif  // TENSORFLOW_CORE_KERNELS_POOLING_OP_3D_SYCL_H_
    759