1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Make this file empty (or nearly empty) so that it can be compiled even when 17 // libxsmm is not available. 18 19 #ifndef TENSORFLOW_USE_LIBXSMM 20 void dummy_xsmm_conv2d_ensure_file_is_not_empty(); 21 #else 22 23 #define USE_EIGEN_TENSOR 24 #define EIGEN_USE_THREADS 25 26 #include "tensorflow/core/kernels/xsmm_conv2d.h" 27 28 #include <stdlib.h> 29 #include <cstring> 30 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/lib/core/blocking_counter.h" 33 #include "tensorflow/core/lib/core/threadpool.h" 34 35 #include "libxsmm_main.h" // TODO(bsteiner): API to avoid incl. header from src/ 36 #include "include/libxsmm_cpuid.h" 37 #include "include/libxsmm_malloc.h" 38 39 namespace tensorflow { 40 41 // Xsmm*Conv2D are wrappers for libxsmm direct convolutions. 42 43 // Returns true if convolution can be computed efficiently by XsmmConv2D, 44 // returns false otherwise. 45 bool CanUseXsmmConv2D(const libxsmm_dnn_conv_desc& desc, 46 TensorFormat data_format) { 47 int VECTOR_SIZE; 48 int arch = libxsmm_cpuid_x86(); 49 50 if (arch == LIBXSMM_X86_AVX512_CORE) { 51 VECTOR_SIZE = 16; 52 } else if (arch == LIBXSMM_X86_AVX2) { 53 VECTOR_SIZE = 8; 54 } else { 55 VLOG(1) << "Cannot use XSMM convolutions: unsupported architecture!"; 56 return false; 57 } 58 59 if (data_format != FORMAT_NHWC) { 60 VLOG(1) << "Cannot use XSMM convolutions: unsupported format!"; 61 return false; 62 } 63 if (desc.K % VECTOR_SIZE != 0) { 64 VLOG(1) << "Cannot use XSMM convolutions: output features count not" 65 " divisible by vector size!"; 66 return false; 67 } 68 VLOG(2) << "Can use XSMM convolutions."; 69 return true; 70 } 71 72 typedef Eigen::ThreadPoolDevice CPUDevice; 73 74 namespace functor { 75 76 static void chk_libxsmm_err(libxsmm_dnn_err_t status, string msg) { 77 if (status != LIBXSMM_DNN_SUCCESS) { 78 VLOG(0) << msg << " failed: " << libxsmm_dnn_get_error(status); 79 } 80 } 81 82 LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R, 83 int S, int C, int K, int blocksifm, 84 int blocksofm, int ifmblock, 85 int ofmblock, int start, int end) { 86 LIBXSMM_VLA_DECL(4, const float, input, rsck, S, C, K); 87 LIBXSMM_VLA_DECL(6, float, output, kcrs, blocksifm, R, S, ifmblock, ofmblock); 88 int r, s, k, c, v1, v2; 89 90 for (k = start; k < end; k++) { 91 for (c = 0; c < blocksifm; c++) { 92 for (r = 0; r < R; r++) { 93 for (s = 0; s < S; s++) { 94 for (v1 = c * ifmblock; v1 < std::min(C, (c + 1) * ifmblock); v1++) { 95 for (v2 = k * ofmblock; v2 < std::min(K, (k + 1) * ofmblock); v2++) 96 LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock, 97 v2 - k * ofmblock, blocksifm, R, S, ifmblock, 98 ofmblock) = 99 LIBXSMM_VLA_ACCESS(4, input, r, s, v1, v2, S, C, K); 100 for (v2 = K; v2 < (k + 1) * ofmblock; v2++) 101 LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock, 102 v2 - k * ofmblock, blocksifm, R, S, ifmblock, 103 ofmblock) = 0.0f; 104 } 105 for (v1 = C; v1 < (c + 1) * ifmblock; v1++) { 106 for (v2 = k * ofmblock; v2 < (k + 1) * ofmblock; v2++) 107 LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock, 108 v2 - k * ofmblock, blocksifm, R, S, ifmblock, 109 ofmblock) = 0.0f; 110 } 111 } 112 } 113 } 114 } 115 } 116 117 class libxsmm_dnn_conv_desc_wrap { 118 public: 119 const libxsmm_dnn_conv_desc d; 120 121 libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc& d_) : d(d_) {} 122 bool operator==(const libxsmm_dnn_conv_desc_wrap& w) const { 123 return (d.N == w.d.N && d.C == w.d.C && d.H == w.d.H && d.W == w.d.W && 124 d.K == w.d.K && d.R == w.d.R && d.S == w.d.S && d.u == w.d.u && 125 d.v == w.d.v && d.pad_h == w.d.pad_h && d.pad_w == w.d.pad_w); 126 } 127 }; 128 129 struct HashFunction { 130 std::size_t operator()(const libxsmm_dnn_conv_desc_wrap& w) const { 131 return libxsmm_hash(&w.d, sizeof(w.d), 25071975); 132 } 133 }; 134 135 class handles { 136 public: 137 libxsmm_dnn_layer* find(const libxsmm_dnn_conv_desc_wrap& w) { 138 std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*, 139 HashFunction>::iterator i = libxsmm_handles.find(w); 140 if (i == libxsmm_handles.end()) { 141 libxsmm_dnn_err_t status; 142 libxsmm_dnn_layer* libxsmm_handle = 143 libxsmm_dnn_create_conv_layer(w.d, &status); 144 chk_libxsmm_err(status, "Create handle"); 145 libxsmm_handles.insert(std::make_pair(w, libxsmm_handle)); 146 return libxsmm_handle; 147 } else { 148 return i->second; 149 } 150 } 151 ~handles() { 152 std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*, 153 HashFunction>::iterator i; 154 for (i = libxsmm_handles.begin(); i != libxsmm_handles.end(); i++) 155 chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(i->second), 156 "Destroy handle"); 157 } 158 159 private: 160 std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*, 161 HashFunction> 162 libxsmm_handles; 163 }; 164 165 static handles libxsmm_handles; 166 167 // #define LIBXSMM_DETAILED_TIMING 168 169 template <typename InputPtr, typename FilterPtr, typename OutputPtr> 170 static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, 171 const libxsmm_dnn_conv_desc& desc, 172 libxsmm_dnn_compute_kind kind, 173 InputPtr input, FilterPtr filter, 174 OutputPtr output) { 175 #if defined(LIBXSMM_DETAILED_TIMING) 176 unsigned long long l_tick1, l_tick2, l_tick3, l_tick4, l_tick5, l_tick6, 177 l_tick7, l_tick8, l_tick9, l_tick10; 178 l_tick1 = libxsmm_timer_tick(); 179 #endif 180 // setup scoped allocator, which adopts the allocator from the context 181 const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator(*ctx); 182 libxsmm_dnn_err_t status; 183 libxsmm_dnn_layer* libxsmm_handle; 184 libxsmm_dnn_conv_desc_wrap w(desc); 185 void* scratch; 186 187 // if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) 188 libxsmm_handle = libxsmm_handles.find(w); 189 // else{ 190 // libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status); 191 // chk_libxsmm_err(status, "Create handle"); 192 //} 193 194 status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind); 195 if (status == LIBXSMM_DNN_WARN_FALLBACK) { 196 return false; // Use non-libxsmm code 197 } 198 chk_libxsmm_err(status, "Check codegen status"); 199 200 libxsmm_dnn_buffer* libxsmm_input; 201 libxsmm_dnn_buffer* libxsmm_output; 202 libxsmm_dnn_filter* libxsmm_filter; 203 204 #if defined(LIBXSMM_DETAILED_TIMING) 205 l_tick2 = libxsmm_timer_tick(); 206 #endif 207 208 int ifmblock = (libxsmm_handle->ifmblock); 209 int ofmblock = (libxsmm_handle->ofmblock); 210 211 int blocksifm = 212 desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1; 213 int blocksofm = 214 desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1; 215 float* native_filter = 216 (float*)libxsmm_aligned_scratch(blocksofm * blocksifm * desc.R * desc.S * 217 ifmblock * ofmblock * sizeof(float), 218 2097152); 219 220 const DeviceBase::CpuWorkerThreads* worker_threads = 221 ctx->device()->tensorflow_cpu_worker_threads(); 222 223 int num_threads = worker_threads->num_threads; 224 225 #if 1 226 if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD || 227 kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { 228 if (blocksofm > num_threads) { 229 int work = blocksofm; 230 BlockingCounter count(num_threads); 231 for (int i = 0; i < num_threads; ++i) { 232 worker_threads->workers->Schedule([=, &count]() { 233 int start = work / num_threads * i; 234 int end = (start + work / num_threads) > work 235 ? work 236 : start + work / num_threads; 237 copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C, 238 desc.K, blocksifm, blocksofm, ifmblock, ofmblock, 239 start, end); 240 count.DecrementCount(); 241 }); 242 } 243 count.Wait(); 244 } else { 245 int work = blocksofm; 246 int num_threads = work; 247 248 BlockingCounter count(num_threads); 249 for (int i = 0; i < num_threads; ++i) { 250 worker_threads->workers->Schedule([=, &count]() { 251 int start = i; 252 int end = i + 1; 253 copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C, 254 desc.K, blocksifm, blocksofm, ifmblock, ofmblock, 255 start, end); 256 count.DecrementCount(); 257 }); 258 } 259 count.Wait(); 260 } 261 } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { 262 // Added: for weight update 263 libxsmm_filter = 264 libxsmm_dnn_link_filter(libxsmm_handle, LIBXSMM_DNN_FILTER, filter, 265 LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR, &status); 266 chk_libxsmm_err(status, 267 "Link filter"); // weight update is in RSCK as 268 // filter should be returned in RSCK 269 // format 270 } 271 #else 272 memset(native_filter, 0, 273 blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock * 274 sizeof(float)); 275 #endif 276 277 #if defined(LIBXSMM_DETAILED_TIMING) 278 l_tick3 = libxsmm_timer_tick(); 279 #endif 280 281 libxsmm_input = 282 libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_INPUT, input, 283 LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status); 284 chk_libxsmm_err(status, "Link input buffer"); 285 libxsmm_output = 286 libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_OUTPUT, output, 287 LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status); 288 chk_libxsmm_err(status, "Link output buffer"); 289 if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD || 290 kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { 291 libxsmm_filter = libxsmm_dnn_link_filter( 292 libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter, 293 LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status); 294 chk_libxsmm_err(status, "Link filter"); 295 } 296 if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) { 297 chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, 298 LIBXSMM_DNN_REGULAR_INPUT), 299 "Bind input forward"); 300 chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, 301 LIBXSMM_DNN_REGULAR_OUTPUT), 302 "Bind output forward"); 303 chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, 304 LIBXSMM_DNN_REGULAR_FILTER), 305 "Bind filter forward"); 306 } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { 307 chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_input), "Zero input"); 308 309 chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, 310 LIBXSMM_DNN_GRADIENT_INPUT), 311 "Bind input backward"); 312 chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, 313 LIBXSMM_DNN_GRADIENT_OUTPUT), 314 "Bind output backward"); 315 chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, 316 LIBXSMM_DNN_REGULAR_FILTER), 317 "Bind filter backward"); 318 } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { 319 chk_libxsmm_err(libxsmm_dnn_zero_filter(libxsmm_filter), "Zero filter"); 320 321 chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, 322 LIBXSMM_DNN_REGULAR_INPUT), 323 "Bind input weight update"); 324 chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, 325 LIBXSMM_DNN_GRADIENT_OUTPUT), 326 "Bind output weight update"); 327 chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, 328 LIBXSMM_DNN_GRADIENT_FILTER), 329 "Bind filter weight update"); 330 } else { 331 /* shouldn't happen */ 332 } 333 334 #if defined(LIBXSMM_DETAILED_TIMING) 335 l_tick4 = libxsmm_timer_tick(); 336 #endif 337 338 /* bind scratch */ 339 scratch = (void*)libxsmm_aligned_scratch( 340 libxsmm_dnn_get_scratch_size(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, 341 &status), 342 2097152); 343 chk_libxsmm_err(status, "scratch allocation"); 344 chk_libxsmm_err(libxsmm_dnn_bind_scratch( 345 libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch), 346 "binding scratch"); 347 348 #if defined(LIBXSMM_DETAILED_TIMING) 349 l_tick5 = libxsmm_timer_tick(); 350 #endif 351 352 if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { 353 libxsmm_dnn_transpose_filter(libxsmm_handle, LIBXSMM_DNN_FILTER); 354 } 355 356 #if defined(LIBXSMM_DETAILED_TIMING) 357 l_tick6 = libxsmm_timer_tick(); 358 #endif 359 360 BlockingCounter counter(num_threads); 361 362 for (int i = 0; i < num_threads; ++i) { 363 worker_threads->workers->Schedule([=, &counter]() { 364 chk_libxsmm_err(libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, i), 365 "Worker"); 366 counter.DecrementCount(); 367 }); 368 } 369 counter.Wait(); 370 371 #if defined(LIBXSMM_DETAILED_TIMING) 372 l_tick7 = libxsmm_timer_tick(); 373 #endif 374 375 if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { 376 libxsmm_dnn_reduce_wu_filters(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER); 377 } 378 379 #if defined(LIBXSMM_DETAILED_TIMING) 380 l_tick8 = libxsmm_timer_tick(); 381 #endif 382 383 /* clean up */ 384 chk_libxsmm_err( 385 libxsmm_dnn_release_scratch(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL), 386 "release scratch"); 387 if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) { 388 chk_libxsmm_err( 389 libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT), 390 "release input"); 391 chk_libxsmm_err( 392 libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT), 393 "release output"); 394 chk_libxsmm_err( 395 libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER), 396 "release filter"); 397 } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) { 398 chk_libxsmm_err( 399 libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT), 400 "release input"); 401 chk_libxsmm_err( 402 libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT), 403 "release output"); 404 chk_libxsmm_err( 405 libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER), 406 "release filter"); 407 } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) { 408 chk_libxsmm_err( 409 libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT), 410 "release input"); 411 chk_libxsmm_err( 412 libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT), 413 "release output"); 414 chk_libxsmm_err( 415 libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER), 416 "release filter"); 417 } else { 418 /* shouldn't happen */ 419 } 420 chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input"); 421 chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output"); 422 chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter"); 423 424 #if defined(LIBXSMM_DETAILED_TIMING) 425 l_tick9 = libxsmm_timer_tick(); 426 #endif 427 428 // if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD) 429 // chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle), 430 // "Destroy handle"); 431 432 libxsmm_free(native_filter); 433 libxsmm_free(scratch); 434 435 #if defined(LIBXSMM_DETAILED_TIMING) 436 l_tick10 = libxsmm_timer_tick(); 437 printf( 438 "time for convolution (%i, %i, %i, %i, %i): %f, %f, %f, %f, %f, %f, %f, " 439 "%f, %f, %f\n", 440 desc.N, desc.C, desc.K, desc.R, desc.S, 441 libxsmm_timer_duration(l_tick1, l_tick2), 442 libxsmm_timer_duration(l_tick2, l_tick3), 443 libxsmm_timer_duration(l_tick3, l_tick4), 444 libxsmm_timer_duration(l_tick4, l_tick5), 445 libxsmm_timer_duration(l_tick5, l_tick6), 446 libxsmm_timer_duration(l_tick6, l_tick7), 447 libxsmm_timer_duration(l_tick7, l_tick8), 448 libxsmm_timer_duration(l_tick8, l_tick9), 449 libxsmm_timer_duration(l_tick9, l_tick10), 450 libxsmm_timer_duration(l_tick1, l_tick10)); 451 #endif 452 453 return true; // Succeeded 454 } 455 456 template <typename T> 457 struct XsmmFwdConv2D<CPUDevice, T> { 458 bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc, 459 const T* input, const T* filter, T* output) { 460 return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_FWD, 461 input, filter, output); 462 } 463 }; 464 465 template <typename T> 466 struct XsmmBkwInputConv2D<CPUDevice, T> { 467 bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc, 468 T* input, const T* filter, const T* output) { 469 return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_BWD, 470 input, filter, output); 471 } 472 }; 473 474 template <typename T> 475 struct XsmmBkwFilterConv2D<CPUDevice, T> { 476 bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc, 477 const T* input, T* filter, const T* output) { 478 return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_UPD, 479 input, filter, output); 480 } 481 }; 482 483 } // namespace functor 484 485 template struct functor::XsmmFwdConv2D<CPUDevice, float>; 486 template struct functor::XsmmBkwInputConv2D<CPUDevice, float>; 487 template struct functor::XsmmBkwFilterConv2D<CPUDevice, float>; 488 489 } // namespace tensorflow 490 491 #endif // TENSORFLOW_USE_LIBXSMM 492