1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include "tensorflow/core/kernels/sparse_matmul_op.h" 21 22 #include <map> 23 #include <memory> 24 #include <vector> 25 26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 27 #include "tensorflow/core/common_runtime/device.h" 28 #include "tensorflow/core/framework/bfloat16.h" 29 #include "tensorflow/core/framework/op.h" 30 #include "tensorflow/core/framework/op_kernel.h" 31 #include "tensorflow/core/framework/types.h" 32 #include "tensorflow/core/kernels/fill_functor.h" 33 #include "tensorflow/core/lib/core/blocking_counter.h" 34 #include "tensorflow/core/lib/core/threadpool.h" 35 #include "tensorflow/core/lib/gtl/stl_util.h" 36 #include "tensorflow/core/platform/logging.h" 37 #include "tensorflow/core/platform/macros.h" 38 #include "tensorflow/core/platform/mutex.h" 39 #include "tensorflow/core/platform/thread_annotations.h" 40 #include "tensorflow/core/platform/types.h" 41 #ifdef TENSORFLOW_USE_LIBXSMM 42 #include "include/libxsmm_intrinsics_x86.h" 43 #include "include/libxsmm_malloc.h" 44 #include "include/libxsmm_spmdm.h" 45 #endif 46 47 namespace tensorflow { 48 namespace { 49 50 using Eigen::operator==; 51 52 template <typename T> 53 using BasicMatrix = Eigen::Tensor<T, 2, Eigen::RowMajor>; 54 55 template <typename T> 56 using BasicMatrixMap = 57 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>; 58 59 using Matrix = BasicMatrix<float>; 60 using MatrixMap = BasicMatrixMap<float>; 61 using CPUDevice = Eigen::ThreadPoolDevice; 62 using DSizes = Eigen::DSizes<Eigen::DenseIndex, 2>; 63 64 // Two commonly used static dsizes. We use Eigen::type2index to allow as much 65 // compile time optimization as possible. 66 #ifdef EIGEN_HAS_INDEX_LIST 67 inline Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>> 68 dsizes_00() { 69 return Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>(); 70 } 71 inline Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>> 72 dsizes_10() { 73 return Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>(); 74 } 75 #else 76 inline DSizes dsizes_00() { return DSizes(0, 0); } 77 inline DSizes dsizes_10() { return DSizes(1, 0); } 78 #endif 79 80 // Blocksizes 81 // TODO(agarwal): compute these sizes based on cache sizes. 82 const int K = 64; 83 const int M = 64; 84 const int N = 128; 85 86 // This stores a sparse representation of a slice of a matrix with size 87 // (num_rows, num_cols). The slice is represented as a series of blocks of size 88 // (num_rows, b), where b = block_size for all but the last block, which may 89 // have fewer columns. 90 // 91 // num_rows and block_size are assumed to be <= 256. This allows storing 92 // different indices as uint8. 93 // 94 // For each block, we store all the non zero entries in data/data3 vector and 95 // the corresponding coordinates of the element in index/index3 vectors. index3 96 // vector stores index of 3 elements in the same row so that these elements can 97 // share the same row coordinate. Each entry in Index3 corresponds to 3 entries 98 // in data3. 99 // 100 // Note that all the data/indices of all the blocks are stored in the same 101 // vectors respectively. To identify block boundaries, we store the block 102 // offsets using index3_offset/index_offset. If there are n blocks in the slice, 103 // index3_offset and index_offset have n entries. The indices for the ith block 104 // are the values in the following range: 105 // [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for 106 // index_offset. 107 template <typename T> 108 struct SparseSlice { 109 using ConstMatrixMap = BasicMatrixMap<const T>; 110 111 public: 112 // Indices of three elements on the same row. 113 struct Index3 { 114 uint8 m; // row 115 // columns 116 uint8 k1; 117 uint8 k2; 118 uint8 k3; 119 }; 120 121 // Index of one element. 122 struct Index { 123 uint8 m; 124 uint8 k; 125 }; 126 127 SparseSlice(int nrows, int ncols, int bsize) 128 : num_rows(nrows), num_cols(ncols), block_size(bsize) { 129 DCHECK_LE(nrows, 256); 130 DCHECK_LE(block_size, 256); 131 } 132 133 // Initializes the slice with data starting at mat(0, col_offset) and with 134 // size (num_rows, num_cols). 135 // If Transpose is true, implicitly transposes mat. 136 template <bool Transpose = false> 137 void Initialize(const ConstMatrixMap& mat, int col_offset); 138 139 void Clear(); 140 141 // See comments above. 142 std::vector<int> index3_offset; 143 std::vector<Index3> index3; 144 std::vector<T> data3; 145 146 // See comments above. Similar to "index3" except that each element in "index" 147 // corresponds to one element in data. 148 std::vector<int> index_offset; 149 std::vector<Index> index; 150 std::vector<T> data; 151 152 // Number of rows and columns for the slice. 153 const int num_rows; 154 const int num_cols; 155 156 // Block size used to initialize from a matrix. 157 const int block_size; 158 }; 159 160 template <typename T> 161 template <bool Transpose> 162 void SparseSlice<T>::Initialize( 163 const typename SparseSlice<T>::ConstMatrixMap& mat, int col_offset) { 164 const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0); 165 const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1); 166 DCHECK_LE(num_rows, mat_rows); 167 DCHECK_LE(num_cols + col_offset, mat_cols); 168 169 int num_blocks = (num_cols + block_size - 1) / block_size; 170 int mat_size = num_rows * num_cols; 171 172 index3_offset.reserve(num_blocks); 173 data3.reserve(mat_size); 174 index3.reserve(mat_size / 3); 175 176 index_offset.reserve(num_blocks); 177 data.reserve(num_blocks * num_rows * 2); 178 index.reserve(num_blocks * num_rows * 2); 179 180 Index3 idx3; 181 Index idx; 182 int data3_size = 0; 183 static const T zero(0); 184 for (int i = 0; i < num_blocks; ++i) { 185 int num_block_cols = std::min(block_size, num_cols - block_size * i); 186 for (int row = 0; row < num_rows; ++row) { 187 idx3.m = static_cast<uint8>(row); 188 // Safety note: The following code has a race, since it checks whether 189 // *curr is nonzero and then reads it again on use. However, the result 190 // of the race is only that some of the "nonzeros" in the resulting sparse 191 // representation may actually be zero, which is harmless. 192 const auto* start = 193 Transpose ? &mat(col_offset, row) : &mat(row, col_offset); 194 const auto* curr = start; 195 const int stride = Transpose ? mat.dimension(1) : 1; 196 const auto* end = start + stride * num_block_cols; 197 uint8 k = 0; 198 #define NEXT_ELEM \ 199 curr += stride; \ 200 ++k; 201 while (true) { 202 while (curr < end && (*curr == zero)) { 203 NEXT_ELEM; 204 } 205 if (curr >= end) break; 206 idx3.k1 = k; 207 data3.push_back(*curr); 208 NEXT_ELEM; 209 210 while (curr < end && (*curr == zero)) { 211 NEXT_ELEM; 212 } 213 if (curr >= end) break; 214 idx3.k2 = k; 215 data3.push_back(*curr); 216 NEXT_ELEM; 217 218 while (curr < end && (*curr == zero)) { 219 NEXT_ELEM; 220 } 221 if (curr >= end) break; 222 idx3.k3 = k; 223 data3.push_back(*curr); 224 NEXT_ELEM; 225 index3.push_back(idx3); 226 #undef NEXT_ELEM 227 } 228 int num_inserted_mod = data3.size() % 3; 229 // Move some elements to index and data if needed. 230 data3_size = data3.size() - num_inserted_mod; 231 idx.m = idx3.m; 232 switch (num_inserted_mod) { 233 case 2: 234 idx.k = idx3.k2; 235 data.push_back(data3[data3_size + 1]); 236 index.push_back(idx); 237 TF_FALLTHROUGH_INTENDED; 238 case 1: 239 idx.k = idx3.k1; 240 data.push_back(data3[data3_size]); 241 index.push_back(idx); 242 data3.resize(data3_size); 243 } 244 } 245 col_offset += block_size; 246 index3_offset.push_back(index3.size()); 247 index_offset.push_back(index.size()); 248 } 249 DCHECK_EQ(index3_offset.size(), num_blocks); 250 DCHECK_EQ(index_offset.size(), num_blocks); 251 DCHECK_EQ(3 * index3.size(), data3.size()); 252 DCHECK_EQ(index.size(), data.size()); 253 } 254 255 template <typename T> 256 void SparseSlice<T>::Clear() { 257 index3_offset.clear(); 258 index3.clear(); 259 data3.clear(); 260 index_offset.clear(); 261 index.clear(); 262 data.clear(); 263 } 264 265 using Packet = Eigen::internal::packet_traits<float>::type; 266 const int kNumOperands = (sizeof(Packet) / sizeof(float)); 267 #define LOAD(x) Eigen::internal::pload<Packet>(x); 268 #define EXPAND_BFLOAT_L(x, y) \ 269 const auto y = Eigen::internal::pexpand_bf16_l<Packet>(x); 270 #define EXPAND_BFLOAT_U(x, y) \ 271 const auto y = Eigen::internal::pexpand_bf16_u<Packet>(x); 272 #define STORE(x, y) Eigen::internal::pstore<float>(x, y); 273 #define FMA(a, b, c, d) d = Eigen::internal::pmadd<Packet>(a, b, c); 274 275 #define ALWAYS_INLINE EIGEN_ALWAYS_INLINE 276 277 ALWAYS_INLINE float ConvertBfloat16ToFloat(const bfloat16* src) { 278 float out = 0; 279 auto tmp = reinterpret_cast<bfloat16*>(&out); 280 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ 281 tmp[0] = *src; 282 #else 283 tmp[1] = *src; 284 #endif 285 return out; 286 } 287 288 ALWAYS_INLINE Packet ConvertFourBfloat16ToFloat(const bfloat16* src) { 289 return Eigen::internal::pload4bf16<Packet>( 290 reinterpret_cast<const float*>(src)); 291 } 292 293 ALWAYS_INLINE Packet ConvertTwoBfloat16ToFloat(const bfloat16* src) { 294 return Eigen::internal::pload2bf16<Packet>( 295 reinterpret_cast<const float*>(src)); 296 } 297 298 ALWAYS_INLINE void ScalarMulAdd(const float a, const float** inp, float** out) { 299 **out += a * **inp; 300 ++*inp; 301 ++*out; 302 } 303 304 ALWAYS_INLINE void ScalarMulAdd(const float a, const bfloat16** inp, 305 float** out) { 306 float inp_f = ConvertBfloat16ToFloat(*inp); 307 **out += a * inp_f; 308 ++*inp; 309 ++*out; 310 } 311 ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2, 312 const float a3, const bfloat16** inp1, 313 const bfloat16** inp2, 314 const bfloat16** inp3, float** out) { 315 float inp1_f = ConvertBfloat16ToFloat(*inp1); 316 float inp2_f = ConvertBfloat16ToFloat(*inp2); 317 float inp3_f = ConvertBfloat16ToFloat(*inp3); 318 **out += a1 * inp1_f + a2 * inp2_f + a3 * inp3_f; 319 ++*out; 320 ++*inp1; 321 ++*inp2; 322 ++*inp3; 323 } 324 325 ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2, 326 const float a3, const float** inp1, 327 const float** inp2, const float** inp3, 328 float** out) { 329 **out += a1 * **inp1 + a2 * **inp2 + a3 * **inp3; 330 ++*out; 331 ++*inp1; 332 ++*inp2; 333 ++*inp3; 334 } 335 336 ALWAYS_INLINE void LoadSingleScalar(const bfloat16** data, Packet* l) { 337 auto tmp = ConvertBfloat16ToFloat(*data); 338 *l = Eigen::internal::pset1<Packet>(tmp); 339 ++*data; 340 } 341 342 ALWAYS_INLINE void LoadTwoScalars(const bfloat16** data, Packet* l1, 343 Packet* l2) { 344 if (kNumOperands >= 2) { 345 auto tmp = ConvertTwoBfloat16ToFloat(*data); 346 *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp); 347 *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp); 348 *data += 2; 349 } else { 350 LoadSingleScalar(data, l1); 351 LoadSingleScalar(data, l2); 352 } 353 } 354 355 ALWAYS_INLINE void LoadFourScalars(const bfloat16** data, Packet* l1, 356 Packet* l2, Packet* l3, Packet* l4) { 357 if (kNumOperands >= 4) { 358 auto tmp = ConvertFourBfloat16ToFloat(*data); 359 *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp); 360 *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp); 361 *l3 = Eigen::internal::pbroadcast_third<Packet>(tmp); 362 *l4 = Eigen::internal::pbroadcast_fourth<Packet>(tmp); 363 *data += 4; 364 } else { 365 LoadTwoScalars(data, l1, l2); 366 LoadTwoScalars(data, l3, l4); 367 } 368 } 369 370 ALWAYS_INLINE void LoadSingleScalar(const float** data, Packet* l) { 371 *l = Eigen::internal::pload1<Packet>(*data); 372 ++(*data); 373 } 374 375 ALWAYS_INLINE void LoadTwoScalars(const float** data, Packet* l1, Packet* l2) { 376 LoadSingleScalar(data, l1); 377 LoadSingleScalar(data, l2); 378 } 379 380 ALWAYS_INLINE void LoadFourScalars(const float** data, Packet* l1, Packet* l2, 381 Packet* l3, Packet* l4) { 382 LoadTwoScalars(data, l1, l2); 383 LoadTwoScalars(data, l3, l4); 384 } 385 386 template <typename T> 387 ALWAYS_INLINE void LoadThreeScalars(const T** data, Packet* l1, Packet* l2, 388 Packet* l3) { 389 LoadTwoScalars(data, l1, l2); 390 LoadSingleScalar(data, l3); 391 } 392 393 template <typename T> 394 ALWAYS_INLINE void LoadSixScalars(const T** data, Packet* l1, Packet* l2, 395 Packet* l3, Packet* l4, Packet* l5, 396 Packet* l6) { 397 LoadFourScalars(data, l1, l2, l3, l4); 398 LoadTwoScalars(data, l5, l6); 399 } 400 401 // Vectorized version of ScalarMulAdd. 402 ALWAYS_INLINE void MulAdd(const Packet a, const bfloat16** binp, float** out) { 403 auto inp = reinterpret_cast<const float*>(*binp); 404 const auto b = LOAD(inp); 405 EXPAND_BFLOAT_L(b, b_0); 406 EXPAND_BFLOAT_U(b, b_1); 407 *binp += 2 * kNumOperands; 408 auto c1 = LOAD(*out); 409 auto c2 = LOAD(*out + kNumOperands); 410 FMA(a, b_0, c1, c1); 411 FMA(a, b_1, c2, c2); 412 STORE(*out, c1); 413 STORE(*out + kNumOperands, c2); 414 *out += 2 * kNumOperands; 415 } 416 417 // Vectorized version of ScalarMulAdd3Way. 418 ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3, 419 const bfloat16** binp1, const bfloat16** binp2, 420 const bfloat16** binp3, float** out) { 421 auto inp1 = reinterpret_cast<const float*>(*binp1); 422 auto inp2 = reinterpret_cast<const float*>(*binp2); 423 auto inp3 = reinterpret_cast<const float*>(*binp3); 424 auto c1 = LOAD(*out); 425 auto c2 = LOAD(*out + kNumOperands); 426 const auto b1 = LOAD(inp1); 427 EXPAND_BFLOAT_L(b1, b1_0); 428 EXPAND_BFLOAT_U(b1, b1_1); 429 *binp1 += 2 * kNumOperands; 430 const auto b2 = LOAD(inp2); 431 EXPAND_BFLOAT_L(b2, b2_0); 432 EXPAND_BFLOAT_U(b2, b2_1); 433 *binp2 += 2 * kNumOperands; 434 const auto b3 = LOAD(inp3); 435 EXPAND_BFLOAT_L(b3, b3_0); 436 EXPAND_BFLOAT_U(b3, b3_1); 437 *binp3 += 2 * kNumOperands; 438 FMA(a1, b1_0, c1, c1); 439 FMA(a1, b1_1, c2, c2); 440 FMA(a2, b2_0, c1, c1); 441 FMA(a2, b2_1, c2, c2); 442 FMA(a3, b3_0, c1, c1); 443 FMA(a3, b3_1, c2, c2); 444 STORE(*out, c1); 445 STORE(*out + kNumOperands, c2); 446 *out += 2 * kNumOperands; 447 } 448 449 // Unroll MulAdd3Way for two iterations 450 ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2, 451 const Packet a3, const bfloat16** binp1, 452 const bfloat16** binp2, const bfloat16** binp3, 453 float** out) { 454 auto inp1 = reinterpret_cast<const float*>(*binp1); 455 auto inp2 = reinterpret_cast<const float*>(*binp2); 456 auto inp3 = reinterpret_cast<const float*>(*binp3); 457 auto c1 = LOAD(*out); 458 auto c2 = LOAD(*out + kNumOperands); 459 const auto b1 = LOAD(inp1); 460 const auto b2 = LOAD(inp2); 461 const auto b3 = LOAD(inp3); 462 463 EXPAND_BFLOAT_L(b1, b1_0); 464 EXPAND_BFLOAT_U(b1, b1_1); 465 EXPAND_BFLOAT_L(b2, b2_0); 466 EXPAND_BFLOAT_U(b2, b2_1); 467 EXPAND_BFLOAT_L(b3, b3_0); 468 EXPAND_BFLOAT_U(b3, b3_1); 469 auto c3 = LOAD(*out + 2 * kNumOperands); 470 auto c4 = LOAD(*out + 3 * kNumOperands); 471 const auto b4 = LOAD(inp1 + kNumOperands); 472 const auto b5 = LOAD(inp2 + kNumOperands); 473 const auto b6 = LOAD(inp3 + kNumOperands); 474 475 EXPAND_BFLOAT_L(b4, b4_0); 476 EXPAND_BFLOAT_U(b4, b4_1); 477 EXPAND_BFLOAT_L(b5, b5_0); 478 EXPAND_BFLOAT_U(b5, b5_1); 479 EXPAND_BFLOAT_L(b6, b6_0); 480 EXPAND_BFLOAT_U(b6, b6_1); 481 482 FMA(a1, b1_0, c1, c1); 483 FMA(a1, b1_1, c2, c2); 484 FMA(a1, b4_0, c3, c3); 485 FMA(a1, b4_1, c4, c4); 486 FMA(a2, b2_0, c1, c1); 487 FMA(a2, b2_1, c2, c2); 488 FMA(a2, b5_0, c3, c3); 489 FMA(a2, b5_1, c4, c4); 490 FMA(a3, b3_0, c1, c1); 491 FMA(a3, b3_1, c2, c2); 492 FMA(a3, b6_0, c3, c3); 493 FMA(a3, b6_1, c4, c4); 494 STORE(*out, c1); 495 STORE(*out + kNumOperands, c2); 496 STORE(*out + 2 * kNumOperands, c3); 497 STORE(*out + 3 * kNumOperands, c4); 498 *out += 4 * kNumOperands; 499 *binp1 += 4 * kNumOperands; 500 *binp2 += 4 * kNumOperands; 501 *binp3 += 4 * kNumOperands; 502 } 503 504 // Apply MulAdd3Way on 128 operands. 505 ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2, 506 const Packet a3, const bfloat16** inp1, 507 const bfloat16** inp2, const bfloat16** inp3, 508 float** out) { 509 for (int k = 0; k < 128 / (8 * kNumOperands); ++k) { 510 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); 511 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); 512 } 513 } 514 515 // Vectorized version of ScalarMulAdd 516 ALWAYS_INLINE void MulAdd(const Packet a, const float** inp, float** out) { 517 const auto b = LOAD(*inp); 518 *inp += kNumOperands; 519 auto c = LOAD(*out); 520 FMA(a, b, c, c); 521 STORE(*out, c); 522 *out += kNumOperands; 523 } 524 525 // Vectorized version of ScalarMulAdd3Way 526 ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3, 527 const float** inp1, const float** inp2, 528 const float** inp3, float** out) { 529 auto c = LOAD(*out); 530 const auto b1 = LOAD(*inp1); 531 *inp1 += kNumOperands; 532 const auto b2 = LOAD(*inp2); 533 *inp2 += kNumOperands; 534 const auto b3 = LOAD(*inp3); 535 *inp3 += kNumOperands; 536 FMA(a1, b1, c, c); 537 FMA(a2, b2, c, c); 538 FMA(a3, b3, c, c); 539 STORE(*out, c); 540 *out += kNumOperands; 541 } 542 543 // Unroll MulAdd3Way for two iterations 544 ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2, 545 const Packet a3, const float** inp1, 546 const float** inp2, const float** inp3, 547 float** out) { 548 auto c1 = LOAD(*out); 549 const auto b1 = LOAD(*inp1); 550 const auto b2 = LOAD(*inp2); 551 const auto b3 = LOAD(*inp3); 552 553 auto c2 = LOAD(*out + kNumOperands); 554 const auto b4 = LOAD(*inp1 + kNumOperands); 555 const auto b5 = LOAD(*inp2 + kNumOperands); 556 const auto b6 = LOAD(*inp3 + kNumOperands); 557 558 FMA(a1, b1, c1, c1); 559 FMA(a1, b4, c2, c2); 560 FMA(a2, b2, c1, c1); 561 FMA(a2, b5, c2, c2); 562 FMA(a3, b3, c1, c1); 563 FMA(a3, b6, c2, c2); 564 STORE(*out, c1); 565 STORE(*out + kNumOperands, c2); 566 *out += 2 * kNumOperands; 567 *inp1 += 2 * kNumOperands; 568 *inp2 += 2 * kNumOperands; 569 *inp3 += 2 * kNumOperands; 570 } 571 572 // Unroll MulAdd3Way for four iterations 573 ALWAYS_INLINE void FourMulAdd3Way(const Packet a1, const Packet a2, 574 const Packet a3, const float** inp1, 575 const float** inp2, const float** inp3, 576 float** out) { 577 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); 578 TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); 579 } 580 581 // Apply MulAdd3Way on 128 operands. 582 ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2, 583 const Packet a3, const float** inp1, 584 const float** inp2, const float** inp3, 585 float** out) { 586 if (kNumOperands == 8) { 587 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); 588 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); 589 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); 590 FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); 591 } else { 592 DCHECK_LE(4 * kNumOperands, 128); 593 for (int i = 0; i < 128 / (4 * kNumOperands); ++i) { 594 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); 595 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); 596 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); 597 MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); 598 } 599 } 600 } 601 // Computes product of "left_slices" with "num_cols" columns of "right", and 602 // stores the output in *"output". 603 // Note that left_slices is a list of SparseSlices, which are conceptually 604 // assumed to be concatenated along the column dimension. Also each SparseSlice 605 // is encoded as a list of blocks with upto N columns. See SparseSlice for more 606 // details. 607 template <typename TL, typename TR, int Cols> 608 inline void GEPP( 609 const std::vector<SparseSlice<TL>*>& left_slices, 610 const Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>, 611 Eigen::Aligned>& right, 612 const int num_cols, Matrix* output) { 613 const int cols = (Cols == -1) ? num_cols : Cols; 614 DCHECK_EQ(num_cols, cols); 615 const int right_num_cols = right.dimension(1); 616 const int output_num_cols = output->dimension(1); 617 static const int kNumOperandsR = kNumOperands * sizeof(float) / sizeof(TR); 618 const int cols_mod = cols % kNumOperandsR; 619 int k_offset = 0; 620 // Pre-compute pointers for output matrix. 621 float* out_ptrs[M]; 622 float* const out_start = &(*output)(0, 0); 623 for (int j = 0; j < M; ++j) { 624 out_ptrs[j] = out_start + output_num_cols * j; 625 } 626 for (const auto* left_slice : left_slices) { 627 const auto& left = *left_slice; 628 const auto* data3 = (!left.data3.empty()) ? &left.data3[0] : nullptr; 629 const auto* data = (!left.data.empty()) ? &left.data[0] : nullptr; 630 const int num_blocks = left.index3_offset.size(); 631 int begin3 = 0; 632 int begin = 0; 633 for (int i = 0; i < num_blocks; ++i) { 634 // Pre-compute pointers for right matrix 635 const TR* right_ptrs[K]; 636 const auto* const right_start = &right(k_offset, 0); 637 DCHECK_LT(k_offset, right.dimension(0)); 638 for (int j = 0; j < K; ++j) { 639 right_ptrs[j] = right_start + right_num_cols * j; 640 } 641 642 const int end3 = left.index3_offset[i]; 643 int j = begin3; 644 // Loop unrolled for 2 iterations. 645 for (; j + 1 < end3; j += 2) { 646 Packet l1, l2, l3, nl1, nl2, nl3; 647 LoadSixScalars(&data3, &l1, &l2, &l3, &nl1, &nl2, &nl3); 648 const auto& index = left.index3[j]; 649 const auto& nindex = left.index3[j + 1]; 650 float* out = out_ptrs[index.m]; 651 float* nout = out_ptrs[nindex.m]; 652 const auto* r1 = right_ptrs[index.k1]; 653 const auto* r2 = right_ptrs[index.k2]; 654 const auto* r3 = right_ptrs[index.k3]; 655 656 const auto* nr1 = right_ptrs[nindex.k1]; 657 const auto* nr2 = right_ptrs[nindex.k2]; 658 const auto* nr3 = right_ptrs[nindex.k3]; 659 if (cols == 128) { 660 MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out); 661 MulAdd3Way128(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout); 662 } else { 663 for (int n = 0; n < cols / kNumOperandsR; ++n) { 664 MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out); 665 MulAdd3Way(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout); 666 } 667 668 const float sl1 = Eigen::internal::pfirst<Packet>(l1); 669 const float sl2 = Eigen::internal::pfirst<Packet>(l2); 670 const float sl3 = Eigen::internal::pfirst<Packet>(l3); 671 const float nsl1 = Eigen::internal::pfirst<Packet>(nl1); 672 const float nsl2 = Eigen::internal::pfirst<Packet>(nl2); 673 const float nsl3 = Eigen::internal::pfirst<Packet>(nl3); 674 for (int k = 0; k < cols_mod; ++k) { 675 ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out); 676 ScalarMulAdd3Way(nsl1, nsl2, nsl3, &nr1, &nr2, &nr3, &nout); 677 } 678 } 679 } 680 if (j < end3) { 681 Packet l1, l2, l3; 682 LoadThreeScalars(&data3, &l1, &l2, &l3); 683 684 const auto& index = left.index3[j]; 685 float* out = out_ptrs[index.m]; 686 const auto* r1 = right_ptrs[index.k1]; 687 const auto* r2 = right_ptrs[index.k2]; 688 const auto* r3 = right_ptrs[index.k3]; 689 if (cols == 128) { 690 MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out); 691 } else { 692 for (int n = 0; n < cols / kNumOperandsR; ++n) { 693 MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out); 694 } 695 const float sl1 = Eigen::internal::pfirst<Packet>(l1); 696 const float sl2 = Eigen::internal::pfirst<Packet>(l2); 697 const float sl3 = Eigen::internal::pfirst<Packet>(l3); 698 for (int k = 0; k < cols_mod; ++k) { 699 ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out); 700 } 701 } 702 } 703 begin3 = end3; 704 int end = left.index_offset[i]; 705 // Loop unrolled for 4 iterations. 706 j = begin; 707 for (; j + 3 < end; j += 4) { 708 Packet l, nl, n2l, n3l; 709 LoadFourScalars(&data, &l, &nl, &n2l, &n3l); 710 711 const auto& index = left.index[j]; 712 const auto& nindex = left.index[j + 1]; 713 const auto& n2index = left.index[j + 2]; 714 const auto& n3index = left.index[j + 3]; 715 const auto* r = right_ptrs[index.k]; 716 const auto* nr = right_ptrs[nindex.k]; 717 const auto* n2r = right_ptrs[n2index.k]; 718 const auto* n3r = right_ptrs[n3index.k]; 719 float* out = out_ptrs[index.m]; 720 float* nout = out_ptrs[nindex.m]; 721 float* n2out = out_ptrs[n2index.m]; 722 float* n3out = out_ptrs[n3index.m]; 723 724 for (int n = 0; n < cols / kNumOperandsR; ++n) { 725 MulAdd(l, &r, &out); 726 MulAdd(nl, &nr, &nout); 727 MulAdd(n2l, &n2r, &n2out); 728 MulAdd(n3l, &n3r, &n3out); 729 } 730 731 const float sl1 = Eigen::internal::pfirst<Packet>(l); 732 const float sl2 = Eigen::internal::pfirst<Packet>(nl); 733 const float sl3 = Eigen::internal::pfirst<Packet>(n2l); 734 const float sl4 = Eigen::internal::pfirst<Packet>(n3l); 735 for (int k = 0; k < cols_mod; ++k) { 736 ScalarMulAdd(sl1, &r, &out); 737 ScalarMulAdd(sl2, &nr, &nout); 738 ScalarMulAdd(sl3, &n2r, &n2out); 739 ScalarMulAdd(sl4, &n3r, &n3out); 740 } 741 } 742 while (j < end) { 743 Packet l; 744 LoadSingleScalar(&data, &l); 745 const auto& index = left.index[j]; 746 const auto* r = right_ptrs[index.k]; 747 float* out = out_ptrs[index.m]; 748 for (int n = 0; n < cols / kNumOperandsR; ++n) { 749 MulAdd(l, &r, &out); 750 } 751 const float sl = Eigen::internal::pfirst<Packet>(l); 752 for (int k = 0; k < cols_mod; ++k) { 753 ScalarMulAdd(sl, &r, &out); 754 } 755 j++; 756 } 757 k_offset += left.block_size; 758 begin = end; 759 } 760 } 761 } 762 763 #undef LOAD 764 #undef EXPAND_BFLOAT_L 765 #undef EXPAND_BFLOAT_U 766 #undef STORE 767 #undef FMA 768 769 } // namespace 770 771 template <typename TL, typename TR> 772 class SparseMatMul { 773 using MatrixL = BasicMatrix<TL>; 774 using MatrixR = BasicMatrix<TR>; 775 using ConstMatrixMapL = BasicMatrixMap<const TL>; 776 using ConstMatrixMapR = BasicMatrixMap<const TR>; 777 using MatrixMapR = BasicMatrixMap<TR>; 778 779 public: 780 // Not used; added to match interface of LibxsmmSparseMatMul 781 struct TensorInfoCache {}; 782 783 // Perform matrix multiplication of "left" and "right", and store the result 784 // in *"output". 785 public: 786 static inline void Compute(TensorInfoCache* cache, 787 const ConstMatrixMapL& left, 788 const ConstMatrixMapR& right, bool transpose_left, 789 const DeviceBase::CpuWorkerThreads* thread_pool, 790 bool transpose_output, MatrixMap* output); 791 792 private: 793 // Computes multiplication of left and num_cols columns of right, and stores 794 // the output block in *"output" at offsets "output_row_offset" and 795 // "output_col_offset". If assign is true, assigns the value to that block, 796 // else adds the values to the existing values. 797 static inline void ComputeOutputBlock( 798 const std::vector<SparseSlice<TL>*>& left, const ConstMatrixMapR& right, 799 int num_cols, int output_row_offset, int output_col_offset, bool assign, 800 bool transpose_output, MatrixMap* output); 801 802 // Encodes "mat" using a sparse representation and stores that in 803 // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and 804 // "slice_num_cols", each grid element is converted into a SparseSlice and 805 // stored in mat_slices. "slice_block_size" is used to perform further column 806 // blocking of each slice. 807 static inline std::unique_ptr<BlockingCounter> CreateSparseSlices( 808 const ConstMatrixMapL& mat, bool transpose, int slice_num_rows, 809 int slice_block_size, int slice_num_cols, 810 std::vector<std::vector<SparseSlice<TL>*>>* mat_slices, 811 const DeviceBase::CpuWorkerThreads* thread_pool); 812 813 // This function chops "mat" along column dimension into pieces with at most N 814 // columns, and concatenates the pieces one after the other in "buffer". It 815 // returns the list of the pieces in "slices". It returns a BlockingCounter 816 // which should be used to wait for the shuffle operations to complete. 817 static inline std::unique_ptr<BlockingCounter> CreateDenseSlices( 818 const ConstMatrixMapR& mat, int row_start, int num_rows, int col_start, 819 int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool, 820 MatrixR* buffer, std::vector<ConstMatrixMapR*>* slices); 821 822 // Helper function for CreateDenseSlices to move the data around. It returns a 823 // BlockingCounter which should be used to wait for the shuffle operations to 824 // complete. 825 static inline BlockingCounter* ShuffleMatrix( 826 const ConstMatrixMapR& mat, int slice_row_start, int slice_num_rows, 827 int slice_col_start, int slice_num_cols, const int N, 828 const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer); 829 830 // Helper function for CreateDenseSlices to create slices. 831 static inline void SliceMatrix(const MatrixR& mat, const int num_rows, 832 const int num_slices, 833 std::vector<ConstMatrixMapR*>* slices); 834 835 // Heuristics to compute various block sizes. 836 // KR, NR: block sizes for "right". We run blocking iterations that operate on 837 // matrices with at most this size. 838 // KL: grid size along the column dimension used while encoding left. 839 // IB, JB: number of left and right slices to multiply together. This is used 840 // for ordering different ComputeBlockOutput operations inside each blocking 841 // iteration so as to potentially reduce the working set size. 842 static inline void ComputeBlockSizes(const ConstMatrixMapL& left, 843 const ConstMatrixMapR& right, 844 bool transpose_left, int num_threads, 845 int* KR, int* NR, int* KL, int* JB, 846 int* IB); 847 848 TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMul); 849 }; 850 851 #ifdef TENSORFLOW_USE_LIBXSMM 852 template <typename TL, typename TR> 853 class LibxsmmSparseMatMul { 854 using MatrixL = BasicMatrix<TL>; 855 using MatrixR = BasicMatrix<TR>; 856 using ConstMatrixMapL = BasicMatrixMap<const TL>; 857 using ConstMatrixMapR = BasicMatrixMap<const TR>; 858 using MatrixMapR = BasicMatrixMap<TR>; 859 860 public: 861 // This structure contains a set of libxsmm kernels for sizes that have been 862 // encountered previously by this operator so that libxsmm does not need to 863 // reallocate its scratchpad memory each time (which hurts performance 864 // substantially). 865 struct TensorInfoCache { 866 struct TensorInfoCacheEntry { 867 // Parameters for kernel 868 int M; 869 int K; 870 int N; 871 int max_threads; 872 // libxsmm handle and matrix data 873 libxsmm_spmdm_handle handle; 874 libxsmm_CSR_sparseslice* output_csr; 875 // Chain to non-libxsmm implementation's cache in case that ever becomes 876 // useful (it is an empty struct right now) 877 typename SparseMatMul<TL, TR>::TensorInfoCache 878 non_libxsmm_cache; // Currently not used 879 }; 880 // protects entries; invariant: entries is a valid std::multimap 881 tensorflow::mutex lock; 882 // Because there could be multiple matrix multiplies with the same sizes 883 // going on at the same time, we need to allow multiple cache entries for a 884 // given set of parameters. Taking and returning entries is used to make 885 // sure the same cache entry is not used from two threads at a time. 886 std::multimap<std::tuple<int, int, int, int>, 887 std::unique_ptr<TensorInfoCacheEntry>> 888 entries GUARDED_BY(lock); 889 890 TensorInfoCache() : lock(), entries() {} 891 // Look up and remove first entry with these parameters, creating one if 892 // there isn't one 893 std::unique_ptr<TensorInfoCacheEntry> take_cache_entry(int M, int K, int N, 894 int max_threads) 895 LOCKS_EXCLUDED(lock) { 896 tensorflow::mutex_lock ml(lock); 897 auto key = std::make_tuple(M, K, N, max_threads); 898 auto it = entries.find(key); 899 if (it != entries.end()) { 900 auto val = std::move(it->second); 901 entries.erase(it); 902 return val; 903 } else { 904 std::unique_ptr<TensorInfoCacheEntry> e{ 905 new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}}; 906 // setup scoped allocator, which uses cpu_allocator() for this scope 907 const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator; 908 libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr); 909 return e; 910 } 911 } 912 // Add a cache entry with certain parameters 913 void return_cache_entry(std::unique_ptr<TensorInfoCacheEntry> e) 914 LOCKS_EXCLUDED(lock) { 915 tensorflow::mutex_lock ml(lock); 916 auto key = std::make_tuple(e->M, e->K, e->N, e->max_threads); 917 entries.insert(std::make_pair(key, std::move(e))); 918 } 919 ~TensorInfoCache() { 920 tensorflow::mutex_lock ml(lock); 921 for (auto& p : entries) { 922 libxsmm_spmdm_destroy(&p.second->handle); 923 } 924 entries.clear(); 925 } 926 927 private: 928 TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCache); 929 }; 930 931 // Perform matrix multiplication of "left" and "right", and store the result 932 // in *"output". 933 public: 934 static inline void Compute(TensorInfoCache* cache, 935 const ConstMatrixMapL& left, 936 const ConstMatrixMapR& right, bool transpose_left, 937 const DeviceBase::CpuWorkerThreads* thread_pool, 938 bool transpose_output, MatrixMap* output); 939 940 private: 941 TF_DISALLOW_COPY_AND_ASSIGN(LibxsmmSparseMatMul); 942 }; 943 #endif 944 945 template <typename TL, typename TR, 946 template <typename TL2, typename TR2> class DoMatMul> 947 class SparseMatMulOp : public OpKernel { 948 using MatrixR = BasicMatrix<TR>; 949 using ConstMatrixMapR = BasicMatrixMap<const TR>; 950 951 public: 952 explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 953 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); 954 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); 955 OP_REQUIRES_OK(ctx, ctx->GetAttr("a_is_sparse", &a_is_sparse_)); 956 OP_REQUIRES_OK(ctx, ctx->GetAttr("b_is_sparse", &b_is_sparse_)); 957 } 958 959 void Compute(OpKernelContext* ctx) override { 960 const Tensor& a = ctx->input(0); 961 const Tensor& b = ctx->input(1); 962 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()), 963 errors::InvalidArgument("a is not a matrix")); 964 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()), 965 errors::InvalidArgument("b is not a matrix")); 966 967 const int m = transpose_a_ ? a.dim_size(1) : a.dim_size(0); 968 const int k = transpose_a_ ? a.dim_size(0) : a.dim_size(1); 969 const int n = transpose_b_ ? b.dim_size(0) : b.dim_size(1); 970 const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0); 971 972 OP_REQUIRES(ctx, k == k2, 973 errors::InvalidArgument( 974 "Matrix size incompatible: a: ", a.shape().DebugString(), 975 ", b: ", b.shape().DebugString())); 976 Tensor* output = nullptr; 977 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output)); 978 979 if (k == 0) { 980 // If the inner dimension k in the matrix multiplication is zero, we fill 981 // the output with zeros. 982 functor::SetZeroFunctor<CPUDevice, float> f; 983 f(ctx->eigen_device<CPUDevice>(), output->flat<float>()); 984 return; 985 } 986 987 auto out = output->matrix<float>(); 988 989 std::unique_ptr<Tensor> a_float; 990 std::unique_ptr<Tensor> b_float; 991 if (!a_is_sparse_ && !b_is_sparse_) { 992 auto left = &a; 993 auto right = &b; 994 // TODO(agarwal): multi-thread the conversions from bfloat16 to float. 995 if (std::is_same<TL, bfloat16>::value) { 996 a_float.reset(new Tensor(DT_FLOAT, a.shape())); 997 BFloat16ToFloat(a.flat<bfloat16>().data(), 998 a_float->flat<float>().data(), a.NumElements()); 999 left = a_float.get(); 1000 } 1001 if (std::is_same<TR, bfloat16>::value) { 1002 b_float.reset(new Tensor(DT_FLOAT, b.shape())); 1003 BFloat16ToFloat(b.flat<bfloat16>().data(), 1004 b_float->flat<float>().data(), b.NumElements()); 1005 right = b_float.get(); 1006 } 1007 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; 1008 dim_pair[0].first = transpose_a_ ? 0 : 1; 1009 dim_pair[0].second = transpose_b_ ? 1 : 0; 1010 1011 out.device(ctx->template eigen_device<CPUDevice>()) = 1012 left->matrix<float>().contract(right->matrix<float>(), dim_pair); 1013 return; 1014 } 1015 1016 auto left = &a; 1017 auto right = &b; 1018 bool transpose_output = false; 1019 bool transpose_a = transpose_a_; 1020 bool transpose_b = transpose_b_; 1021 if (!a_is_sparse_) { 1022 // Swap the order of multiplications using the identity: 1023 // A * B = (B' * A')'. 1024 std::swap(left, right); 1025 std::swap(transpose_a, transpose_b); 1026 transpose_a = !transpose_a; 1027 transpose_b = !transpose_b; 1028 transpose_output = !transpose_output; 1029 } 1030 1031 std::unique_ptr<Tensor> right_tr; 1032 if (transpose_b) { 1033 // TODO(agarwal): avoid transposing the matrix here and directly handle 1034 // transpose in CreateDenseSlices. 1035 right_tr.reset( 1036 new Tensor(right->dtype(), 1037 TensorShape({right->dim_size(1), right->dim_size(0)}))); 1038 1039 const auto perm = dsizes_10(); 1040 if (transpose_output) { 1041 right_tr->matrix<TL>().device(ctx->template eigen_device<CPUDevice>()) = 1042 right->matrix<TL>().shuffle(perm); 1043 } else { 1044 right_tr->matrix<TR>().device(ctx->template eigen_device<CPUDevice>()) = 1045 right->matrix<TR>().shuffle(perm); 1046 } 1047 right = right_tr.get(); 1048 } 1049 1050 if (transpose_output) { 1051 DoMatMul<TR, TL>::Compute(&this->cache_tr_, left->matrix<TR>(), 1052 right->matrix<TL>(), transpose_a, 1053 ctx->device()->tensorflow_cpu_worker_threads(), 1054 transpose_output, &out); 1055 } else { 1056 DoMatMul<TL, TR>::Compute(&this->cache_nt_, left->matrix<TL>(), 1057 right->matrix<TR>(), transpose_a, 1058 ctx->device()->tensorflow_cpu_worker_threads(), 1059 transpose_output, &out); 1060 } 1061 } 1062 1063 private: 1064 bool transpose_a_; 1065 bool transpose_b_; 1066 bool a_is_sparse_; 1067 bool b_is_sparse_; 1068 1069 // Cache for non-transposed-output multiply 1070 typename DoMatMul<TL, TR>::TensorInfoCache cache_nt_; 1071 // Cache for transposed-output multiply 1072 typename DoMatMul<TR, TL>::TensorInfoCache cache_tr_; 1073 1074 TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp); 1075 }; 1076 1077 template <typename TL, typename TR> 1078 inline void SparseMatMul<TL, TR>::ComputeOutputBlock( 1079 const std::vector<SparseSlice<TL>*>& left, 1080 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, int num_cols, 1081 int output_row_offset, int output_col_offset, bool assign, 1082 bool transpose_output, MatrixMap* output) { 1083 const auto perm = dsizes_10(); 1084 int num_rows = left[0]->num_rows; 1085 const int rhs_num_cols = right.dimension(1); 1086 DCHECK_LE(num_cols, rhs_num_cols); 1087 Matrix out(num_rows, rhs_num_cols); 1088 out.setZero(); 1089 if (num_cols == N) { 1090 GEPP<TL, TR, N>(left, right, num_cols, &out); 1091 } else { 1092 GEPP<TL, TR, -1>(left, right, num_cols, &out); 1093 } 1094 if (!assign) { 1095 const DSizes begin(output_row_offset, output_col_offset); 1096 const DSizes sizes(num_rows, num_cols); 1097 if (transpose_output) { 1098 if (num_cols == rhs_num_cols) { 1099 output->shuffle(perm).slice(begin, sizes) += out; 1100 } else { 1101 const auto zero = dsizes_00(); 1102 output->shuffle(perm).slice(begin, sizes) += out.slice(zero, sizes); 1103 } 1104 } else { 1105 if (num_cols == rhs_num_cols) { 1106 output->slice(begin, sizes) += out; 1107 } else { 1108 const auto zero = dsizes_00(); 1109 output->slice(begin, sizes) += out.slice(zero, sizes); 1110 } 1111 } 1112 } else { 1113 std::unique_ptr<Matrix> out_tr; 1114 if (transpose_output) { 1115 out_tr.reset(new Matrix(rhs_num_cols, num_rows)); 1116 *out_tr = out.shuffle(perm); 1117 std::swap(output_row_offset, output_col_offset); 1118 std::swap(num_rows, num_cols); 1119 } 1120 const Matrix& final_out = transpose_output ? *out_tr : out; 1121 for (int i = 0; i < num_rows; ++i) { 1122 memcpy(&(*output)(output_row_offset + i, output_col_offset), 1123 &final_out(i, 0), num_cols * sizeof(float)); 1124 } 1125 } 1126 } 1127 1128 template <typename TL, typename TR> 1129 inline std::unique_ptr<BlockingCounter> 1130 SparseMatMul<TL, TR>::CreateSparseSlices( 1131 const typename SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose, 1132 int slice_num_rows, int slice_block_size, int slice_num_cols, 1133 std::vector<std::vector<SparseSlice<TL>*>>* mat_slices, 1134 const DeviceBase::CpuWorkerThreads* thread_pool) { 1135 const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0); 1136 const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1); 1137 const int num_slices_dim0 = 1138 std::max(1, (mat_num_rows + slice_num_rows - 1) / slice_num_rows); 1139 const int num_slices_dim1 = 1140 std::max(1, (mat_num_cols + slice_num_cols - 1) / slice_num_cols); 1141 mat_slices->resize(num_slices_dim0); 1142 BlockingCounter* counter = 1143 new BlockingCounter(num_slices_dim0 * num_slices_dim1); 1144 auto work = [counter, transpose](SparseSlice<TL>* sparse_slice, 1145 SparseMatMul<TL, TR>::ConstMatrixMapL* slice, 1146 int col_offset) { 1147 if (transpose) { 1148 sparse_slice->template Initialize<true>(*slice, col_offset); 1149 } else { 1150 sparse_slice->template Initialize<false>(*slice, col_offset); 1151 } 1152 delete slice; 1153 counter->DecrementCount(); 1154 }; 1155 for (int i = 0; i < num_slices_dim0; ++i) { 1156 (*mat_slices)[i].resize(num_slices_dim1); 1157 int num_rows = 1158 std::min<int>(slice_num_rows, mat_num_rows - i * slice_num_rows); 1159 for (int j = 0; j < num_slices_dim1; ++j) { 1160 int num_cols = 1161 std::min<int>(slice_num_cols, mat_num_cols - j * slice_num_cols); 1162 SparseMatMul<TL, TR>::ConstMatrixMapL* slice = nullptr; 1163 if (transpose) { 1164 slice = new SparseMatMul<TL, TR>::ConstMatrixMapL( 1165 &mat(0, i * slice_num_rows), mat.dimensions()); 1166 } else { 1167 DSizes d(num_rows, mat_num_cols); 1168 slice = new SparseMatMul<TL, TR>::ConstMatrixMapL( 1169 &mat(i * slice_num_rows, 0), d); 1170 } 1171 auto* sparse_slice = 1172 new SparseSlice<TL>(num_rows, num_cols, slice_block_size); 1173 (*mat_slices)[i][j] = sparse_slice; 1174 thread_pool->workers->Schedule( 1175 [=]() { work(sparse_slice, slice, slice_num_cols * j); }); 1176 } 1177 } 1178 return std::unique_ptr<BlockingCounter>(counter); 1179 } 1180 #define LOAD(x) Eigen::internal::ploadu<Packet>((x)); 1181 #define INTERLEAVE(x) Eigen::internal::pinterleave4x64<Packet>(x); 1182 #define STORE(x, y) Eigen::internal::pstoreu<float>(x, y); 1183 1184 template <int NUM_ELEM = -1> 1185 ALWAYS_INLINE void CopyAndMayBeInterleaveBfloat16(void* bdst, const void* bsrc, 1186 int num_elements) { 1187 DCHECK_GE(kNumOperands, 8); 1188 static const int kStep = kNumOperands * sizeof(float) / sizeof(bfloat16); 1189 const int num = (NUM_ELEM == -1) ? num_elements : NUM_ELEM; 1190 DCHECK_EQ(num, num_elements); 1191 const float* src = reinterpret_cast<const float*>(bsrc); 1192 float* dst = reinterpret_cast<float*>(bdst); 1193 for (int index = 0; index + kStep <= num; index += kStep) { 1194 auto in = LOAD(src); 1195 auto tmp = INTERLEAVE(in); 1196 STORE(dst, tmp); 1197 src += kNumOperands; 1198 dst += kNumOperands; 1199 } 1200 if (num % kStep != 0) { 1201 memcpy(dst, src, (num % kStep) * sizeof(bfloat16)); 1202 } 1203 } 1204 1205 template <typename T> 1206 ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src, 1207 int num_elements) { 1208 if (std::is_same<T, float>::value || kNumOperands < 8) { 1209 memcpy(dst, src, num_elements * sizeof(T)); 1210 } else if (std::is_same<T, bfloat16>::value) { 1211 if (num_elements == N) { 1212 CopyAndMayBeInterleaveBfloat16<N>(dst, src, num_elements); 1213 } else { 1214 CopyAndMayBeInterleaveBfloat16<-1>(dst, src, num_elements); 1215 } 1216 } else { 1217 LOG(FATAL) << "Unsupported type"; 1218 } 1219 } 1220 1221 #undef LOAD 1222 #undef Interleave 1223 #undef Store 1224 1225 template <typename TL, typename TR> 1226 inline BlockingCounter* SparseMatMul<TL, TR>::ShuffleMatrix( 1227 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, 1228 int slice_row_start, int slice_num_rows, int slice_col_start, 1229 int slice_num_cols, const int N, 1230 const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) { 1231 DCHECK_EQ(N % 2, 0); 1232 DCHECK_LE(kNumOperands * sizeof(float) / sizeof(TR), N); 1233 int num_threads = std::min(thread_pool->num_threads, 16); 1234 BlockingCounter* counter = new BlockingCounter(num_threads); 1235 DCHECK_EQ(N, buffer->dimension(1)); 1236 auto shuffle_work = [&mat, slice_row_start, slice_num_rows, slice_col_start, 1237 slice_num_cols, N, buffer, counter](int s, int e) { 1238 const int row_start = s % slice_num_rows + slice_row_start; 1239 const int col_start = s / slice_num_rows * N + slice_col_start; 1240 auto* out_start = &(*buffer)(s, 0); 1241 const auto* input_start = &mat(row_start, col_start); 1242 const auto* input_end = &mat(slice_row_start + slice_num_rows - 1, 1243 slice_col_start + slice_num_cols - 1); 1244 const int mat_num_cols = mat.dimension(1); 1245 const int row_slice_size = slice_num_rows * mat_num_cols; 1246 1247 const int aligned_end = slice_num_cols / N * slice_num_rows; 1248 const int e1 = std::min(e, aligned_end); 1249 while (s < e1) { 1250 CopyAndMayBeInterleave<TR>(out_start, input_start, N); 1251 out_start += N; 1252 input_start += mat_num_cols; 1253 if (input_start > input_end) { 1254 input_start = input_start - row_slice_size + N; 1255 } 1256 ++s; 1257 } 1258 int s1 = std::max(s, aligned_end); 1259 const int copy_num_cols = slice_num_cols % N; 1260 while (s1 < e) { 1261 CopyAndMayBeInterleave<TR>(out_start, input_start, copy_num_cols); 1262 out_start += N; 1263 input_start += mat_num_cols; 1264 ++s1; 1265 } 1266 if (counter) counter->DecrementCount(); 1267 }; 1268 1269 int start = 0; 1270 int end = 0; 1271 int num_out_rows = (slice_num_cols + N - 1) / N * slice_num_rows; 1272 DCHECK_LE(num_out_rows, buffer->dimension(0)); 1273 for (int i = std::max(1, num_threads); i > 0; --i) { 1274 end = start + num_out_rows / i; 1275 thread_pool->workers->Schedule([=]() { shuffle_work(start, end); }); 1276 num_out_rows -= (end - start); 1277 start = end; 1278 } 1279 return counter; 1280 } 1281 1282 template <typename TL, typename TR> 1283 inline void SparseMatMul<TL, TR>::SliceMatrix( 1284 const MatrixR& mat, const int num_rows, const int num_slices, 1285 std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) { 1286 slices->resize(num_slices); 1287 DSizes d(num_rows, mat.dimension(1)); 1288 DCHECK_LE(num_rows * num_slices, mat.dimension(0)); 1289 for (int i = 0; i < num_slices; ++i) { 1290 (*slices)[i] = new ConstMatrixMapR(&mat(i * num_rows, 0), d); 1291 } 1292 } 1293 1294 template <typename TL, typename TR> 1295 inline std::unique_ptr<BlockingCounter> SparseMatMul<TL, TR>::CreateDenseSlices( 1296 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start, 1297 int num_rows, int col_start, int num_cols, 1298 const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer, 1299 std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) { 1300 std::unique_ptr<BlockingCounter> shuffle_counter(ShuffleMatrix( 1301 mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer)); 1302 const int num_slices = (num_cols + N - 1) / N; 1303 SliceMatrix(*buffer, num_rows, num_slices, slices); 1304 return shuffle_counter; 1305 } 1306 1307 template <typename TL, typename TR> 1308 inline void SparseMatMul<TL, TR>::ComputeBlockSizes( 1309 const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left, 1310 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, 1311 bool transpose_left, int num_threads, int* KR, int* NR, int* KL, int* JB, 1312 int* IB) { 1313 // Heuristics for calculating block sizes 1314 // Assume two hyperthreads per core. 1315 const int est_num_cores = std::max(1, (num_threads + 1) / 2); 1316 // Use block of rhs with at most 128K floats per core. 1317 const int mem = est_num_cores * 128 * 1024; 1318 *KR = std::min(static_cast<int>(right.dimension(0)), mem / 256); 1319 *NR = right.dimension(1); 1320 if (*KR * *NR > mem) { 1321 // 4096 may be enough to amortize the cost of writes. 1322 *KR = std::min<int>(*KR, 4096); 1323 } 1324 // Use sizes that are multiples of K and 256. 1325 *KR = std::max(1, *KR / K) * K; 1326 *NR = std::max(1, *NR / 256) * 256; 1327 if (*KR * *NR > mem) { 1328 *NR = mem / *KR; 1329 } 1330 *NR = std::max(1, *NR / 256) * 256; 1331 1332 const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0); 1333 const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1); 1334 for (*KL = 1024; *KL > K; *KL /= 2) { 1335 if (*KR % *KL == 0 && 1336 std::max<int>(1, left_dim0 / 64) * (left_dim1 / *KL) > est_num_cores) { 1337 break; 1338 } 1339 } 1340 DCHECK_EQ(*KL % K, 0); 1341 DCHECK_GE(*KR, *KL); 1342 if (*KR < right.dimension(0)) { 1343 CHECK_EQ(*KR % *KL, 0); 1344 } 1345 1346 *JB = std::max(1, static_cast<int>(sqrt(num_threads) / 2.0)); 1347 *IB = 8 * *JB; 1348 DCHECK_EQ(N * sizeof(float) % 64, size_t{0}); 1349 } 1350 1351 #ifdef TENSORFLOW_USE_LIBXSMM 1352 1353 template <typename F> 1354 void do_on_all_threads(const DeviceBase::CpuWorkerThreads* thread_pool, 1355 const F& f) { 1356 int num_threads = thread_pool->num_threads; 1357 if (num_threads == 0) { 1358 LOG(FATAL) << "Have 0 threads in thread pool"; 1359 } else if (num_threads == 1) { 1360 f(0); 1361 } else { 1362 BlockingCounter counter(num_threads - 1); 1363 for (int i = 1; i < num_threads; ++i) { 1364 thread_pool->workers->Schedule([&, i]() { 1365 f(i); 1366 counter.DecrementCount(); 1367 }); 1368 } 1369 f(0); 1370 counter.Wait(); 1371 } 1372 } 1373 1374 template <typename T> 1375 struct empty_type_wrapper {}; 1376 1377 // Copies of interface to libxsmm_spmdm_createSparseSlice_*_notrans_thread to 1378 // allow overloading 1379 void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( 1380 empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA, 1381 const float* A, libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, 1382 int tid, int nthreads) { 1383 return libxsmm_spmdm_createSparseSlice_fp32_thread( 1384 handle, transA, A, libxsmm_output_csr_a, block_id, tid, nthreads); 1385 } 1386 void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( 1387 empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle, 1388 char transA, const bfloat16* A, 1389 libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid, 1390 int nthreads) { 1391 return libxsmm_spmdm_createSparseSlice_bfloat16_thread( 1392 handle, transA, reinterpret_cast<const uint16*>(A), libxsmm_output_csr_a, 1393 block_id, tid, nthreads); 1394 } 1395 1396 void wrapper_libxsmm_spmdm_compute_generic_thread( 1397 empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle, 1398 char transA, char transB, const bfloat16* alpha, 1399 libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC, 1400 const bfloat16* beta, float* C, int block_id, int tid, int nthreads) { 1401 return libxsmm_spmdm_compute_bfloat16_thread( 1402 handle, transA, transB, reinterpret_cast<const uint16*>(alpha), A_sparse, 1403 reinterpret_cast<const uint16*>(B), transC, 1404 reinterpret_cast<const uint16*>(beta), C, block_id, tid, nthreads); 1405 } 1406 void wrapper_libxsmm_spmdm_compute_generic_thread( 1407 empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA, 1408 char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse, 1409 const float* B, char transC, const float* beta, float* C, int block_id, 1410 int tid, int nthreads) { 1411 return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha, 1412 A_sparse, B, transC, beta, C, 1413 block_id, tid, nthreads); 1414 } 1415 1416 template <typename TL, typename TR> 1417 inline void LibxsmmSparseMatMul<TL, TR>::Compute( 1418 typename LibxsmmSparseMatMul<TL, TR>::TensorInfoCache* cache, 1419 const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapL& left, 1420 const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right, 1421 bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, 1422 bool transpose_output, MatrixMap* output) { 1423 if (false) { 1424 // Not handled by libxsmm currently 1425 SparseMatMul<TL, TR>::Compute( 1426 nullptr /* Assumes no cached data for fallback */, left, right, 1427 transpose_left, thread_pool, transpose_output, output); 1428 return; 1429 } 1430 const int num_threads = thread_pool->num_threads; 1431 const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0); 1432 const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1); 1433 const int right_dim0 = right.dimension(0); 1434 const int right_dim1 = right.dimension(1); 1435 CHECK_EQ(left_dim1, right_dim0); 1436 CHECK_EQ(left_dim0, 1437 (transpose_output ? output->dimension(1) : output->dimension(0))); 1438 CHECK_EQ(right_dim1, 1439 (transpose_output ? output->dimension(0) : output->dimension(1))); 1440 if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) { 1441 // Causes problems in libxsmm 1442 SparseMatMul<TL, TR>::Compute( 1443 nullptr /* Assumes no cached data for fallback */, left, right, 1444 transpose_left, thread_pool, transpose_output, output); 1445 return; 1446 } 1447 auto left_data = left.data(); 1448 auto right_data = right.data(); 1449 auto output_data = output->data(); 1450 // Initialize libxsmm for this matrix; make sure another thread doesn't use 1451 // this handle 1452 auto entry = 1453 cache->take_cache_entry(left_dim0, right_dim0, right_dim1, num_threads); 1454 // Convert the left matrix to compressed sparse row (CSR) format 1455 ptrdiff_t total_num_creation_blocks = 1456 libxsmm_spmdm_get_num_createSparseSlice_blocks(&entry->handle); 1457 std::atomic<int> cur_create_block_number; 1458 cur_create_block_number.store(0); 1459 do_on_all_threads(thread_pool, [&](int i) { 1460 while (true) { 1461 int work_item = cur_create_block_number.fetch_add(1); 1462 if (work_item >= total_num_creation_blocks) break; 1463 wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( 1464 empty_type_wrapper<TL>{}, &entry->handle, 1465 (transpose_left ? 'T' : 'N'), left_data, entry->output_csr, work_item, 1466 i, num_threads); 1467 } 1468 }); 1469 // Do matrix-matrix multiplication 1470 ptrdiff_t total_num_mult_blocks = 1471 libxsmm_spmdm_get_num_compute_blocks(&entry->handle); 1472 std::atomic<int> cur_mult_block_number; 1473 cur_mult_block_number.store(0); 1474 do_on_all_threads(thread_pool, [&](int i) { 1475 while (true) { 1476 int work_item = cur_mult_block_number.fetch_add(1); 1477 if (work_item >= total_num_mult_blocks) break; 1478 const TL alpha(1.0); // Stored in a variable so we can get a pointer 1479 const TL beta(0.0); // Stored in a variable so we can get a pointer 1480 wrapper_libxsmm_spmdm_compute_generic_thread( 1481 empty_type_wrapper<TL>{}, &entry->handle, 1482 (transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr, 1483 right_data, (transpose_output ? 'T' : 'N'), &beta, output_data, 1484 work_item, i, num_threads); 1485 } 1486 }); 1487 // Put handle + CSR storage back into cache 1488 cache->return_cache_entry(std::move(entry)); 1489 } 1490 1491 #endif // TENSORFLOW_USE_LIBXSMM 1492 1493 // Here is a an overview of the SparseMatMul code. Note that we assume that the 1494 // left matrix is sparse. 1495 // 1496 // The matrix "left" is divided into a grid with blocksize of (M, KL). Each 1497 // block is encoded as a SparseSlice. These grid elements are stored as 1498 // std::vector<std::vector<SparseSlice>>. Each element of the outer vector 1499 // represents M rows of the left matrix. Lets call these elements l_i and lets 1500 // call each element of the inner vector L_mk. 1501 // 1502 // The matrix "right" is divided into a grid with block size KR * NR. Lets 1503 // denote the blocks on the right as R_kn. Note that we ensure that KL divides 1504 // KR so that for each element R_kn, we don't need to multiply it with any 1505 // partial L_mk blocks. 1506 // 1507 // We then multiply each right side block R_kn with the full "left" matrix and 1508 // update the output. These iterations are run sequentially since R_kn are 1509 // packed into the same underlying temporary buffer. 1510 // 1511 // In each iteration we do the following: 1512 // 1. Create slices r_j of R_kn: We split R_kn into vertical blocks with N 1513 // (=128) columns and then concatenating these slices into a buffer. This is 1514 // done so that each slice r_j of R_kn is stored contiguously in memory. Note 1515 // that if R_kj has dimensions (KR, NR), we create NR / N slices, and the 1516 // buffer has dimensions (KR * NR / N, N) (assuming N divides NR). 1517 // 2. For each (l_i, r_j), we compute the inner product using the GEPP function 1518 // and update the output block o_ij. These calls are further blocked to 1519 // reduce the working set size. In each iteration we take IB elements from 1520 // {l_i} and JB elements from {r_j} and compute the IB * JB inner products. 1521 template <typename TL, typename TR> 1522 inline void SparseMatMul<TL, TR>::Compute( 1523 typename SparseMatMul<TL, TR>::TensorInfoCache* /*cache*/, 1524 const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left, 1525 const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, 1526 bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, 1527 bool transpose_output, MatrixMap* output) { 1528 const int num_threads = thread_pool->num_threads; 1529 int KR, NR, KL, JB, IB; 1530 ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL, 1531 &JB, &IB); 1532 // Slice the left matrix 1533 std::vector<std::vector<SparseSlice<TL>*>> left_slices; 1534 std::unique_ptr<BlockingCounter> sparse_slice_counter = 1535 CreateSparseSlices(ConstMatrixMapL(left.data(), left.dimensions()), 1536 transpose_left, M, K, KL, &left_slices, thread_pool); 1537 const int num_left_slices = left_slices.size(); 1538 1539 const int right_dim0 = right.dimension(0); 1540 const int right_dim1 = right.dimension(1); 1541 // Allocate buffer for storing slices of right matrix. 1542 // Note buffer needs enough space to hold at most a KR * NR matrix since that 1543 // is the block size per iteration. 1544 const int buffer_num_rows = 1545 std::min(KR, right_dim0) * (std::min(NR, right_dim1) + N - 1) / N; 1546 MatrixR buffer(buffer_num_rows, N); 1547 std::vector<ConstMatrixMapR*> right_slices; 1548 1549 std::vector<SparseSlice<TL>*> block_left_slices; 1550 std::vector<std::function<void(void)>> tasks; 1551 // Number of blocks based on block sizes of KR * NR. 1552 const int num_k_blocks = (right_dim0 + KR - 1) / KR; 1553 const int num_n_blocks = (right_dim1 + NR - 1) / NR; 1554 std::unique_ptr<BlockingCounter> dense_slice_counter; 1555 1556 for (int nb = 0; nb < num_n_blocks; ++nb) { 1557 const int right_num_cols = 1558 std::min(NR, static_cast<int>(right_dim1 - NR * nb)); 1559 for (int kb = 0; kb < num_k_blocks; ++kb) { 1560 const int right_num_rows = 1561 std::min(KR, static_cast<int>(right_dim0 - KR * kb)); 1562 dense_slice_counter = CreateDenseSlices( 1563 right, kb * KR, right_num_rows, nb * NR, right_num_cols, thread_pool, 1564 &buffer, &right_slices); 1565 const int num_right_slices = right_slices.size(); 1566 tasks.reserve(num_left_slices * num_right_slices); 1567 for (int j_outer = 0; j_outer < num_right_slices; j_outer += JB) { 1568 for (int i_outer = 0; i_outer < num_left_slices; i_outer += IB) { 1569 for (int j_inner = j_outer; 1570 j_inner < std::min(num_right_slices, j_outer + JB); ++j_inner) { 1571 const int num_cols = std::min(N, right_num_cols - N * j_inner); 1572 for (int i_inner = i_outer; 1573 i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) { 1574 block_left_slices.clear(); 1575 int begin = kb * KR / KL; 1576 int end = std::min<int>((kb + 1) * KR / KL, 1577 (right.dimension(0) + KL - 1) / KL); 1578 DCHECK_LT(begin, end); 1579 block_left_slices.insert(block_left_slices.begin(), 1580 left_slices[i_inner].begin() + begin, 1581 left_slices[i_inner].begin() + end); 1582 tasks.push_back(std::bind( 1583 &ComputeOutputBlock, block_left_slices, 1584 std::ref(*right_slices[j_inner]), num_cols, M * i_inner, 1585 N * j_inner + nb * NR, kb == 0, transpose_output, output)); 1586 } 1587 } 1588 } 1589 } 1590 if (sparse_slice_counter) { 1591 sparse_slice_counter->Wait(); 1592 sparse_slice_counter.reset(nullptr); 1593 } 1594 if (dense_slice_counter) { 1595 dense_slice_counter->Wait(); 1596 dense_slice_counter.reset(nullptr); 1597 } 1598 BlockingCounter bc(tasks.size()); 1599 for (const auto& t : tasks) { 1600 thread_pool->workers->Schedule([&bc, &t]() { 1601 t(); 1602 bc.DecrementCount(); 1603 }); 1604 } 1605 bc.Wait(); 1606 tasks.clear(); 1607 gtl::STLDeleteElements(&right_slices); 1608 right_slices.clear(); 1609 } 1610 } 1611 for (auto& left_slice : left_slices) { 1612 gtl::STLDeleteElements(&left_slice); 1613 } 1614 } 1615 1616 #define REGISTER_SPARSE_MATMUL(TA, TB) \ 1617 REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \ 1618 .Device(DEVICE_CPU) \ 1619 .TypeConstraint<TA>("Ta") \ 1620 .TypeConstraint<TB>("Tb"), \ 1621 SparseMatMulOp<TA, TB, SparseMatMul>); 1622 #ifdef TENSORFLOW_USE_LIBXSMM 1623 #define REGISTER_SPARSE_MATMUL_LIBXSMM(TA, TB) \ 1624 REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \ 1625 .Device(DEVICE_CPU) \ 1626 .TypeConstraint<TA>("Ta") \ 1627 .TypeConstraint<TB>("Tb"), \ 1628 SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>); 1629 #endif 1630 1631 REGISTER_SPARSE_MATMUL(bfloat16, bfloat16); 1632 1633 REGISTER_SPARSE_MATMUL(float, bfloat16); 1634 1635 REGISTER_SPARSE_MATMUL(bfloat16, float); 1636 1637 #ifdef TENSORFLOW_USE_LIBXSMM 1638 REGISTER_SPARSE_MATMUL_LIBXSMM(float, float); 1639 #else 1640 REGISTER_SPARSE_MATMUL(float, float); 1641 #endif 1642 1643 #undef REGISTER_SPARSE_MATMUL 1644 1645 } // end namespace tensorflow 1646