Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "include/libxsmm.h"
     17 #include "tensorflow/core/framework/fake_input.h"
     18 #include "tensorflow/core/graph/graph.h"
     19 #include "tensorflow/core/graph/node_builder.h"
     20 #include "tensorflow/core/kernels/conv_ops.h"
     21 #include "tensorflow/core/kernels/ops_testutil.h"
     22 #include "tensorflow/core/platform/test.h"
     23 
     24 namespace tensorflow {
     25 namespace {
     26 
     27 typedef struct {
     28   int nImg;
     29   int nIfm;
     30   int nOfm;
     31   int ifhp;
     32   int ifwp;
     33   int ifh;
     34   int ifw;
     35   int ofhp;
     36   int ofwp;
     37   int ofh;
     38   int ofw;
     39   int pad_h;
     40   int pad_w;
     41   int pad_h_in;
     42   int pad_w_in;
     43   int pad_h_out;
     44   int pad_w_out;
     45   int kh;
     46   int kw;
     47   int stride_h;
     48   int stride_w;
     49 } naive_conv_t;
     50 
     51 LIBXSMM_INLINE void naive_copy_NCHW_to_NHWC(const float* nchw, Tensor& nhwc,
     52                                             int N, int H, int W, int C) {
     53   LIBXSMM_VLA_DECL(4, const float, input, nchw, C, H, W);
     54   int n, h, w, c;
     55   auto output = nhwc.flat<float>();
     56   for (n = 0; n < N; n++) {
     57     for (h = 0; h < H; h++) {
     58       for (w = 0; w < W; w++) {
     59         for (c = 0; c < C; c++) {
     60           output(n * H * W * C + h * W * C + w * C + c) =
     61               LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W);
     62         }
     63       }
     64     }
     65   }
     66 }
     67 
     68 LIBXSMM_INLINE void naive_copy_KCRS_to_RSCK(const float* kcrs, Tensor& rsck,
     69                                             int R, int S, int C, int K) {
     70   LIBXSMM_VLA_DECL(4, const float, input, kcrs, C, R, S);
     71   int r, s, c, k;
     72   auto output = rsck.flat<float>();
     73 
     74   for (r = 0; r < R; r++) {
     75     for (s = 0; s < S; s++) {
     76       for (c = 0; c < C; c++) {
     77         for (k = 0; k < K; k++) {
     78           output(r * S * C * K + s * C * K + c * K + k) =
     79               LIBXSMM_VLA_ACCESS(4, input, k, c, r, s, C, R, S);
     80         }
     81       }
     82     }
     83   }
     84 }
     85 
     86 LIBXSMM_INLINE void zero_buf(float* buf, long size) {
     87   int i;
     88   for (i = 0; i < size; ++i) {
     89     buf[i] = 0.0f;
     90   }
     91 }
     92 
     93 LIBXSMM_INLINE void copy_buf(Tensor& dst, float* src, long size) {
     94   long i;
     95   auto output = dst.flat<float>();
     96   for (i = 0; i < size; ++i) output(i) = src[i];
     97 }
     98 
     99 LIBXSMM_INLINE void init_buf(float* buf, long size, int initPos, int initOne) {
    100   int i;
    101   zero_buf(buf, size);
    102   for (i = 0; i < size; ++i) {
    103     buf[i] =
    104         (float)((initOne != 0)
    105                     ? 1.0
    106                     : ((initPos != 0) ? drand48() : (0.05 - drand48() / 10.0)));
    107   }
    108 }
    109 
    110 LIBXSMM_INLINE void naive_conv_fp(naive_conv_t* param, const float* input,
    111                                   float* output, const float* filter) {
    112   int nImg = param->nImg;
    113   int nIfm = param->nIfm;
    114   int nOfm = param->nOfm;
    115   int ifhp = param->ifhp;
    116   int ifwp = param->ifwp;
    117   int ofhp = param->ofhp;
    118   int ofwp = param->ofwp;
    119   int ifh = param->ifh;
    120   int ifw = param->ifw;
    121   int ofh = param->ofh;
    122   int ofw = param->ofw;
    123   int pad_h = param->pad_h;
    124   int pad_w = param->pad_w;
    125   int pad_h_in = param->pad_h_in;
    126   int pad_w_in = param->pad_w_in;
    127   int pad_h_out = param->pad_h_out;
    128   int pad_w_out = param->pad_w_out;
    129   int kh = param->kh;
    130   int kw = param->kw;
    131   int stride_h = param->stride_h;
    132   int stride_w = param->stride_w;
    133   /* loop counters */
    134   int img, ofm, ifm, oj, oi, ij, ii, kj, ki;
    135 
    136   LIBXSMM_VLA_DECL(4, float, output_t, output + (pad_w_out * ofwp + pad_h_out),
    137                    nOfm, ofhp, ofwp);
    138   LIBXSMM_VLA_DECL(4, const float, input_t,
    139                    input + (pad_w_in * ifwp + pad_h_in), nIfm, ifhp, ifwp);
    140   LIBXSMM_VLA_DECL(4, const float, filter_t, filter, nIfm, kh, kw);
    141 
    142   for (img = 0; img < nImg; ++img) {
    143     for (ofm = 0; ofm < nOfm; ++ofm) {
    144       for (ifm = 0; ifm < nIfm; ++ifm) {
    145         for (oj = 0; oj < ofh; ++oj) {
    146           ij = oj * stride_h - pad_h;
    147           for (oi = 0; oi < ofw; ++oi) {
    148             ii = oi * stride_w - pad_w;
    149             for (kj = 0; kj < kh; ++kj) {
    150               if (ij + kj < 0 || ij + kj >= ifh) continue;
    151               for (ki = 0; ki < kw; ++ki) {
    152                 if (ii + ki < 0 || ii + ki >= ifw) continue;
    153                 LIBXSMM_VLA_ACCESS(4, output_t, img, ofm, oj, oi, nOfm, ofhp,
    154                                    ofwp) +=
    155                     LIBXSMM_VLA_ACCESS(4, input_t, img, ifm, ij + kj, ii + ki,
    156                                        nIfm, ifhp, ifwp) *
    157                     LIBXSMM_VLA_ACCESS(4, filter_t, ofm, ifm, kj, ki, nIfm, kh,
    158                                        kw);
    159               }
    160             }
    161           }
    162         }
    163       }
    164     }
    165   }
    166 }
    167 
    168 void RunXsmmVsGeneric() {}
    169 
    170 class XsmmConv2DTest : public OpsTestBase {
    171  protected:
    172   void MakeOp(int stride) {
    173     TF_CHECK_OK(NodeDefBuilder("xsmm", "Conv2D")
    174                     .Input(FakeInput(DT_FLOAT))
    175                     .Input(FakeInput(DT_FLOAT))
    176                     .Attr("strides", {1, stride, stride, 1})
    177                     .Attr("padding", "VALID")
    178                     .Finalize(node_def()));
    179 
    180     TF_ASSERT_OK(InitOp());
    181   }
    182 };
    183 
    184 TEST_F(XsmmConv2DTest, Basic) {
    185   MakeOp(1);
    186 
    187   // setup scoped allocator, which uses cpu_allocator() for this scope
    188   const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator;
    189 
    190   int ifw = 14;   /* input width, "W" */
    191   int ifh = 14;   /* input height, "H" */
    192   int nImg = 32;  /* mini-batch size, "N" */
    193   int nIfm = 64;  /* number of input feature maps, "C" */
    194   int nOfm = 64;  /* number of output feature maps, "K" */
    195   int kh = 3;     /* filter height, "R" */
    196   int kw = 3;     /* filter width, "S" */
    197   int pad = 0;    /* padding in output */
    198   int stride = 1; /* stride when accessing inputs */
    199 
    200   int stride_w = stride;
    201   int stride_h = stride;
    202   int pad_h = pad;
    203   int pad_w = pad;
    204 
    205   int pad_h_in = pad_h;
    206   int pad_w_in = pad_w;
    207 
    208   int pad_h_out = 0;
    209   int pad_w_out = 0;
    210 
    211   /* deriving some values for naive code */
    212   int ofh = (ifh + 2 * pad_h - kh) / stride_h + 1;
    213   int ofw = (ifw + 2 * pad_w - kw) / stride_w + 1;
    214   int ifhp = ifh + 2 * pad_h_in;
    215   int ifwp = ifw + 2 * pad_w_in;
    216   int ofhp = ofh + 2 * pad_h_out;
    217   int ofwp = ofw + 2 * pad_w_out;
    218 
    219   // Initialization of Filter and Image
    220 
    221   /* allocate data */
    222   float* naive_input = (float*)libxsmm_aligned_scratch(
    223       nImg * nIfm * ifhp * ifwp * sizeof(float), 2097152);
    224   float* naive_output = (float*)libxsmm_aligned_scratch(
    225       nImg * nOfm * ofhp * ofwp * sizeof(float), 2097152);
    226   float* naive_filter = (float*)libxsmm_aligned_scratch(
    227       nOfm * nIfm * kh * kw * sizeof(float), 2097152);
    228   /* initialize data */
    229   init_buf(naive_input, nImg * nIfm * ifhp * ifwp, 0, 0);
    230   zero_buf(naive_output, nImg * nOfm * ofhp * ofwp);
    231   init_buf(naive_filter, nOfm * nIfm * kh * kw, 0, 0);
    232 
    233   Tensor image(DT_FLOAT, {nImg, ifhp, ifwp, nIfm});
    234 
    235   Tensor filter(DT_FLOAT, {kh, kw, nIfm, nOfm});
    236 
    237   naive_copy_NCHW_to_NHWC(naive_input, image, nImg, ifhp, ifwp, nIfm);
    238   naive_copy_KCRS_to_RSCK(naive_filter, filter, kh, kw, nIfm, nOfm);
    239 
    240   // Run naive convolution
    241 
    242   naive_conv_t naive_param;
    243 
    244   naive_param.nImg = nImg;
    245   naive_param.nIfm = nIfm;
    246   naive_param.nOfm = nOfm;
    247   naive_param.ifhp = ifhp;
    248   naive_param.ifwp = ifwp;
    249   naive_param.ofhp = ofhp;
    250   naive_param.ofwp = ofwp;
    251   naive_param.ifh = ifh;
    252   naive_param.ifw = ifw;
    253   naive_param.ofh = ofh;
    254   naive_param.ofw = ofw;
    255   naive_param.pad_h = pad_h;
    256   naive_param.pad_w = pad_w;
    257   naive_param.pad_h_in = pad_h_in;
    258   naive_param.pad_w_in = pad_w_in;
    259   naive_param.pad_h_out = pad_h_out;
    260   naive_param.pad_w_out = pad_w_out;
    261   naive_param.kh = kh;
    262   naive_param.kw = kw;
    263   naive_param.stride_h = stride_h;
    264   naive_param.stride_w = stride_w;
    265 
    266   naive_conv_fp(&naive_param, naive_input, naive_output, naive_filter);
    267 
    268   AddInputFromArray<float>(image.shape(), image.flat<float>());
    269   AddInputFromArray<float>(filter.shape(), filter.flat<float>());
    270 
    271   // Run Op (TF)
    272   TF_ASSERT_OK(RunOpKernel());
    273 
    274   // Check the output.
    275   Tensor expected(DT_FLOAT, {nImg, ofhp, ofwp, nOfm});
    276   naive_copy_NCHW_to_NHWC(naive_output, expected, nImg, ofhp, ofwp, nOfm);
    277 
    278   test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
    279   libxsmm_free(naive_input);
    280   libxsmm_free(naive_output);
    281   libxsmm_free(naive_filter);
    282 }
    283 
    284 /*
    285 
    286 
    287 TEST(XsmmConv2DTest, Basic) {
    288 
    289     auto num_threads =
    290         ctx->device()->tensorflow_cpu_worker_threads()->num_threads;
    291     // See libxsmm_dnn.h for this struct definition.
    292     libxsmm_dnn_conv_desc desc;
    293     desc.N = batch;
    294     desc.C = in_depth;
    295     desc.H = input_rows;
    296     desc.W = input_cols;
    297     desc.K = out_depth;
    298     desc.R = filter_rows;
    299     desc.S = filter_cols;
    300     desc.u = stride_rows;
    301     desc.v = stride_cols;
    302     desc.pad_h = pad_rows;
    303     desc.pad_w = pad_cols;
    304     desc.pad_h_in = pad_rows;  // libxsmm supports only physical padding for now
    305     desc.pad_w_in = pad_cols;  // libxsmm supports only physical padding for now
    306     desc.pad_h_out = 0;
    307     desc.pad_w_out = 0;
    308     desc.threads = num_threads;
    309     desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
    310     desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
    311     desc.filter_format =
    312 LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;//LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
    313     desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
    314     desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
    315     desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
    316 
    317     if (!CanUseXsmmConv2D(desc, data_format)) {
    318       return false;
    319     }
    320 
    321     auto input_ptr = input.template flat<float>().data();
    322     auto filter_ptr = filter.template flat<float>().data();
    323     auto output_ptr = output->template flat<float>().data();
    324 
    325     bool success = functor::XsmmFwdConv2D<CPUDevice, float>()(
    326         ctx, desc, input_ptr, filter_ptr, output_ptr);
    327     return success;
    328 
    329 
    330 
    331 
    332 
    333 
    334 
    335 }
    336 */
    337 }  // namespace
    338 }  // namespace tensorflow
    339