Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/math_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/kernels/sparse_matmul_op.h"
     21 
     22 #include <map>
     23 #include <memory>
     24 #include <vector>
     25 
     26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     27 #include "tensorflow/core/common_runtime/device.h"
     28 #include "tensorflow/core/framework/bfloat16.h"
     29 #include "tensorflow/core/framework/op.h"
     30 #include "tensorflow/core/framework/op_kernel.h"
     31 #include "tensorflow/core/framework/types.h"
     32 #include "tensorflow/core/kernels/fill_functor.h"
     33 #include "tensorflow/core/lib/core/blocking_counter.h"
     34 #include "tensorflow/core/lib/core/threadpool.h"
     35 #include "tensorflow/core/lib/gtl/stl_util.h"
     36 #include "tensorflow/core/platform/logging.h"
     37 #include "tensorflow/core/platform/macros.h"
     38 #include "tensorflow/core/platform/mutex.h"
     39 #include "tensorflow/core/platform/thread_annotations.h"
     40 #include "tensorflow/core/platform/types.h"
     41 #ifdef TENSORFLOW_USE_LIBXSMM
     42 #include "include/libxsmm_intrinsics_x86.h"
     43 #include "include/libxsmm_malloc.h"
     44 #include "include/libxsmm_spmdm.h"
     45 #endif
     46 
     47 namespace tensorflow {
     48 namespace {
     49 
     50 using Eigen::operator==;
     51 
     52 template <typename T>
     53 using BasicMatrix = Eigen::Tensor<T, 2, Eigen::RowMajor>;
     54 
     55 template <typename T>
     56 using BasicMatrixMap =
     57     Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>;
     58 
     59 using Matrix = BasicMatrix<float>;
     60 using MatrixMap = BasicMatrixMap<float>;
     61 using CPUDevice = Eigen::ThreadPoolDevice;
     62 using DSizes = Eigen::DSizes<Eigen::DenseIndex, 2>;
     63 
     64 // Two commonly used static dsizes. We use Eigen::type2index to allow as much
     65 // compile time optimization as possible.
     66 #ifdef EIGEN_HAS_INDEX_LIST
     67 inline Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>
     68 dsizes_00() {
     69   return Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<0>>();
     70 }
     71 inline Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>
     72 dsizes_10() {
     73   return Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<0>>();
     74 }
     75 #else
     76 inline DSizes dsizes_00() { return DSizes(0, 0); }
     77 inline DSizes dsizes_10() { return DSizes(1, 0); }
     78 #endif
     79 
     80 // Blocksizes
     81 // TODO(agarwal): compute these sizes based on cache sizes.
     82 const int K = 64;
     83 const int M = 64;
     84 const int N = 128;
     85 
     86 // This stores a sparse representation of a slice of a matrix with size
     87 // (num_rows, num_cols). The slice is represented as a series of blocks of size
     88 // (num_rows, b), where b = block_size for all but the last block, which may
     89 // have fewer columns.
     90 //
     91 // num_rows and block_size are assumed to be <= 256. This allows storing
     92 // different indices as uint8.
     93 //
     94 // For each block, we store all the non zero entries in data/data3 vector and
     95 // the corresponding coordinates of the element in index/index3 vectors. index3
     96 // vector stores index of 3 elements in the same row so that these elements can
     97 // share the same row coordinate. Each entry in Index3 corresponds to 3 entries
     98 // in data3.
     99 //
    100 // Note that all the data/indices of all the blocks are stored in the same
    101 // vectors respectively. To identify block boundaries, we store the block
    102 // offsets using index3_offset/index_offset. If there are n blocks in the slice,
    103 // index3_offset and index_offset have n entries. The indices for the ith block
    104 // are the values in the following range:
    105 // [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for
    106 // index_offset.
    107 template <typename T>
    108 struct SparseSlice {
    109   using ConstMatrixMap = BasicMatrixMap<const T>;
    110 
    111  public:
    112   // Indices of three elements on the same row.
    113   struct Index3 {
    114     uint8 m;  // row
    115     // columns
    116     uint8 k1;
    117     uint8 k2;
    118     uint8 k3;
    119   };
    120 
    121   // Index of one element.
    122   struct Index {
    123     uint8 m;
    124     uint8 k;
    125   };
    126 
    127   SparseSlice(int nrows, int ncols, int bsize)
    128       : num_rows(nrows), num_cols(ncols), block_size(bsize) {
    129     DCHECK_LE(nrows, 256);
    130     DCHECK_LE(block_size, 256);
    131   }
    132 
    133   // Initializes the slice with data starting at mat(0, col_offset) and with
    134   // size (num_rows, num_cols).
    135   // If Transpose is true, implicitly transposes mat.
    136   template <bool Transpose = false>
    137   void Initialize(const ConstMatrixMap& mat, int col_offset);
    138 
    139   void Clear();
    140 
    141   // See comments above.
    142   std::vector<int> index3_offset;
    143   std::vector<Index3> index3;
    144   std::vector<T> data3;
    145 
    146   // See comments above. Similar to "index3" except that each element in "index"
    147   // corresponds to one element in data.
    148   std::vector<int> index_offset;
    149   std::vector<Index> index;
    150   std::vector<T> data;
    151 
    152   // Number of rows and columns for the slice.
    153   const int num_rows;
    154   const int num_cols;
    155 
    156   // Block size used to initialize from a matrix.
    157   const int block_size;
    158 };
    159 
    160 template <typename T>
    161 template <bool Transpose>
    162 void SparseSlice<T>::Initialize(
    163     const typename SparseSlice<T>::ConstMatrixMap& mat, int col_offset) {
    164   const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0);
    165   const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1);
    166   DCHECK_LE(num_rows, mat_rows);
    167   DCHECK_LE(num_cols + col_offset, mat_cols);
    168 
    169   int num_blocks = (num_cols + block_size - 1) / block_size;
    170   int mat_size = num_rows * num_cols;
    171 
    172   index3_offset.reserve(num_blocks);
    173   data3.reserve(mat_size);
    174   index3.reserve(mat_size / 3);
    175 
    176   index_offset.reserve(num_blocks);
    177   data.reserve(num_blocks * num_rows * 2);
    178   index.reserve(num_blocks * num_rows * 2);
    179 
    180   Index3 idx3;
    181   Index idx;
    182   int data3_size = 0;
    183   static const T zero(0);
    184   for (int i = 0; i < num_blocks; ++i) {
    185     int num_block_cols = std::min(block_size, num_cols - block_size * i);
    186     for (int row = 0; row < num_rows; ++row) {
    187       idx3.m = static_cast<uint8>(row);
    188       // Safety note: The following code has a race, since it checks whether
    189       // *curr is nonzero and then reads it again on use.  However, the result
    190       // of the race is only that some of the "nonzeros" in the resulting sparse
    191       // representation may actually be zero, which is harmless.
    192       const auto* start =
    193           Transpose ? &mat(col_offset, row) : &mat(row, col_offset);
    194       const auto* curr = start;
    195       const int stride = Transpose ? mat.dimension(1) : 1;
    196       const auto* end = start + stride * num_block_cols;
    197       uint8 k = 0;
    198 #define NEXT_ELEM \
    199   curr += stride; \
    200   ++k;
    201       while (true) {
    202         while (curr < end && (*curr == zero)) {
    203           NEXT_ELEM;
    204         }
    205         if (curr >= end) break;
    206         idx3.k1 = k;
    207         data3.push_back(*curr);
    208         NEXT_ELEM;
    209 
    210         while (curr < end && (*curr == zero)) {
    211           NEXT_ELEM;
    212         }
    213         if (curr >= end) break;
    214         idx3.k2 = k;
    215         data3.push_back(*curr);
    216         NEXT_ELEM;
    217 
    218         while (curr < end && (*curr == zero)) {
    219           NEXT_ELEM;
    220         }
    221         if (curr >= end) break;
    222         idx3.k3 = k;
    223         data3.push_back(*curr);
    224         NEXT_ELEM;
    225         index3.push_back(idx3);
    226 #undef NEXT_ELEM
    227       }
    228       int num_inserted_mod = data3.size() % 3;
    229       // Move some elements to index and data if needed.
    230       data3_size = data3.size() - num_inserted_mod;
    231       idx.m = idx3.m;
    232       switch (num_inserted_mod) {
    233         case 2:
    234           idx.k = idx3.k2;
    235           data.push_back(data3[data3_size + 1]);
    236           index.push_back(idx);
    237           TF_FALLTHROUGH_INTENDED;
    238         case 1:
    239           idx.k = idx3.k1;
    240           data.push_back(data3[data3_size]);
    241           index.push_back(idx);
    242           data3.resize(data3_size);
    243       }
    244     }
    245     col_offset += block_size;
    246     index3_offset.push_back(index3.size());
    247     index_offset.push_back(index.size());
    248   }
    249   DCHECK_EQ(index3_offset.size(), num_blocks);
    250   DCHECK_EQ(index_offset.size(), num_blocks);
    251   DCHECK_EQ(3 * index3.size(), data3.size());
    252   DCHECK_EQ(index.size(), data.size());
    253 }
    254 
    255 template <typename T>
    256 void SparseSlice<T>::Clear() {
    257   index3_offset.clear();
    258   index3.clear();
    259   data3.clear();
    260   index_offset.clear();
    261   index.clear();
    262   data.clear();
    263 }
    264 
    265 using Packet = Eigen::internal::packet_traits<float>::type;
    266 const int kNumOperands = (sizeof(Packet) / sizeof(float));
    267 #define LOAD(x) Eigen::internal::pload<Packet>(x);
    268 #define EXPAND_BFLOAT_L(x, y) \
    269   const auto y = Eigen::internal::pexpand_bf16_l<Packet>(x);
    270 #define EXPAND_BFLOAT_U(x, y) \
    271   const auto y = Eigen::internal::pexpand_bf16_u<Packet>(x);
    272 #define STORE(x, y) Eigen::internal::pstore<float>(x, y);
    273 #define FMA(a, b, c, d) d = Eigen::internal::pmadd<Packet>(a, b, c);
    274 
    275 #define ALWAYS_INLINE EIGEN_ALWAYS_INLINE
    276 
    277 ALWAYS_INLINE float ConvertBfloat16ToFloat(const bfloat16* src) {
    278   float out = 0;
    279   auto tmp = reinterpret_cast<bfloat16*>(&out);
    280 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
    281   tmp[0] = *src;
    282 #else
    283   tmp[1] = *src;
    284 #endif
    285   return out;
    286 }
    287 
    288 ALWAYS_INLINE Packet ConvertFourBfloat16ToFloat(const bfloat16* src) {
    289   return Eigen::internal::pload4bf16<Packet>(
    290       reinterpret_cast<const float*>(src));
    291 }
    292 
    293 ALWAYS_INLINE Packet ConvertTwoBfloat16ToFloat(const bfloat16* src) {
    294   return Eigen::internal::pload2bf16<Packet>(
    295       reinterpret_cast<const float*>(src));
    296 }
    297 
    298 ALWAYS_INLINE void ScalarMulAdd(const float a, const float** inp, float** out) {
    299   **out += a * **inp;
    300   ++*inp;
    301   ++*out;
    302 }
    303 
    304 ALWAYS_INLINE void ScalarMulAdd(const float a, const bfloat16** inp,
    305                                 float** out) {
    306   float inp_f = ConvertBfloat16ToFloat(*inp);
    307   **out += a * inp_f;
    308   ++*inp;
    309   ++*out;
    310 }
    311 ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
    312                                     const float a3, const bfloat16** inp1,
    313                                     const bfloat16** inp2,
    314                                     const bfloat16** inp3, float** out) {
    315   float inp1_f = ConvertBfloat16ToFloat(*inp1);
    316   float inp2_f = ConvertBfloat16ToFloat(*inp2);
    317   float inp3_f = ConvertBfloat16ToFloat(*inp3);
    318   **out += a1 * inp1_f + a2 * inp2_f + a3 * inp3_f;
    319   ++*out;
    320   ++*inp1;
    321   ++*inp2;
    322   ++*inp3;
    323 }
    324 
    325 ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
    326                                     const float a3, const float** inp1,
    327                                     const float** inp2, const float** inp3,
    328                                     float** out) {
    329   **out += a1 * **inp1 + a2 * **inp2 + a3 * **inp3;
    330   ++*out;
    331   ++*inp1;
    332   ++*inp2;
    333   ++*inp3;
    334 }
    335 
    336 ALWAYS_INLINE void LoadSingleScalar(const bfloat16** data, Packet* l) {
    337   auto tmp = ConvertBfloat16ToFloat(*data);
    338   *l = Eigen::internal::pset1<Packet>(tmp);
    339   ++*data;
    340 }
    341 
    342 ALWAYS_INLINE void LoadTwoScalars(const bfloat16** data, Packet* l1,
    343                                   Packet* l2) {
    344   if (kNumOperands >= 2) {
    345     auto tmp = ConvertTwoBfloat16ToFloat(*data);
    346     *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
    347     *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
    348     *data += 2;
    349   } else {
    350     LoadSingleScalar(data, l1);
    351     LoadSingleScalar(data, l2);
    352   }
    353 }
    354 
    355 ALWAYS_INLINE void LoadFourScalars(const bfloat16** data, Packet* l1,
    356                                    Packet* l2, Packet* l3, Packet* l4) {
    357   if (kNumOperands >= 4) {
    358     auto tmp = ConvertFourBfloat16ToFloat(*data);
    359     *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
    360     *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
    361     *l3 = Eigen::internal::pbroadcast_third<Packet>(tmp);
    362     *l4 = Eigen::internal::pbroadcast_fourth<Packet>(tmp);
    363     *data += 4;
    364   } else {
    365     LoadTwoScalars(data, l1, l2);
    366     LoadTwoScalars(data, l3, l4);
    367   }
    368 }
    369 
    370 ALWAYS_INLINE void LoadSingleScalar(const float** data, Packet* l) {
    371   *l = Eigen::internal::pload1<Packet>(*data);
    372   ++(*data);
    373 }
    374 
    375 ALWAYS_INLINE void LoadTwoScalars(const float** data, Packet* l1, Packet* l2) {
    376   LoadSingleScalar(data, l1);
    377   LoadSingleScalar(data, l2);
    378 }
    379 
    380 ALWAYS_INLINE void LoadFourScalars(const float** data, Packet* l1, Packet* l2,
    381                                    Packet* l3, Packet* l4) {
    382   LoadTwoScalars(data, l1, l2);
    383   LoadTwoScalars(data, l3, l4);
    384 }
    385 
    386 template <typename T>
    387 ALWAYS_INLINE void LoadThreeScalars(const T** data, Packet* l1, Packet* l2,
    388                                     Packet* l3) {
    389   LoadTwoScalars(data, l1, l2);
    390   LoadSingleScalar(data, l3);
    391 }
    392 
    393 template <typename T>
    394 ALWAYS_INLINE void LoadSixScalars(const T** data, Packet* l1, Packet* l2,
    395                                   Packet* l3, Packet* l4, Packet* l5,
    396                                   Packet* l6) {
    397   LoadFourScalars(data, l1, l2, l3, l4);
    398   LoadTwoScalars(data, l5, l6);
    399 }
    400 
    401 // Vectorized version of ScalarMulAdd.
    402 ALWAYS_INLINE void MulAdd(const Packet a, const bfloat16** binp, float** out) {
    403   auto inp = reinterpret_cast<const float*>(*binp);
    404   const auto b = LOAD(inp);
    405   EXPAND_BFLOAT_L(b, b_0);
    406   EXPAND_BFLOAT_U(b, b_1);
    407   *binp += 2 * kNumOperands;
    408   auto c1 = LOAD(*out);
    409   auto c2 = LOAD(*out + kNumOperands);
    410   FMA(a, b_0, c1, c1);
    411   FMA(a, b_1, c2, c2);
    412   STORE(*out, c1);
    413   STORE(*out + kNumOperands, c2);
    414   *out += 2 * kNumOperands;
    415 }
    416 
    417 // Vectorized version of ScalarMulAdd3Way.
    418 ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
    419                               const bfloat16** binp1, const bfloat16** binp2,
    420                               const bfloat16** binp3, float** out) {
    421   auto inp1 = reinterpret_cast<const float*>(*binp1);
    422   auto inp2 = reinterpret_cast<const float*>(*binp2);
    423   auto inp3 = reinterpret_cast<const float*>(*binp3);
    424   auto c1 = LOAD(*out);
    425   auto c2 = LOAD(*out + kNumOperands);
    426   const auto b1 = LOAD(inp1);
    427   EXPAND_BFLOAT_L(b1, b1_0);
    428   EXPAND_BFLOAT_U(b1, b1_1);
    429   *binp1 += 2 * kNumOperands;
    430   const auto b2 = LOAD(inp2);
    431   EXPAND_BFLOAT_L(b2, b2_0);
    432   EXPAND_BFLOAT_U(b2, b2_1);
    433   *binp2 += 2 * kNumOperands;
    434   const auto b3 = LOAD(inp3);
    435   EXPAND_BFLOAT_L(b3, b3_0);
    436   EXPAND_BFLOAT_U(b3, b3_1);
    437   *binp3 += 2 * kNumOperands;
    438   FMA(a1, b1_0, c1, c1);
    439   FMA(a1, b1_1, c2, c2);
    440   FMA(a2, b2_0, c1, c1);
    441   FMA(a2, b2_1, c2, c2);
    442   FMA(a3, b3_0, c1, c1);
    443   FMA(a3, b3_1, c2, c2);
    444   STORE(*out, c1);
    445   STORE(*out + kNumOperands, c2);
    446   *out += 2 * kNumOperands;
    447 }
    448 
    449 // Unroll MulAdd3Way for two iterations
    450 ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
    451                                  const Packet a3, const bfloat16** binp1,
    452                                  const bfloat16** binp2, const bfloat16** binp3,
    453                                  float** out) {
    454   auto inp1 = reinterpret_cast<const float*>(*binp1);
    455   auto inp2 = reinterpret_cast<const float*>(*binp2);
    456   auto inp3 = reinterpret_cast<const float*>(*binp3);
    457   auto c1 = LOAD(*out);
    458   auto c2 = LOAD(*out + kNumOperands);
    459   const auto b1 = LOAD(inp1);
    460   const auto b2 = LOAD(inp2);
    461   const auto b3 = LOAD(inp3);
    462 
    463   EXPAND_BFLOAT_L(b1, b1_0);
    464   EXPAND_BFLOAT_U(b1, b1_1);
    465   EXPAND_BFLOAT_L(b2, b2_0);
    466   EXPAND_BFLOAT_U(b2, b2_1);
    467   EXPAND_BFLOAT_L(b3, b3_0);
    468   EXPAND_BFLOAT_U(b3, b3_1);
    469   auto c3 = LOAD(*out + 2 * kNumOperands);
    470   auto c4 = LOAD(*out + 3 * kNumOperands);
    471   const auto b4 = LOAD(inp1 + kNumOperands);
    472   const auto b5 = LOAD(inp2 + kNumOperands);
    473   const auto b6 = LOAD(inp3 + kNumOperands);
    474 
    475   EXPAND_BFLOAT_L(b4, b4_0);
    476   EXPAND_BFLOAT_U(b4, b4_1);
    477   EXPAND_BFLOAT_L(b5, b5_0);
    478   EXPAND_BFLOAT_U(b5, b5_1);
    479   EXPAND_BFLOAT_L(b6, b6_0);
    480   EXPAND_BFLOAT_U(b6, b6_1);
    481 
    482   FMA(a1, b1_0, c1, c1);
    483   FMA(a1, b1_1, c2, c2);
    484   FMA(a1, b4_0, c3, c3);
    485   FMA(a1, b4_1, c4, c4);
    486   FMA(a2, b2_0, c1, c1);
    487   FMA(a2, b2_1, c2, c2);
    488   FMA(a2, b5_0, c3, c3);
    489   FMA(a2, b5_1, c4, c4);
    490   FMA(a3, b3_0, c1, c1);
    491   FMA(a3, b3_1, c2, c2);
    492   FMA(a3, b6_0, c3, c3);
    493   FMA(a3, b6_1, c4, c4);
    494   STORE(*out, c1);
    495   STORE(*out + kNumOperands, c2);
    496   STORE(*out + 2 * kNumOperands, c3);
    497   STORE(*out + 3 * kNumOperands, c4);
    498   *out += 4 * kNumOperands;
    499   *binp1 += 4 * kNumOperands;
    500   *binp2 += 4 * kNumOperands;
    501   *binp3 += 4 * kNumOperands;
    502 }
    503 
    504 // Apply MulAdd3Way on 128 operands.
    505 ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
    506                                  const Packet a3, const bfloat16** inp1,
    507                                  const bfloat16** inp2, const bfloat16** inp3,
    508                                  float** out) {
    509   for (int k = 0; k < 128 / (8 * kNumOperands); ++k) {
    510     TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
    511     TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
    512   }
    513 }
    514 
    515 // Vectorized version of ScalarMulAdd
    516 ALWAYS_INLINE void MulAdd(const Packet a, const float** inp, float** out) {
    517   const auto b = LOAD(*inp);
    518   *inp += kNumOperands;
    519   auto c = LOAD(*out);
    520   FMA(a, b, c, c);
    521   STORE(*out, c);
    522   *out += kNumOperands;
    523 }
    524 
    525 // Vectorized version of ScalarMulAdd3Way
    526 ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
    527                               const float** inp1, const float** inp2,
    528                               const float** inp3, float** out) {
    529   auto c = LOAD(*out);
    530   const auto b1 = LOAD(*inp1);
    531   *inp1 += kNumOperands;
    532   const auto b2 = LOAD(*inp2);
    533   *inp2 += kNumOperands;
    534   const auto b3 = LOAD(*inp3);
    535   *inp3 += kNumOperands;
    536   FMA(a1, b1, c, c);
    537   FMA(a2, b2, c, c);
    538   FMA(a3, b3, c, c);
    539   STORE(*out, c);
    540   *out += kNumOperands;
    541 }
    542 
    543 // Unroll MulAdd3Way for two iterations
    544 ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
    545                                  const Packet a3, const float** inp1,
    546                                  const float** inp2, const float** inp3,
    547                                  float** out) {
    548   auto c1 = LOAD(*out);
    549   const auto b1 = LOAD(*inp1);
    550   const auto b2 = LOAD(*inp2);
    551   const auto b3 = LOAD(*inp3);
    552 
    553   auto c2 = LOAD(*out + kNumOperands);
    554   const auto b4 = LOAD(*inp1 + kNumOperands);
    555   const auto b5 = LOAD(*inp2 + kNumOperands);
    556   const auto b6 = LOAD(*inp3 + kNumOperands);
    557 
    558   FMA(a1, b1, c1, c1);
    559   FMA(a1, b4, c2, c2);
    560   FMA(a2, b2, c1, c1);
    561   FMA(a2, b5, c2, c2);
    562   FMA(a3, b3, c1, c1);
    563   FMA(a3, b6, c2, c2);
    564   STORE(*out, c1);
    565   STORE(*out + kNumOperands, c2);
    566   *out += 2 * kNumOperands;
    567   *inp1 += 2 * kNumOperands;
    568   *inp2 += 2 * kNumOperands;
    569   *inp3 += 2 * kNumOperands;
    570 }
    571 
    572 // Unroll MulAdd3Way for four iterations
    573 ALWAYS_INLINE void FourMulAdd3Way(const Packet a1, const Packet a2,
    574                                   const Packet a3, const float** inp1,
    575                                   const float** inp2, const float** inp3,
    576                                   float** out) {
    577   TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
    578   TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
    579 }
    580 
    581 // Apply MulAdd3Way on 128 operands.
    582 ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
    583                                  const Packet a3, const float** inp1,
    584                                  const float** inp2, const float** inp3,
    585                                  float** out) {
    586   if (kNumOperands == 8) {
    587     FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
    588     FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
    589     FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
    590     FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
    591   } else {
    592     DCHECK_LE(4 * kNumOperands, 128);
    593     for (int i = 0; i < 128 / (4 * kNumOperands); ++i) {
    594       MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
    595       MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
    596       MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
    597       MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
    598     }
    599   }
    600 }
    601 // Computes product of "left_slices" with "num_cols" columns of "right", and
    602 // stores the output in *"output".
    603 // Note that left_slices is a list of SparseSlices, which are conceptually
    604 // assumed to be concatenated along the column dimension. Also each SparseSlice
    605 // is encoded as a list of blocks with upto N columns. See SparseSlice for more
    606 // details.
    607 template <typename TL, typename TR, int Cols>
    608 inline void GEPP(
    609     const std::vector<SparseSlice<TL>*>& left_slices,
    610     const Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
    611                            Eigen::Aligned>& right,
    612     const int num_cols, Matrix* output) {
    613   const int cols = (Cols == -1) ? num_cols : Cols;
    614   DCHECK_EQ(num_cols, cols);
    615   const int right_num_cols = right.dimension(1);
    616   const int output_num_cols = output->dimension(1);
    617   static const int kNumOperandsR = kNumOperands * sizeof(float) / sizeof(TR);
    618   const int cols_mod = cols % kNumOperandsR;
    619   int k_offset = 0;
    620   // Pre-compute pointers for output matrix.
    621   float* out_ptrs[M];
    622   float* const out_start = &(*output)(0, 0);
    623   for (int j = 0; j < M; ++j) {
    624     out_ptrs[j] = out_start + output_num_cols * j;
    625   }
    626   for (const auto* left_slice : left_slices) {
    627     const auto& left = *left_slice;
    628     const auto* data3 = (!left.data3.empty()) ? &left.data3[0] : nullptr;
    629     const auto* data = (!left.data.empty()) ? &left.data[0] : nullptr;
    630     const int num_blocks = left.index3_offset.size();
    631     int begin3 = 0;
    632     int begin = 0;
    633     for (int i = 0; i < num_blocks; ++i) {
    634       // Pre-compute pointers for right matrix
    635       const TR* right_ptrs[K];
    636       const auto* const right_start = &right(k_offset, 0);
    637       DCHECK_LT(k_offset, right.dimension(0));
    638       for (int j = 0; j < K; ++j) {
    639         right_ptrs[j] = right_start + right_num_cols * j;
    640       }
    641 
    642       const int end3 = left.index3_offset[i];
    643       int j = begin3;
    644       // Loop unrolled for 2 iterations.
    645       for (; j + 1 < end3; j += 2) {
    646         Packet l1, l2, l3, nl1, nl2, nl3;
    647         LoadSixScalars(&data3, &l1, &l2, &l3, &nl1, &nl2, &nl3);
    648         const auto& index = left.index3[j];
    649         const auto& nindex = left.index3[j + 1];
    650         float* out = out_ptrs[index.m];
    651         float* nout = out_ptrs[nindex.m];
    652         const auto* r1 = right_ptrs[index.k1];
    653         const auto* r2 = right_ptrs[index.k2];
    654         const auto* r3 = right_ptrs[index.k3];
    655 
    656         const auto* nr1 = right_ptrs[nindex.k1];
    657         const auto* nr2 = right_ptrs[nindex.k2];
    658         const auto* nr3 = right_ptrs[nindex.k3];
    659         if (cols == 128) {
    660           MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
    661           MulAdd3Way128(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
    662         } else {
    663           for (int n = 0; n < cols / kNumOperandsR; ++n) {
    664             MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
    665             MulAdd3Way(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
    666           }
    667 
    668           const float sl1 = Eigen::internal::pfirst<Packet>(l1);
    669           const float sl2 = Eigen::internal::pfirst<Packet>(l2);
    670           const float sl3 = Eigen::internal::pfirst<Packet>(l3);
    671           const float nsl1 = Eigen::internal::pfirst<Packet>(nl1);
    672           const float nsl2 = Eigen::internal::pfirst<Packet>(nl2);
    673           const float nsl3 = Eigen::internal::pfirst<Packet>(nl3);
    674           for (int k = 0; k < cols_mod; ++k) {
    675             ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
    676             ScalarMulAdd3Way(nsl1, nsl2, nsl3, &nr1, &nr2, &nr3, &nout);
    677           }
    678         }
    679       }
    680       if (j < end3) {
    681         Packet l1, l2, l3;
    682         LoadThreeScalars(&data3, &l1, &l2, &l3);
    683 
    684         const auto& index = left.index3[j];
    685         float* out = out_ptrs[index.m];
    686         const auto* r1 = right_ptrs[index.k1];
    687         const auto* r2 = right_ptrs[index.k2];
    688         const auto* r3 = right_ptrs[index.k3];
    689         if (cols == 128) {
    690           MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
    691         } else {
    692           for (int n = 0; n < cols / kNumOperandsR; ++n) {
    693             MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
    694           }
    695           const float sl1 = Eigen::internal::pfirst<Packet>(l1);
    696           const float sl2 = Eigen::internal::pfirst<Packet>(l2);
    697           const float sl3 = Eigen::internal::pfirst<Packet>(l3);
    698           for (int k = 0; k < cols_mod; ++k) {
    699             ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
    700           }
    701         }
    702       }
    703       begin3 = end3;
    704       int end = left.index_offset[i];
    705       // Loop unrolled for 4 iterations.
    706       j = begin;
    707       for (; j + 3 < end; j += 4) {
    708         Packet l, nl, n2l, n3l;
    709         LoadFourScalars(&data, &l, &nl, &n2l, &n3l);
    710 
    711         const auto& index = left.index[j];
    712         const auto& nindex = left.index[j + 1];
    713         const auto& n2index = left.index[j + 2];
    714         const auto& n3index = left.index[j + 3];
    715         const auto* r = right_ptrs[index.k];
    716         const auto* nr = right_ptrs[nindex.k];
    717         const auto* n2r = right_ptrs[n2index.k];
    718         const auto* n3r = right_ptrs[n3index.k];
    719         float* out = out_ptrs[index.m];
    720         float* nout = out_ptrs[nindex.m];
    721         float* n2out = out_ptrs[n2index.m];
    722         float* n3out = out_ptrs[n3index.m];
    723 
    724         for (int n = 0; n < cols / kNumOperandsR; ++n) {
    725           MulAdd(l, &r, &out);
    726           MulAdd(nl, &nr, &nout);
    727           MulAdd(n2l, &n2r, &n2out);
    728           MulAdd(n3l, &n3r, &n3out);
    729         }
    730 
    731         const float sl1 = Eigen::internal::pfirst<Packet>(l);
    732         const float sl2 = Eigen::internal::pfirst<Packet>(nl);
    733         const float sl3 = Eigen::internal::pfirst<Packet>(n2l);
    734         const float sl4 = Eigen::internal::pfirst<Packet>(n3l);
    735         for (int k = 0; k < cols_mod; ++k) {
    736           ScalarMulAdd(sl1, &r, &out);
    737           ScalarMulAdd(sl2, &nr, &nout);
    738           ScalarMulAdd(sl3, &n2r, &n2out);
    739           ScalarMulAdd(sl4, &n3r, &n3out);
    740         }
    741       }
    742       while (j < end) {
    743         Packet l;
    744         LoadSingleScalar(&data, &l);
    745         const auto& index = left.index[j];
    746         const auto* r = right_ptrs[index.k];
    747         float* out = out_ptrs[index.m];
    748         for (int n = 0; n < cols / kNumOperandsR; ++n) {
    749           MulAdd(l, &r, &out);
    750         }
    751         const float sl = Eigen::internal::pfirst<Packet>(l);
    752         for (int k = 0; k < cols_mod; ++k) {
    753           ScalarMulAdd(sl, &r, &out);
    754         }
    755         j++;
    756       }
    757       k_offset += left.block_size;
    758       begin = end;
    759     }
    760   }
    761 }
    762 
    763 #undef LOAD
    764 #undef EXPAND_BFLOAT_L
    765 #undef EXPAND_BFLOAT_U
    766 #undef STORE
    767 #undef FMA
    768 
    769 }  // namespace
    770 
    771 template <typename TL, typename TR>
    772 class SparseMatMul {
    773   using MatrixL = BasicMatrix<TL>;
    774   using MatrixR = BasicMatrix<TR>;
    775   using ConstMatrixMapL = BasicMatrixMap<const TL>;
    776   using ConstMatrixMapR = BasicMatrixMap<const TR>;
    777   using MatrixMapR = BasicMatrixMap<TR>;
    778 
    779  public:
    780   // Not used; added to match interface of LibxsmmSparseMatMul
    781   struct TensorInfoCache {};
    782 
    783   // Perform matrix multiplication of "left" and "right", and store the result
    784   // in *"output".
    785  public:
    786   static inline void Compute(TensorInfoCache* cache,
    787                              const ConstMatrixMapL& left,
    788                              const ConstMatrixMapR& right, bool transpose_left,
    789                              const DeviceBase::CpuWorkerThreads* thread_pool,
    790                              bool transpose_output, MatrixMap* output);
    791 
    792  private:
    793   // Computes multiplication of left and num_cols columns of right, and stores
    794   // the output block in *"output" at offsets "output_row_offset" and
    795   // "output_col_offset". If assign is true, assigns the value to that block,
    796   // else adds the values to the existing values.
    797   static inline void ComputeOutputBlock(
    798       const std::vector<SparseSlice<TL>*>& left, const ConstMatrixMapR& right,
    799       int num_cols, int output_row_offset, int output_col_offset, bool assign,
    800       bool transpose_output, MatrixMap* output);
    801 
    802   // Encodes "mat" using a sparse representation and stores that in
    803   // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and
    804   // "slice_num_cols", each grid element is converted into a SparseSlice and
    805   // stored in mat_slices. "slice_block_size" is used to perform further column
    806   // blocking of each slice.
    807   static inline std::unique_ptr<BlockingCounter> CreateSparseSlices(
    808       const ConstMatrixMapL& mat, bool transpose, int slice_num_rows,
    809       int slice_block_size, int slice_num_cols,
    810       std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
    811       const DeviceBase::CpuWorkerThreads* thread_pool);
    812 
    813   // This function chops "mat" along column dimension into pieces with at most N
    814   // columns, and concatenates the pieces one after the other in "buffer". It
    815   // returns the list of the pieces in "slices". It returns a BlockingCounter
    816   // which should be used to wait for the shuffle operations to complete.
    817   static inline std::unique_ptr<BlockingCounter> CreateDenseSlices(
    818       const ConstMatrixMapR& mat, int row_start, int num_rows, int col_start,
    819       int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
    820       MatrixR* buffer, std::vector<ConstMatrixMapR*>* slices);
    821 
    822   // Helper function for CreateDenseSlices to move the data around. It returns a
    823   // BlockingCounter which should be used to wait for the shuffle operations to
    824   // complete.
    825   static inline BlockingCounter* ShuffleMatrix(
    826       const ConstMatrixMapR& mat, int slice_row_start, int slice_num_rows,
    827       int slice_col_start, int slice_num_cols, const int N,
    828       const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer);
    829 
    830   // Helper function for CreateDenseSlices to create slices.
    831   static inline void SliceMatrix(const MatrixR& mat, const int num_rows,
    832                                  const int num_slices,
    833                                  std::vector<ConstMatrixMapR*>* slices);
    834 
    835   // Heuristics to compute various block sizes.
    836   // KR, NR: block sizes for "right". We run blocking iterations that operate on
    837   // matrices with at most this size.
    838   // KL: grid size along the column dimension used while encoding left.
    839   // IB, JB: number of left and right slices to multiply together. This is used
    840   // for ordering different ComputeBlockOutput operations inside each blocking
    841   // iteration so as to potentially reduce the working set size.
    842   static inline void ComputeBlockSizes(const ConstMatrixMapL& left,
    843                                        const ConstMatrixMapR& right,
    844                                        bool transpose_left, int num_threads,
    845                                        int* KR, int* NR, int* KL, int* JB,
    846                                        int* IB);
    847 
    848   TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMul);
    849 };
    850 
    851 #ifdef TENSORFLOW_USE_LIBXSMM
    852 template <typename TL, typename TR>
    853 class LibxsmmSparseMatMul {
    854   using MatrixL = BasicMatrix<TL>;
    855   using MatrixR = BasicMatrix<TR>;
    856   using ConstMatrixMapL = BasicMatrixMap<const TL>;
    857   using ConstMatrixMapR = BasicMatrixMap<const TR>;
    858   using MatrixMapR = BasicMatrixMap<TR>;
    859 
    860  public:
    861   // This structure contains a set of libxsmm kernels for sizes that have been
    862   // encountered previously by this operator so that libxsmm does not need to
    863   // reallocate its scratchpad memory each time (which hurts performance
    864   // substantially).
    865   struct TensorInfoCache {
    866     struct TensorInfoCacheEntry {
    867       // Parameters for kernel
    868       int M;
    869       int K;
    870       int N;
    871       int max_threads;
    872       // libxsmm handle and matrix data
    873       libxsmm_spmdm_handle handle;
    874       libxsmm_CSR_sparseslice* output_csr;
    875       // Chain to non-libxsmm implementation's cache in case that ever becomes
    876       // useful (it is an empty struct right now)
    877       typename SparseMatMul<TL, TR>::TensorInfoCache
    878           non_libxsmm_cache;  // Currently not used
    879     };
    880     // protects entries; invariant: entries is a valid std::multimap
    881     tensorflow::mutex lock;
    882     // Because there could be multiple matrix multiplies with the same sizes
    883     // going on at the same time, we need to allow multiple cache entries for a
    884     // given set of parameters. Taking and returning entries is used to make
    885     // sure the same cache entry is not used from two threads at a time.
    886     std::multimap<std::tuple<int, int, int, int>,
    887                   std::unique_ptr<TensorInfoCacheEntry>>
    888         entries GUARDED_BY(lock);
    889 
    890     TensorInfoCache() : lock(), entries() {}
    891     // Look up and remove first entry with these parameters, creating one if
    892     // there isn't one
    893     std::unique_ptr<TensorInfoCacheEntry> take_cache_entry(int M, int K, int N,
    894                                                            int max_threads)
    895         LOCKS_EXCLUDED(lock) {
    896       tensorflow::mutex_lock ml(lock);
    897       auto key = std::make_tuple(M, K, N, max_threads);
    898       auto it = entries.find(key);
    899       if (it != entries.end()) {
    900         auto val = std::move(it->second);
    901         entries.erase(it);
    902         return val;
    903       } else {
    904         std::unique_ptr<TensorInfoCacheEntry> e{
    905             new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}};
    906         // setup scoped allocator, which uses cpu_allocator() for this scope
    907         const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator;
    908         libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr);
    909         return e;
    910       }
    911     }
    912     // Add a cache entry with certain parameters
    913     void return_cache_entry(std::unique_ptr<TensorInfoCacheEntry> e)
    914         LOCKS_EXCLUDED(lock) {
    915       tensorflow::mutex_lock ml(lock);
    916       auto key = std::make_tuple(e->M, e->K, e->N, e->max_threads);
    917       entries.insert(std::make_pair(key, std::move(e)));
    918     }
    919     ~TensorInfoCache() {
    920       tensorflow::mutex_lock ml(lock);
    921       for (auto& p : entries) {
    922         libxsmm_spmdm_destroy(&p.second->handle);
    923       }
    924       entries.clear();
    925     }
    926 
    927    private:
    928     TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCache);
    929   };
    930 
    931   // Perform matrix multiplication of "left" and "right", and store the result
    932   // in *"output".
    933  public:
    934   static inline void Compute(TensorInfoCache* cache,
    935                              const ConstMatrixMapL& left,
    936                              const ConstMatrixMapR& right, bool transpose_left,
    937                              const DeviceBase::CpuWorkerThreads* thread_pool,
    938                              bool transpose_output, MatrixMap* output);
    939 
    940  private:
    941   TF_DISALLOW_COPY_AND_ASSIGN(LibxsmmSparseMatMul);
    942 };
    943 #endif
    944 
    945 template <typename TL, typename TR,
    946           template <typename TL2, typename TR2> class DoMatMul>
    947 class SparseMatMulOp : public OpKernel {
    948   using MatrixR = BasicMatrix<TR>;
    949   using ConstMatrixMapR = BasicMatrixMap<const TR>;
    950 
    951  public:
    952   explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    953     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
    954     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
    955     OP_REQUIRES_OK(ctx, ctx->GetAttr("a_is_sparse", &a_is_sparse_));
    956     OP_REQUIRES_OK(ctx, ctx->GetAttr("b_is_sparse", &b_is_sparse_));
    957   }
    958 
    959   void Compute(OpKernelContext* ctx) override {
    960     const Tensor& a = ctx->input(0);
    961     const Tensor& b = ctx->input(1);
    962     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
    963                 errors::InvalidArgument("a is not a matrix"));
    964     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
    965                 errors::InvalidArgument("b is not a matrix"));
    966 
    967     const int m = transpose_a_ ? a.dim_size(1) : a.dim_size(0);
    968     const int k = transpose_a_ ? a.dim_size(0) : a.dim_size(1);
    969     const int n = transpose_b_ ? b.dim_size(0) : b.dim_size(1);
    970     const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0);
    971 
    972     OP_REQUIRES(ctx, k == k2,
    973                 errors::InvalidArgument(
    974                     "Matrix size incompatible: a: ", a.shape().DebugString(),
    975                     ", b: ", b.shape().DebugString()));
    976     Tensor* output = nullptr;
    977     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output));
    978 
    979     if (k == 0) {
    980       // If the inner dimension k in the matrix multiplication is zero, we fill
    981       // the output with zeros.
    982       functor::SetZeroFunctor<CPUDevice, float> f;
    983       f(ctx->eigen_device<CPUDevice>(), output->flat<float>());
    984       return;
    985     }
    986 
    987     auto out = output->matrix<float>();
    988 
    989     std::unique_ptr<Tensor> a_float;
    990     std::unique_ptr<Tensor> b_float;
    991     if (!a_is_sparse_ && !b_is_sparse_) {
    992       auto left = &a;
    993       auto right = &b;
    994       // TODO(agarwal): multi-thread the conversions from bfloat16 to float.
    995       if (std::is_same<TL, bfloat16>::value) {
    996         a_float.reset(new Tensor(DT_FLOAT, a.shape()));
    997         BFloat16ToFloat(a.flat<bfloat16>().data(),
    998                         a_float->flat<float>().data(), a.NumElements());
    999         left = a_float.get();
   1000       }
   1001       if (std::is_same<TR, bfloat16>::value) {
   1002         b_float.reset(new Tensor(DT_FLOAT, b.shape()));
   1003         BFloat16ToFloat(b.flat<bfloat16>().data(),
   1004                         b_float->flat<float>().data(), b.NumElements());
   1005         right = b_float.get();
   1006       }
   1007       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
   1008       dim_pair[0].first = transpose_a_ ? 0 : 1;
   1009       dim_pair[0].second = transpose_b_ ? 1 : 0;
   1010 
   1011       out.device(ctx->template eigen_device<CPUDevice>()) =
   1012           left->matrix<float>().contract(right->matrix<float>(), dim_pair);
   1013       return;
   1014     }
   1015 
   1016     auto left = &a;
   1017     auto right = &b;
   1018     bool transpose_output = false;
   1019     bool transpose_a = transpose_a_;
   1020     bool transpose_b = transpose_b_;
   1021     if (!a_is_sparse_) {
   1022       // Swap the order of multiplications using the identity:
   1023       // A * B = (B' *  A')'.
   1024       std::swap(left, right);
   1025       std::swap(transpose_a, transpose_b);
   1026       transpose_a = !transpose_a;
   1027       transpose_b = !transpose_b;
   1028       transpose_output = !transpose_output;
   1029     }
   1030 
   1031     std::unique_ptr<Tensor> right_tr;
   1032     if (transpose_b) {
   1033       // TODO(agarwal): avoid transposing the matrix here and directly handle
   1034       // transpose in CreateDenseSlices.
   1035       right_tr.reset(
   1036           new Tensor(right->dtype(),
   1037                      TensorShape({right->dim_size(1), right->dim_size(0)})));
   1038 
   1039       const auto perm = dsizes_10();
   1040       if (transpose_output) {
   1041         right_tr->matrix<TL>().device(ctx->template eigen_device<CPUDevice>()) =
   1042             right->matrix<TL>().shuffle(perm);
   1043       } else {
   1044         right_tr->matrix<TR>().device(ctx->template eigen_device<CPUDevice>()) =
   1045             right->matrix<TR>().shuffle(perm);
   1046       }
   1047       right = right_tr.get();
   1048     }
   1049 
   1050     if (transpose_output) {
   1051       DoMatMul<TR, TL>::Compute(&this->cache_tr_, left->matrix<TR>(),
   1052                                 right->matrix<TL>(), transpose_a,
   1053                                 ctx->device()->tensorflow_cpu_worker_threads(),
   1054                                 transpose_output, &out);
   1055     } else {
   1056       DoMatMul<TL, TR>::Compute(&this->cache_nt_, left->matrix<TL>(),
   1057                                 right->matrix<TR>(), transpose_a,
   1058                                 ctx->device()->tensorflow_cpu_worker_threads(),
   1059                                 transpose_output, &out);
   1060     }
   1061   }
   1062 
   1063  private:
   1064   bool transpose_a_;
   1065   bool transpose_b_;
   1066   bool a_is_sparse_;
   1067   bool b_is_sparse_;
   1068 
   1069   // Cache for non-transposed-output multiply
   1070   typename DoMatMul<TL, TR>::TensorInfoCache cache_nt_;
   1071   // Cache for transposed-output multiply
   1072   typename DoMatMul<TR, TL>::TensorInfoCache cache_tr_;
   1073 
   1074   TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp);
   1075 };
   1076 
   1077 template <typename TL, typename TR>
   1078 inline void SparseMatMul<TL, TR>::ComputeOutputBlock(
   1079     const std::vector<SparseSlice<TL>*>& left,
   1080     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right, int num_cols,
   1081     int output_row_offset, int output_col_offset, bool assign,
   1082     bool transpose_output, MatrixMap* output) {
   1083   const auto perm = dsizes_10();
   1084   int num_rows = left[0]->num_rows;
   1085   const int rhs_num_cols = right.dimension(1);
   1086   DCHECK_LE(num_cols, rhs_num_cols);
   1087   Matrix out(num_rows, rhs_num_cols);
   1088   out.setZero();
   1089   if (num_cols == N) {
   1090     GEPP<TL, TR, N>(left, right, num_cols, &out);
   1091   } else {
   1092     GEPP<TL, TR, -1>(left, right, num_cols, &out);
   1093   }
   1094   if (!assign) {
   1095     const DSizes begin(output_row_offset, output_col_offset);
   1096     const DSizes sizes(num_rows, num_cols);
   1097     if (transpose_output) {
   1098       if (num_cols == rhs_num_cols) {
   1099         output->shuffle(perm).slice(begin, sizes) += out;
   1100       } else {
   1101         const auto zero = dsizes_00();
   1102         output->shuffle(perm).slice(begin, sizes) += out.slice(zero, sizes);
   1103       }
   1104     } else {
   1105       if (num_cols == rhs_num_cols) {
   1106         output->slice(begin, sizes) += out;
   1107       } else {
   1108         const auto zero = dsizes_00();
   1109         output->slice(begin, sizes) += out.slice(zero, sizes);
   1110       }
   1111     }
   1112   } else {
   1113     std::unique_ptr<Matrix> out_tr;
   1114     if (transpose_output) {
   1115       out_tr.reset(new Matrix(rhs_num_cols, num_rows));
   1116       *out_tr = out.shuffle(perm);
   1117       std::swap(output_row_offset, output_col_offset);
   1118       std::swap(num_rows, num_cols);
   1119     }
   1120     const Matrix& final_out = transpose_output ? *out_tr : out;
   1121     for (int i = 0; i < num_rows; ++i) {
   1122       memcpy(&(*output)(output_row_offset + i, output_col_offset),
   1123              &final_out(i, 0), num_cols * sizeof(float));
   1124     }
   1125   }
   1126 }
   1127 
   1128 template <typename TL, typename TR>
   1129 inline std::unique_ptr<BlockingCounter>
   1130 SparseMatMul<TL, TR>::CreateSparseSlices(
   1131     const typename SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose,
   1132     int slice_num_rows, int slice_block_size, int slice_num_cols,
   1133     std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
   1134     const DeviceBase::CpuWorkerThreads* thread_pool) {
   1135   const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0);
   1136   const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1);
   1137   const int num_slices_dim0 =
   1138       std::max(1, (mat_num_rows + slice_num_rows - 1) / slice_num_rows);
   1139   const int num_slices_dim1 =
   1140       std::max(1, (mat_num_cols + slice_num_cols - 1) / slice_num_cols);
   1141   mat_slices->resize(num_slices_dim0);
   1142   BlockingCounter* counter =
   1143       new BlockingCounter(num_slices_dim0 * num_slices_dim1);
   1144   auto work = [counter, transpose](SparseSlice<TL>* sparse_slice,
   1145                                    SparseMatMul<TL, TR>::ConstMatrixMapL* slice,
   1146                                    int col_offset) {
   1147     if (transpose) {
   1148       sparse_slice->template Initialize<true>(*slice, col_offset);
   1149     } else {
   1150       sparse_slice->template Initialize<false>(*slice, col_offset);
   1151     }
   1152     delete slice;
   1153     counter->DecrementCount();
   1154   };
   1155   for (int i = 0; i < num_slices_dim0; ++i) {
   1156     (*mat_slices)[i].resize(num_slices_dim1);
   1157     int num_rows =
   1158         std::min<int>(slice_num_rows, mat_num_rows - i * slice_num_rows);
   1159     for (int j = 0; j < num_slices_dim1; ++j) {
   1160       int num_cols =
   1161           std::min<int>(slice_num_cols, mat_num_cols - j * slice_num_cols);
   1162       SparseMatMul<TL, TR>::ConstMatrixMapL* slice = nullptr;
   1163       if (transpose) {
   1164         slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
   1165             &mat(0, i * slice_num_rows), mat.dimensions());
   1166       } else {
   1167         DSizes d(num_rows, mat_num_cols);
   1168         slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
   1169             &mat(i * slice_num_rows, 0), d);
   1170       }
   1171       auto* sparse_slice =
   1172           new SparseSlice<TL>(num_rows, num_cols, slice_block_size);
   1173       (*mat_slices)[i][j] = sparse_slice;
   1174       thread_pool->workers->Schedule(
   1175           [=]() { work(sparse_slice, slice, slice_num_cols * j); });
   1176     }
   1177   }
   1178   return std::unique_ptr<BlockingCounter>(counter);
   1179 }
   1180 #define LOAD(x) Eigen::internal::ploadu<Packet>((x));
   1181 #define INTERLEAVE(x) Eigen::internal::pinterleave4x64<Packet>(x);
   1182 #define STORE(x, y) Eigen::internal::pstoreu<float>(x, y);
   1183 
   1184 template <int NUM_ELEM = -1>
   1185 ALWAYS_INLINE void CopyAndMayBeInterleaveBfloat16(void* bdst, const void* bsrc,
   1186                                                   int num_elements) {
   1187   DCHECK_GE(kNumOperands, 8);
   1188   static const int kStep = kNumOperands * sizeof(float) / sizeof(bfloat16);
   1189   const int num = (NUM_ELEM == -1) ? num_elements : NUM_ELEM;
   1190   DCHECK_EQ(num, num_elements);
   1191   const float* src = reinterpret_cast<const float*>(bsrc);
   1192   float* dst = reinterpret_cast<float*>(bdst);
   1193   for (int index = 0; index + kStep <= num; index += kStep) {
   1194     auto in = LOAD(src);
   1195     auto tmp = INTERLEAVE(in);
   1196     STORE(dst, tmp);
   1197     src += kNumOperands;
   1198     dst += kNumOperands;
   1199   }
   1200   if (num % kStep != 0) {
   1201     memcpy(dst, src, (num % kStep) * sizeof(bfloat16));
   1202   }
   1203 }
   1204 
   1205 template <typename T>
   1206 ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src,
   1207                                           int num_elements) {
   1208   if (std::is_same<T, float>::value || kNumOperands < 8) {
   1209     memcpy(dst, src, num_elements * sizeof(T));
   1210   } else if (std::is_same<T, bfloat16>::value) {
   1211     if (num_elements == N) {
   1212       CopyAndMayBeInterleaveBfloat16<N>(dst, src, num_elements);
   1213     } else {
   1214       CopyAndMayBeInterleaveBfloat16<-1>(dst, src, num_elements);
   1215     }
   1216   } else {
   1217     LOG(FATAL) << "Unsupported type";
   1218   }
   1219 }
   1220 
   1221 #undef LOAD
   1222 #undef Interleave
   1223 #undef Store
   1224 
   1225 template <typename TL, typename TR>
   1226 inline BlockingCounter* SparseMatMul<TL, TR>::ShuffleMatrix(
   1227     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat,
   1228     int slice_row_start, int slice_num_rows, int slice_col_start,
   1229     int slice_num_cols, const int N,
   1230     const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) {
   1231   DCHECK_EQ(N % 2, 0);
   1232   DCHECK_LE(kNumOperands * sizeof(float) / sizeof(TR), N);
   1233   int num_threads = std::min(thread_pool->num_threads, 16);
   1234   BlockingCounter* counter = new BlockingCounter(num_threads);
   1235   DCHECK_EQ(N, buffer->dimension(1));
   1236   auto shuffle_work = [&mat, slice_row_start, slice_num_rows, slice_col_start,
   1237                        slice_num_cols, N, buffer, counter](int s, int e) {
   1238     const int row_start = s % slice_num_rows + slice_row_start;
   1239     const int col_start = s / slice_num_rows * N + slice_col_start;
   1240     auto* out_start = &(*buffer)(s, 0);
   1241     const auto* input_start = &mat(row_start, col_start);
   1242     const auto* input_end = &mat(slice_row_start + slice_num_rows - 1,
   1243                                  slice_col_start + slice_num_cols - 1);
   1244     const int mat_num_cols = mat.dimension(1);
   1245     const int row_slice_size = slice_num_rows * mat_num_cols;
   1246 
   1247     const int aligned_end = slice_num_cols / N * slice_num_rows;
   1248     const int e1 = std::min(e, aligned_end);
   1249     while (s < e1) {
   1250       CopyAndMayBeInterleave<TR>(out_start, input_start, N);
   1251       out_start += N;
   1252       input_start += mat_num_cols;
   1253       if (input_start > input_end) {
   1254         input_start = input_start - row_slice_size + N;
   1255       }
   1256       ++s;
   1257     }
   1258     int s1 = std::max(s, aligned_end);
   1259     const int copy_num_cols = slice_num_cols % N;
   1260     while (s1 < e) {
   1261       CopyAndMayBeInterleave<TR>(out_start, input_start, copy_num_cols);
   1262       out_start += N;
   1263       input_start += mat_num_cols;
   1264       ++s1;
   1265     }
   1266     if (counter) counter->DecrementCount();
   1267   };
   1268 
   1269   int start = 0;
   1270   int end = 0;
   1271   int num_out_rows = (slice_num_cols + N - 1) / N * slice_num_rows;
   1272   DCHECK_LE(num_out_rows, buffer->dimension(0));
   1273   for (int i = std::max(1, num_threads); i > 0; --i) {
   1274     end = start + num_out_rows / i;
   1275     thread_pool->workers->Schedule([=]() { shuffle_work(start, end); });
   1276     num_out_rows -= (end - start);
   1277     start = end;
   1278   }
   1279   return counter;
   1280 }
   1281 
   1282 template <typename TL, typename TR>
   1283 inline void SparseMatMul<TL, TR>::SliceMatrix(
   1284     const MatrixR& mat, const int num_rows, const int num_slices,
   1285     std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
   1286   slices->resize(num_slices);
   1287   DSizes d(num_rows, mat.dimension(1));
   1288   DCHECK_LE(num_rows * num_slices, mat.dimension(0));
   1289   for (int i = 0; i < num_slices; ++i) {
   1290     (*slices)[i] = new ConstMatrixMapR(&mat(i * num_rows, 0), d);
   1291   }
   1292 }
   1293 
   1294 template <typename TL, typename TR>
   1295 inline std::unique_ptr<BlockingCounter> SparseMatMul<TL, TR>::CreateDenseSlices(
   1296     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start,
   1297     int num_rows, int col_start, int num_cols,
   1298     const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer,
   1299     std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
   1300   std::unique_ptr<BlockingCounter> shuffle_counter(ShuffleMatrix(
   1301       mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer));
   1302   const int num_slices = (num_cols + N - 1) / N;
   1303   SliceMatrix(*buffer, num_rows, num_slices, slices);
   1304   return shuffle_counter;
   1305 }
   1306 
   1307 template <typename TL, typename TR>
   1308 inline void SparseMatMul<TL, TR>::ComputeBlockSizes(
   1309     const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
   1310     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
   1311     bool transpose_left, int num_threads, int* KR, int* NR, int* KL, int* JB,
   1312     int* IB) {
   1313   // Heuristics for calculating block sizes
   1314   // Assume two hyperthreads per core.
   1315   const int est_num_cores = std::max(1, (num_threads + 1) / 2);
   1316   // Use block of rhs with at most 128K floats per core.
   1317   const int mem = est_num_cores * 128 * 1024;
   1318   *KR = std::min(static_cast<int>(right.dimension(0)), mem / 256);
   1319   *NR = right.dimension(1);
   1320   if (*KR * *NR > mem) {
   1321     // 4096 may be enough to amortize the cost of writes.
   1322     *KR = std::min<int>(*KR, 4096);
   1323   }
   1324   // Use sizes that are multiples of K and 256.
   1325   *KR = std::max(1, *KR / K) * K;
   1326   *NR = std::max(1, *NR / 256) * 256;
   1327   if (*KR * *NR > mem) {
   1328     *NR = mem / *KR;
   1329   }
   1330   *NR = std::max(1, *NR / 256) * 256;
   1331 
   1332   const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
   1333   const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
   1334   for (*KL = 1024; *KL > K; *KL /= 2) {
   1335     if (*KR % *KL == 0 &&
   1336         std::max<int>(1, left_dim0 / 64) * (left_dim1 / *KL) > est_num_cores) {
   1337       break;
   1338     }
   1339   }
   1340   DCHECK_EQ(*KL % K, 0);
   1341   DCHECK_GE(*KR, *KL);
   1342   if (*KR < right.dimension(0)) {
   1343     CHECK_EQ(*KR % *KL, 0);
   1344   }
   1345 
   1346   *JB = std::max(1, static_cast<int>(sqrt(num_threads) / 2.0));
   1347   *IB = 8 * *JB;
   1348   DCHECK_EQ(N * sizeof(float) % 64, size_t{0});
   1349 }
   1350 
   1351 #ifdef TENSORFLOW_USE_LIBXSMM
   1352 
   1353 template <typename F>
   1354 void do_on_all_threads(const DeviceBase::CpuWorkerThreads* thread_pool,
   1355                        const F& f) {
   1356   int num_threads = thread_pool->num_threads;
   1357   if (num_threads == 0) {
   1358     LOG(FATAL) << "Have 0 threads in thread pool";
   1359   } else if (num_threads == 1) {
   1360     f(0);
   1361   } else {
   1362     BlockingCounter counter(num_threads - 1);
   1363     for (int i = 1; i < num_threads; ++i) {
   1364       thread_pool->workers->Schedule([&, i]() {
   1365         f(i);
   1366         counter.DecrementCount();
   1367       });
   1368     }
   1369     f(0);
   1370     counter.Wait();
   1371   }
   1372 }
   1373 
   1374 template <typename T>
   1375 struct empty_type_wrapper {};
   1376 
   1377 // Copies of interface to libxsmm_spmdm_createSparseSlice_*_notrans_thread to
   1378 // allow overloading
   1379 void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
   1380     empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
   1381     const float* A, libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id,
   1382     int tid, int nthreads) {
   1383   return libxsmm_spmdm_createSparseSlice_fp32_thread(
   1384       handle, transA, A, libxsmm_output_csr_a, block_id, tid, nthreads);
   1385 }
   1386 void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
   1387     empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
   1388     char transA, const bfloat16* A,
   1389     libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid,
   1390     int nthreads) {
   1391   return libxsmm_spmdm_createSparseSlice_bfloat16_thread(
   1392       handle, transA, reinterpret_cast<const uint16*>(A), libxsmm_output_csr_a,
   1393       block_id, tid, nthreads);
   1394 }
   1395 
   1396 void wrapper_libxsmm_spmdm_compute_generic_thread(
   1397     empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
   1398     char transA, char transB, const bfloat16* alpha,
   1399     libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC,
   1400     const bfloat16* beta, float* C, int block_id, int tid, int nthreads) {
   1401   return libxsmm_spmdm_compute_bfloat16_thread(
   1402       handle, transA, transB, reinterpret_cast<const uint16*>(alpha), A_sparse,
   1403       reinterpret_cast<const uint16*>(B), transC,
   1404       reinterpret_cast<const uint16*>(beta), C, block_id, tid, nthreads);
   1405 }
   1406 void wrapper_libxsmm_spmdm_compute_generic_thread(
   1407     empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
   1408     char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse,
   1409     const float* B, char transC, const float* beta, float* C, int block_id,
   1410     int tid, int nthreads) {
   1411   return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha,
   1412                                            A_sparse, B, transC, beta, C,
   1413                                            block_id, tid, nthreads);
   1414 }
   1415 
   1416 template <typename TL, typename TR>
   1417 inline void LibxsmmSparseMatMul<TL, TR>::Compute(
   1418     typename LibxsmmSparseMatMul<TL, TR>::TensorInfoCache* cache,
   1419     const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapL& left,
   1420     const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right,
   1421     bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
   1422     bool transpose_output, MatrixMap* output) {
   1423   if (false) {
   1424     // Not handled by libxsmm currently
   1425     SparseMatMul<TL, TR>::Compute(
   1426         nullptr /* Assumes no cached data for fallback */, left, right,
   1427         transpose_left, thread_pool, transpose_output, output);
   1428     return;
   1429   }
   1430   const int num_threads = thread_pool->num_threads;
   1431   const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0);
   1432   const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1);
   1433   const int right_dim0 = right.dimension(0);
   1434   const int right_dim1 = right.dimension(1);
   1435   CHECK_EQ(left_dim1, right_dim0);
   1436   CHECK_EQ(left_dim0,
   1437            (transpose_output ? output->dimension(1) : output->dimension(0)));
   1438   CHECK_EQ(right_dim1,
   1439            (transpose_output ? output->dimension(0) : output->dimension(1)));
   1440   if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) {
   1441     // Causes problems in libxsmm
   1442     SparseMatMul<TL, TR>::Compute(
   1443         nullptr /* Assumes no cached data for fallback */, left, right,
   1444         transpose_left, thread_pool, transpose_output, output);
   1445     return;
   1446   }
   1447   auto left_data = left.data();
   1448   auto right_data = right.data();
   1449   auto output_data = output->data();
   1450   // Initialize libxsmm for this matrix; make sure another thread doesn't use
   1451   // this handle
   1452   auto entry =
   1453       cache->take_cache_entry(left_dim0, right_dim0, right_dim1, num_threads);
   1454   // Convert the left matrix to compressed sparse row (CSR) format
   1455   ptrdiff_t total_num_creation_blocks =
   1456       libxsmm_spmdm_get_num_createSparseSlice_blocks(&entry->handle);
   1457   std::atomic<int> cur_create_block_number;
   1458   cur_create_block_number.store(0);
   1459   do_on_all_threads(thread_pool, [&](int i) {
   1460     while (true) {
   1461       int work_item = cur_create_block_number.fetch_add(1);
   1462       if (work_item >= total_num_creation_blocks) break;
   1463       wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
   1464           empty_type_wrapper<TL>{}, &entry->handle,
   1465           (transpose_left ? 'T' : 'N'), left_data, entry->output_csr, work_item,
   1466           i, num_threads);
   1467     }
   1468   });
   1469   // Do matrix-matrix multiplication
   1470   ptrdiff_t total_num_mult_blocks =
   1471       libxsmm_spmdm_get_num_compute_blocks(&entry->handle);
   1472   std::atomic<int> cur_mult_block_number;
   1473   cur_mult_block_number.store(0);
   1474   do_on_all_threads(thread_pool, [&](int i) {
   1475     while (true) {
   1476       int work_item = cur_mult_block_number.fetch_add(1);
   1477       if (work_item >= total_num_mult_blocks) break;
   1478       const TL alpha(1.0);  // Stored in a variable so we can get a pointer
   1479       const TL beta(0.0);   // Stored in a variable so we can get a pointer
   1480       wrapper_libxsmm_spmdm_compute_generic_thread(
   1481           empty_type_wrapper<TL>{}, &entry->handle,
   1482           (transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr,
   1483           right_data, (transpose_output ? 'T' : 'N'), &beta, output_data,
   1484           work_item, i, num_threads);
   1485     }
   1486   });
   1487   // Put handle + CSR storage back into cache
   1488   cache->return_cache_entry(std::move(entry));
   1489 }
   1490 
   1491 #endif  // TENSORFLOW_USE_LIBXSMM
   1492 
   1493 // Here is a an overview of the SparseMatMul code. Note that we assume that the
   1494 // left matrix is sparse.
   1495 //
   1496 // The matrix "left" is divided into a grid with blocksize of (M, KL). Each
   1497 // block is encoded as a SparseSlice. These grid elements are stored as
   1498 // std::vector<std::vector<SparseSlice>>. Each element of the outer vector
   1499 // represents M rows of the left matrix. Lets call these elements l_i and lets
   1500 // call each element of the inner vector L_mk.
   1501 //
   1502 // The matrix "right" is divided into a grid with block size KR * NR.  Lets
   1503 // denote the blocks on the right as R_kn. Note that we ensure that KL divides
   1504 // KR so that for each element R_kn, we don't need to multiply it with any
   1505 // partial L_mk blocks.
   1506 //
   1507 // We then multiply each right side block R_kn with the full "left" matrix and
   1508 // update the output. These iterations are run sequentially since R_kn are
   1509 // packed into the same underlying temporary buffer.
   1510 //
   1511 // In each iteration we do the following:
   1512 // 1. Create slices r_j of R_kn: We split R_kn into vertical blocks with N
   1513 //    (=128) columns and then concatenating these slices into a buffer. This is
   1514 //    done so that each slice r_j of R_kn is stored contiguously in memory. Note
   1515 //    that if R_kj has dimensions (KR, NR), we create NR / N slices, and the
   1516 //    buffer has dimensions (KR * NR / N, N) (assuming N divides NR).
   1517 // 2. For each (l_i, r_j), we compute the inner product using the GEPP function
   1518 //    and update the output block o_ij. These calls are further blocked to
   1519 //    reduce the working set size. In each iteration we take IB elements from
   1520 //    {l_i} and JB elements from {r_j} and compute the IB * JB inner products.
   1521 template <typename TL, typename TR>
   1522 inline void SparseMatMul<TL, TR>::Compute(
   1523     typename SparseMatMul<TL, TR>::TensorInfoCache* /*cache*/,
   1524     const typename SparseMatMul<TL, TR>::ConstMatrixMapL& left,
   1525     const typename SparseMatMul<TL, TR>::ConstMatrixMapR& right,
   1526     bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
   1527     bool transpose_output, MatrixMap* output) {
   1528   const int num_threads = thread_pool->num_threads;
   1529   int KR, NR, KL, JB, IB;
   1530   ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL,
   1531                     &JB, &IB);
   1532   // Slice the left matrix
   1533   std::vector<std::vector<SparseSlice<TL>*>> left_slices;
   1534   std::unique_ptr<BlockingCounter> sparse_slice_counter =
   1535       CreateSparseSlices(ConstMatrixMapL(left.data(), left.dimensions()),
   1536                          transpose_left, M, K, KL, &left_slices, thread_pool);
   1537   const int num_left_slices = left_slices.size();
   1538 
   1539   const int right_dim0 = right.dimension(0);
   1540   const int right_dim1 = right.dimension(1);
   1541   // Allocate buffer for storing slices of right matrix.
   1542   // Note buffer needs enough space to hold at most a KR * NR matrix since that
   1543   // is the block size per iteration.
   1544   const int buffer_num_rows =
   1545       std::min(KR, right_dim0) * (std::min(NR, right_dim1) + N - 1) / N;
   1546   MatrixR buffer(buffer_num_rows, N);
   1547   std::vector<ConstMatrixMapR*> right_slices;
   1548 
   1549   std::vector<SparseSlice<TL>*> block_left_slices;
   1550   std::vector<std::function<void(void)>> tasks;
   1551   // Number of blocks based on block sizes of KR * NR.
   1552   const int num_k_blocks = (right_dim0 + KR - 1) / KR;
   1553   const int num_n_blocks = (right_dim1 + NR - 1) / NR;
   1554   std::unique_ptr<BlockingCounter> dense_slice_counter;
   1555 
   1556   for (int nb = 0; nb < num_n_blocks; ++nb) {
   1557     const int right_num_cols =
   1558         std::min(NR, static_cast<int>(right_dim1 - NR * nb));
   1559     for (int kb = 0; kb < num_k_blocks; ++kb) {
   1560       const int right_num_rows =
   1561           std::min(KR, static_cast<int>(right_dim0 - KR * kb));
   1562       dense_slice_counter = CreateDenseSlices(
   1563           right, kb * KR, right_num_rows, nb * NR, right_num_cols, thread_pool,
   1564           &buffer, &right_slices);
   1565       const int num_right_slices = right_slices.size();
   1566       tasks.reserve(num_left_slices * num_right_slices);
   1567       for (int j_outer = 0; j_outer < num_right_slices; j_outer += JB) {
   1568         for (int i_outer = 0; i_outer < num_left_slices; i_outer += IB) {
   1569           for (int j_inner = j_outer;
   1570                j_inner < std::min(num_right_slices, j_outer + JB); ++j_inner) {
   1571             const int num_cols = std::min(N, right_num_cols - N * j_inner);
   1572             for (int i_inner = i_outer;
   1573                  i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) {
   1574               block_left_slices.clear();
   1575               int begin = kb * KR / KL;
   1576               int end = std::min<int>((kb + 1) * KR / KL,
   1577                                       (right.dimension(0) + KL - 1) / KL);
   1578               DCHECK_LT(begin, end);
   1579               block_left_slices.insert(block_left_slices.begin(),
   1580                                        left_slices[i_inner].begin() + begin,
   1581                                        left_slices[i_inner].begin() + end);
   1582               tasks.push_back(std::bind(
   1583                   &ComputeOutputBlock, block_left_slices,
   1584                   std::ref(*right_slices[j_inner]), num_cols, M * i_inner,
   1585                   N * j_inner + nb * NR, kb == 0, transpose_output, output));
   1586             }
   1587           }
   1588         }
   1589       }
   1590       if (sparse_slice_counter) {
   1591         sparse_slice_counter->Wait();
   1592         sparse_slice_counter.reset(nullptr);
   1593       }
   1594       if (dense_slice_counter) {
   1595         dense_slice_counter->Wait();
   1596         dense_slice_counter.reset(nullptr);
   1597       }
   1598       BlockingCounter bc(tasks.size());
   1599       for (const auto& t : tasks) {
   1600         thread_pool->workers->Schedule([&bc, &t]() {
   1601           t();
   1602           bc.DecrementCount();
   1603         });
   1604       }
   1605       bc.Wait();
   1606       tasks.clear();
   1607       gtl::STLDeleteElements(&right_slices);
   1608       right_slices.clear();
   1609     }
   1610   }
   1611   for (auto& left_slice : left_slices) {
   1612     gtl::STLDeleteElements(&left_slice);
   1613   }
   1614 }
   1615 
   1616 #define REGISTER_SPARSE_MATMUL(TA, TB)                   \
   1617   REGISTER_KERNEL_BUILDER(Name("SparseMatMul")           \
   1618                               .Device(DEVICE_CPU)        \
   1619                               .TypeConstraint<TA>("Ta")  \
   1620                               .TypeConstraint<TB>("Tb"), \
   1621                           SparseMatMulOp<TA, TB, SparseMatMul>);
   1622 #ifdef TENSORFLOW_USE_LIBXSMM
   1623 #define REGISTER_SPARSE_MATMUL_LIBXSMM(TA, TB)           \
   1624   REGISTER_KERNEL_BUILDER(Name("SparseMatMul")           \
   1625                               .Device(DEVICE_CPU)        \
   1626                               .TypeConstraint<TA>("Ta")  \
   1627                               .TypeConstraint<TB>("Tb"), \
   1628                           SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>);
   1629 #endif
   1630 
   1631 REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
   1632 
   1633 REGISTER_SPARSE_MATMUL(float, bfloat16);
   1634 
   1635 REGISTER_SPARSE_MATMUL(bfloat16, float);
   1636 
   1637 #ifdef TENSORFLOW_USE_LIBXSMM
   1638 REGISTER_SPARSE_MATMUL_LIBXSMM(float, float);
   1639 #else
   1640 REGISTER_SPARSE_MATMUL(float, float);
   1641 #endif
   1642 
   1643 #undef REGISTER_SPARSE_MATMUL
   1644 
   1645 }  // end namespace tensorflow
   1646