Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Make this file empty (or nearly empty) so that it can be compiled even when
     17 // libxsmm is not available.
     18 
     19 #ifndef TENSORFLOW_USE_LIBXSMM
     20 void dummy_xsmm_conv2d_ensure_file_is_not_empty();
     21 #else
     22 
     23 #define USE_EIGEN_TENSOR
     24 #define EIGEN_USE_THREADS
     25 
     26 #include "tensorflow/core/kernels/xsmm_conv2d.h"
     27 
     28 #include <stdlib.h>
     29 #include <cstring>
     30 
     31 #include "tensorflow/core/framework/op_kernel.h"
     32 #include "tensorflow/core/lib/core/blocking_counter.h"
     33 #include "tensorflow/core/lib/core/threadpool.h"
     34 
     35 #include "libxsmm_main.h"  // TODO(bsteiner): API to avoid incl. header from src/
     36 #include "include/libxsmm_cpuid.h"
     37 #include "include/libxsmm_malloc.h"
     38 
     39 namespace tensorflow {
     40 
     41 // Xsmm*Conv2D are wrappers for libxsmm direct convolutions.
     42 
     43 // Returns true if convolution can be computed efficiently by XsmmConv2D,
     44 // returns false otherwise.
     45 bool CanUseXsmmConv2D(const libxsmm_dnn_conv_desc& desc,
     46                       TensorFormat data_format) {
     47   int VECTOR_SIZE;
     48   int arch = libxsmm_cpuid_x86();
     49 
     50   if (arch == LIBXSMM_X86_AVX512_CORE) {
     51     VECTOR_SIZE = 16;
     52   } else if (arch == LIBXSMM_X86_AVX2) {
     53     VECTOR_SIZE = 8;
     54   } else {
     55     VLOG(1) << "Cannot use XSMM convolutions: unsupported architecture!";
     56     return false;
     57   }
     58 
     59   if (data_format != FORMAT_NHWC) {
     60     VLOG(1) << "Cannot use XSMM convolutions: unsupported format!";
     61     return false;
     62   }
     63   if (desc.K % VECTOR_SIZE != 0) {
     64     VLOG(1) << "Cannot use XSMM convolutions: output features count not"
     65                " divisible by vector size!";
     66     return false;
     67   }
     68   VLOG(2) << "Can use XSMM convolutions.";
     69   return true;
     70 }
     71 
     72 typedef Eigen::ThreadPoolDevice CPUDevice;
     73 
     74 namespace functor {
     75 
     76 static void chk_libxsmm_err(libxsmm_dnn_err_t status, string msg) {
     77   if (status != LIBXSMM_DNN_SUCCESS) {
     78     VLOG(0) << msg << " failed: " << libxsmm_dnn_get_error(status);
     79   }
     80 }
     81 
     82 LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R,
     83                                         int S, int C, int K, int blocksifm,
     84                                         int blocksofm, int ifmblock,
     85                                         int ofmblock, int start, int end) {
     86   LIBXSMM_VLA_DECL(4, const float, input, rsck, S, C, K);
     87   LIBXSMM_VLA_DECL(6, float, output, kcrs, blocksifm, R, S, ifmblock, ofmblock);
     88   int r, s, k, c, v1, v2;
     89 
     90   for (k = start; k < end; k++) {
     91     for (c = 0; c < blocksifm; c++) {
     92       for (r = 0; r < R; r++) {
     93         for (s = 0; s < S; s++) {
     94           for (v1 = c * ifmblock; v1 < std::min(C, (c + 1) * ifmblock); v1++) {
     95             for (v2 = k * ofmblock; v2 < std::min(K, (k + 1) * ofmblock); v2++)
     96               LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
     97                                  v2 - k * ofmblock, blocksifm, R, S, ifmblock,
     98                                  ofmblock) =
     99                   LIBXSMM_VLA_ACCESS(4, input, r, s, v1, v2, S, C, K);
    100             for (v2 = K; v2 < (k + 1) * ofmblock; v2++)
    101               LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
    102                                  v2 - k * ofmblock, blocksifm, R, S, ifmblock,
    103                                  ofmblock) = 0.0f;
    104           }
    105           for (v1 = C; v1 < (c + 1) * ifmblock; v1++) {
    106             for (v2 = k * ofmblock; v2 < (k + 1) * ofmblock; v2++)
    107               LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
    108                                  v2 - k * ofmblock, blocksifm, R, S, ifmblock,
    109                                  ofmblock) = 0.0f;
    110           }
    111         }
    112       }
    113     }
    114   }
    115 }
    116 
    117 class libxsmm_dnn_conv_desc_wrap {
    118  public:
    119   const libxsmm_dnn_conv_desc d;
    120 
    121   libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc& d_) : d(d_) {}
    122   bool operator==(const libxsmm_dnn_conv_desc_wrap& w) const {
    123     return (d.N == w.d.N && d.C == w.d.C && d.H == w.d.H && d.W == w.d.W &&
    124             d.K == w.d.K && d.R == w.d.R && d.S == w.d.S && d.u == w.d.u &&
    125             d.v == w.d.v && d.pad_h == w.d.pad_h && d.pad_w == w.d.pad_w);
    126   }
    127 };
    128 
    129 struct HashFunction {
    130   std::size_t operator()(const libxsmm_dnn_conv_desc_wrap& w) const {
    131     return libxsmm_hash(&w.d, sizeof(w.d), 25071975);
    132   }
    133 };
    134 
    135 class handles {
    136  public:
    137   libxsmm_dnn_layer* find(const libxsmm_dnn_conv_desc_wrap& w) {
    138     std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
    139                        HashFunction>::iterator i = libxsmm_handles.find(w);
    140     if (i == libxsmm_handles.end()) {
    141       libxsmm_dnn_err_t status;
    142       libxsmm_dnn_layer* libxsmm_handle =
    143           libxsmm_dnn_create_conv_layer(w.d, &status);
    144       chk_libxsmm_err(status, "Create handle");
    145       libxsmm_handles.insert(std::make_pair(w, libxsmm_handle));
    146       return libxsmm_handle;
    147     } else {
    148       return i->second;
    149     }
    150   }
    151   ~handles() {
    152     std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
    153                        HashFunction>::iterator i;
    154     for (i = libxsmm_handles.begin(); i != libxsmm_handles.end(); i++)
    155       chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(i->second),
    156                       "Destroy handle");
    157   }
    158 
    159  private:
    160   std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
    161                      HashFunction>
    162       libxsmm_handles;
    163 };
    164 
    165 static handles libxsmm_handles;
    166 
    167 // #define LIBXSMM_DETAILED_TIMING
    168 
    169 template <typename InputPtr, typename FilterPtr, typename OutputPtr>
    170 static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
    171                                    const libxsmm_dnn_conv_desc& desc,
    172                                    libxsmm_dnn_compute_kind kind,
    173                                    InputPtr input, FilterPtr filter,
    174                                    OutputPtr output) {
    175 #if defined(LIBXSMM_DETAILED_TIMING)
    176   unsigned long long l_tick1, l_tick2, l_tick3, l_tick4, l_tick5, l_tick6,
    177       l_tick7, l_tick8, l_tick9, l_tick10;
    178   l_tick1 = libxsmm_timer_tick();
    179 #endif
    180   // setup scoped allocator, which adopts the allocator from the context
    181   const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator(*ctx);
    182   libxsmm_dnn_err_t status;
    183   libxsmm_dnn_layer* libxsmm_handle;
    184   libxsmm_dnn_conv_desc_wrap w(desc);
    185   void* scratch;
    186 
    187   // if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD)
    188   libxsmm_handle = libxsmm_handles.find(w);
    189   // else{
    190   //  libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status);
    191   //  chk_libxsmm_err(status, "Create handle");
    192   //}
    193 
    194   status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind);
    195   if (status == LIBXSMM_DNN_WARN_FALLBACK) {
    196     return false;  // Use non-libxsmm code
    197   }
    198   chk_libxsmm_err(status, "Check codegen status");
    199 
    200   libxsmm_dnn_buffer* libxsmm_input;
    201   libxsmm_dnn_buffer* libxsmm_output;
    202   libxsmm_dnn_filter* libxsmm_filter;
    203 
    204 #if defined(LIBXSMM_DETAILED_TIMING)
    205   l_tick2 = libxsmm_timer_tick();
    206 #endif
    207 
    208   int ifmblock = (libxsmm_handle->ifmblock);
    209   int ofmblock = (libxsmm_handle->ofmblock);
    210 
    211   int blocksifm =
    212       desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1;
    213   int blocksofm =
    214       desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1;
    215   float* native_filter =
    216       (float*)libxsmm_aligned_scratch(blocksofm * blocksifm * desc.R * desc.S *
    217                                           ifmblock * ofmblock * sizeof(float),
    218                                       2097152);
    219 
    220   const DeviceBase::CpuWorkerThreads* worker_threads =
    221       ctx->device()->tensorflow_cpu_worker_threads();
    222 
    223   int num_threads = worker_threads->num_threads;
    224 
    225 #if 1
    226   if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
    227       kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
    228     if (blocksofm > num_threads) {
    229       int work = blocksofm;
    230       BlockingCounter count(num_threads);
    231       for (int i = 0; i < num_threads; ++i) {
    232         worker_threads->workers->Schedule([=, &count]() {
    233           int start = work / num_threads * i;
    234           int end = (start + work / num_threads) > work
    235                         ? work
    236                         : start + work / num_threads;
    237           copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
    238                               desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
    239                               start, end);
    240           count.DecrementCount();
    241         });
    242       }
    243       count.Wait();
    244     } else {
    245       int work = blocksofm;
    246       int num_threads = work;
    247 
    248       BlockingCounter count(num_threads);
    249       for (int i = 0; i < num_threads; ++i) {
    250         worker_threads->workers->Schedule([=, &count]() {
    251           int start = i;
    252           int end = i + 1;
    253           copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
    254                               desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
    255                               start, end);
    256           count.DecrementCount();
    257         });
    258       }
    259       count.Wait();
    260     }
    261   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
    262     // Added: for weight update
    263     libxsmm_filter =
    264         libxsmm_dnn_link_filter(libxsmm_handle, LIBXSMM_DNN_FILTER, filter,
    265                                 LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR, &status);
    266     chk_libxsmm_err(status,
    267                     "Link filter");  // weight update is in RSCK as
    268                                      // filter should be returned in RSCK
    269                                      // format
    270   }
    271 #else
    272   memset(native_filter, 0,
    273          blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock *
    274              sizeof(float));
    275 #endif
    276 
    277 #if defined(LIBXSMM_DETAILED_TIMING)
    278   l_tick3 = libxsmm_timer_tick();
    279 #endif
    280 
    281   libxsmm_input =
    282       libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_INPUT, input,
    283                               LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
    284   chk_libxsmm_err(status, "Link input buffer");
    285   libxsmm_output =
    286       libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_OUTPUT, output,
    287                               LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
    288   chk_libxsmm_err(status, "Link output buffer");
    289   if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
    290       kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
    291     libxsmm_filter = libxsmm_dnn_link_filter(
    292         libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter,
    293         LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status);
    294     chk_libxsmm_err(status, "Link filter");
    295   }
    296   if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
    297     chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
    298                                             LIBXSMM_DNN_REGULAR_INPUT),
    299                     "Bind input forward");
    300     chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
    301                                             LIBXSMM_DNN_REGULAR_OUTPUT),
    302                     "Bind output forward");
    303     chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
    304                                             LIBXSMM_DNN_REGULAR_FILTER),
    305                     "Bind filter forward");
    306   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
    307     chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_input), "Zero input");
    308 
    309     chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
    310                                             LIBXSMM_DNN_GRADIENT_INPUT),
    311                     "Bind input backward");
    312     chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
    313                                             LIBXSMM_DNN_GRADIENT_OUTPUT),
    314                     "Bind output backward");
    315     chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
    316                                             LIBXSMM_DNN_REGULAR_FILTER),
    317                     "Bind filter backward");
    318   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
    319     chk_libxsmm_err(libxsmm_dnn_zero_filter(libxsmm_filter), "Zero filter");
    320 
    321     chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
    322                                             LIBXSMM_DNN_REGULAR_INPUT),
    323                     "Bind input weight update");
    324     chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
    325                                             LIBXSMM_DNN_GRADIENT_OUTPUT),
    326                     "Bind output weight update");
    327     chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
    328                                             LIBXSMM_DNN_GRADIENT_FILTER),
    329                     "Bind filter weight update");
    330   } else {
    331     /* shouldn't happen */
    332   }
    333 
    334 #if defined(LIBXSMM_DETAILED_TIMING)
    335   l_tick4 = libxsmm_timer_tick();
    336 #endif
    337 
    338   /* bind scratch */
    339   scratch = (void*)libxsmm_aligned_scratch(
    340       libxsmm_dnn_get_scratch_size(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL,
    341                                    &status),
    342       2097152);
    343   chk_libxsmm_err(status, "scratch allocation");
    344   chk_libxsmm_err(libxsmm_dnn_bind_scratch(
    345                       libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch),
    346                   "binding scratch");
    347 
    348 #if defined(LIBXSMM_DETAILED_TIMING)
    349   l_tick5 = libxsmm_timer_tick();
    350 #endif
    351 
    352   if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
    353     libxsmm_dnn_transpose_filter(libxsmm_handle, LIBXSMM_DNN_FILTER);
    354   }
    355 
    356 #if defined(LIBXSMM_DETAILED_TIMING)
    357   l_tick6 = libxsmm_timer_tick();
    358 #endif
    359 
    360   BlockingCounter counter(num_threads);
    361 
    362   for (int i = 0; i < num_threads; ++i) {
    363     worker_threads->workers->Schedule([=, &counter]() {
    364       chk_libxsmm_err(libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, i),
    365                       "Worker");
    366       counter.DecrementCount();
    367     });
    368   }
    369   counter.Wait();
    370 
    371 #if defined(LIBXSMM_DETAILED_TIMING)
    372   l_tick7 = libxsmm_timer_tick();
    373 #endif
    374 
    375   if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
    376     libxsmm_dnn_reduce_wu_filters(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER);
    377   }
    378 
    379 #if defined(LIBXSMM_DETAILED_TIMING)
    380   l_tick8 = libxsmm_timer_tick();
    381 #endif
    382 
    383   /* clean up */
    384   chk_libxsmm_err(
    385       libxsmm_dnn_release_scratch(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL),
    386       "release scratch");
    387   if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
    388     chk_libxsmm_err(
    389         libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT),
    390         "release input");
    391     chk_libxsmm_err(
    392         libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT),
    393         "release output");
    394     chk_libxsmm_err(
    395         libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER),
    396         "release filter");
    397   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
    398     chk_libxsmm_err(
    399         libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT),
    400         "release input");
    401     chk_libxsmm_err(
    402         libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
    403         "release output");
    404     chk_libxsmm_err(
    405         libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER),
    406         "release filter");
    407   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
    408     chk_libxsmm_err(
    409         libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT),
    410         "release input");
    411     chk_libxsmm_err(
    412         libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
    413         "release output");
    414     chk_libxsmm_err(
    415         libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER),
    416         "release filter");
    417   } else {
    418     /* shouldn't happen */
    419   }
    420   chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input");
    421   chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output");
    422   chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter");
    423 
    424 #if defined(LIBXSMM_DETAILED_TIMING)
    425   l_tick9 = libxsmm_timer_tick();
    426 #endif
    427 
    428   // if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD)
    429   // chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
    430   //               "Destroy handle");
    431 
    432   libxsmm_free(native_filter);
    433   libxsmm_free(scratch);
    434 
    435 #if defined(LIBXSMM_DETAILED_TIMING)
    436   l_tick10 = libxsmm_timer_tick();
    437   printf(
    438       "time for convolution (%i, %i, %i, %i, %i): %f, %f, %f, %f, %f, %f, %f, "
    439       "%f, %f, %f\n",
    440       desc.N, desc.C, desc.K, desc.R, desc.S,
    441       libxsmm_timer_duration(l_tick1, l_tick2),
    442       libxsmm_timer_duration(l_tick2, l_tick3),
    443       libxsmm_timer_duration(l_tick3, l_tick4),
    444       libxsmm_timer_duration(l_tick4, l_tick5),
    445       libxsmm_timer_duration(l_tick5, l_tick6),
    446       libxsmm_timer_duration(l_tick6, l_tick7),
    447       libxsmm_timer_duration(l_tick7, l_tick8),
    448       libxsmm_timer_duration(l_tick8, l_tick9),
    449       libxsmm_timer_duration(l_tick9, l_tick10),
    450       libxsmm_timer_duration(l_tick1, l_tick10));
    451 #endif
    452 
    453   return true;  // Succeeded
    454 }
    455 
    456 template <typename T>
    457 struct XsmmFwdConv2D<CPUDevice, T> {
    458   bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
    459                   const T* input, const T* filter, T* output) {
    460     return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_FWD,
    461                                   input, filter, output);
    462   }
    463 };
    464 
    465 template <typename T>
    466 struct XsmmBkwInputConv2D<CPUDevice, T> {
    467   bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
    468                   T* input, const T* filter, const T* output) {
    469     return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_BWD,
    470                                   input, filter, output);
    471   }
    472 };
    473 
    474 template <typename T>
    475 struct XsmmBkwFilterConv2D<CPUDevice, T> {
    476   bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
    477                   const T* input, T* filter, const T* output) {
    478     return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_UPD,
    479                                   input, filter, output);
    480   }
    481 };
    482 
    483 }  // namespace functor
    484 
    485 template struct functor::XsmmFwdConv2D<CPUDevice, float>;
    486 template struct functor::XsmmBkwInputConv2D<CPUDevice, float>;
    487 template struct functor::XsmmBkwFilterConv2D<CPUDevice, float>;
    488 
    489 }  // namespace tensorflow
    490 
    491 #endif  // TENSORFLOW_USE_LIBXSMM
    492