Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_SPARSE_MATMUL_OP_H_
     17 #define TENSORFLOW_KERNELS_SPARSE_MATMUL_OP_H_
     18 
     19 #include "third_party/eigen3/Eigen/Core"
     20 #include "tensorflow/core/platform/types.h"
     21 
     22 #if defined(PLATFORM_WINDOWS)
     23 #include "tensorflow/core/platform/windows/cpu_info.h"
     24 #include "tensorflow/core/platform/windows/intrinsics_port.h"
     25 #endif
     26 
     27 namespace Eigen {
     28 namespace internal {
     29 
     30 // Return the float representation of the bfloat16 value
     31 // in the lower 16-bits of input
     32 template <typename Packet>
     33 EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_l(const Packet& from) {
     34   tensorflow::uint32 tmp;
     35 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
     36   tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
     37 #else
     38   tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
     39 #endif
     40   return reinterpret_cast<const float&>(tmp);
     41 }
     42 
     43 // Return the float representation of the bfloat16 value
     44 // in the upper 16-bits of input
     45 template <typename Packet>
     46 EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_u(const Packet& from) {
     47   tensorflow::uint32 tmp;
     48 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
     49   tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000;
     50 #else
     51   tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000;
     52 #endif
     53   return reinterpret_cast<const float&>(tmp);
     54 }
     55 
     56 // Specialization non-scalar version on non-sse.
     57 // Enable vectorization on z13 and higher
     58 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \
     59     defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR)
     60 template <typename Packet>
     61 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
     62   float r[4];
     63   tensorflow::uint32 p[4];
     64   pstoreu(r, from);
     65   tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
     66   p[0] = (ir[0] << 16) & 0xffff0000;
     67   p[1] = ir[0] & 0xffff0000;
     68   p[2] = (ir[1] << 16) & 0xffff0000;
     69   p[3] = ir[1] & 0xffff0000;
     70   return ploadu<Packet4f>(reinterpret_cast<float*>(p));
     71 }
     72 
     73 template <typename Packet>
     74 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
     75   float r[4];
     76   tensorflow::uint32 p[4];
     77   pstoreu(r, from);
     78   tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r);
     79   p[0] = (ir[2] << 16) & 0xffff0000;
     80   p[1] = ir[2] & 0xffff0000;
     81   p[2] = (ir[3] << 16) & 0xffff0000;
     82   p[3] = ir[3] & 0xffff0000;
     83   return ploadu<Packet4f>(reinterpret_cast<float*>(p));
     84 }
     85 #endif
     86 
     87 template <typename Packet>
     88 EIGEN_DEVICE_FUNC inline Packet pinterleave4x64(const Packet& from) {
     89   return from;
     90 }
     91 
     92 template <typename Packet>
     93 EIGEN_DEVICE_FUNC inline Packet pbroadcast_first(const Packet& a) {
     94   return a;
     95 }
     96 
     97 template <typename Packet>
     98 EIGEN_DEVICE_FUNC inline Packet pbroadcast_second(const Packet& a) {
     99   assert(false && "Not applicable to Scalar Values");
    100   return a;
    101 }
    102 
    103 template <typename Packet>
    104 EIGEN_DEVICE_FUNC inline Packet pbroadcast_third(const Packet& a) {
    105   assert(false && "Not applicable to Scalar Values");
    106   return a;
    107 }
    108 
    109 template <typename Packet>
    110 EIGEN_DEVICE_FUNC inline Packet pbroadcast_fourth(const Packet& a) {
    111   assert(false && "Not applicable to Scalar Values");
    112   return a;
    113 }
    114 
    115 template <typename Packet>
    116 EIGEN_DEVICE_FUNC inline Packet pload4bf16(
    117     const typename unpacket_traits<Packet>::type* from) {
    118   assert(false && "Not applicable to Scalar Values");
    119   return Packet();
    120 }
    121 
    122 template <typename Packet>
    123 EIGEN_DEVICE_FUNC inline Packet pload2bf16(
    124     const typename unpacket_traits<Packet>::type* from) {
    125   assert(false && "Not applicable to Scalar Values");
    126   return Packet();
    127 }
    128 
    129 // Specialization for pload4bf16 and pload2bf16 for non-sse.
    130 // Enable vectorization on z13 and higher.
    131 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \
    132     defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR)
    133 template <>
    134 EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
    135   tensorflow::uint32 p[4];
    136   const tensorflow::uint32* ir =
    137       reinterpret_cast<const tensorflow::uint32*>(from);
    138   p[0] = (ir[0] << 16) & 0xffff0000;
    139   p[1] = ir[0] & 0xffff0000;
    140   p[2] = (ir[1] << 16) & 0xffff0000;
    141   p[3] = ir[1] & 0xffff0000;
    142   return ploadu<Packet4f>(reinterpret_cast<float*>(p));
    143 }
    144 
    145 template <>
    146 EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
    147   tensorflow::uint32 p[4];
    148   const tensorflow::uint32* ir =
    149       reinterpret_cast<const tensorflow::uint32*>(from);
    150   p[0] = (ir[0] << 16) & 0xffff0000;
    151   p[1] = ir[0] & 0xffff0000;
    152   p[2] = (ir[0] << 16) & 0xffff0000;
    153   p[3] = ir[0] & 0xffff0000;
    154   return ploadu<Packet4f>(reinterpret_cast<float*>(p));
    155 }
    156 #endif
    157 
    158 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX)
    159 // Return a packet with the first value of the input Packet replicated
    160 template <>
    161 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
    162   return vec_splat(a, 0);
    163 }
    164 
    165 // Return a packet with the second value of the input Packet replicated
    166 template <>
    167 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
    168   return vec_splat(a, 1);
    169 }
    170 
    171 // Return a packet with the third value of the input Packet replicated
    172 template <>
    173 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
    174   return vec_splat(a, 2);
    175 }
    176 
    177 // Return a packet with the fourth value of the input Packet replicated
    178 template <>
    179 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
    180   return vec_splat(a, 3);
    181 }
    182 #endif
    183 
    184 #ifdef EIGEN_VECTORIZE_SSE2
    185 // For PacketSize of 4 floats the Packet is not modified
    186 template <>
    187 EIGEN_STRONG_INLINE Packet4f pinterleave4x64<Packet4f>(const Packet4f& from) {
    188   return from;
    189 }
    190 
    191 // Return a Packet with 4 floats loaded from 4 bfloat16 values
    192 template <>
    193 EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) {
    194   __m128i zero = _mm_setzero_si128();
    195   __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
    196   return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
    197 }
    198 
    199 // Return a Packet with 2 floats loaded from 2 bfloat16 values
    200 template <>
    201 EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) {
    202   __m128i zero = _mm_setzero_si128();
    203   __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
    204   return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
    205 }
    206 
    207 // Return a Packet with 4 floats expanded from 4 bfloat16 values
    208 // in the lower half of the 128-bit lane
    209 template <typename Packet>
    210 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) {
    211   __m128i zero = _mm_setzero_si128();
    212   __m128i tmp = _mm_castps_si128(from);
    213   return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp));
    214 }
    215 
    216 // Return a Packet with 4 floats expanded from 4 bfloat16 values
    217 // in the upper half of the 128-bit lane
    218 template <typename Packet>
    219 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) {
    220   __m128i zero = _mm_setzero_si128();
    221   __m128i tmp = _mm_castps_si128(from);
    222   return _mm_castsi128_ps(_mm_unpackhi_epi16(zero, tmp));
    223 }
    224 
    225 // Return a packet with the first value of the input Packet replicated
    226 template <>
    227 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) {
    228   return _mm_set1_ps(pfirst<Packet4f>(a));
    229 }
    230 
    231 // Return a packet with the second value of the input Packet replicated
    232 template <>
    233 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) {
    234   return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 1)));
    235 }
    236 
    237 // Return a packet with the third value of the input Packet replicated
    238 template <>
    239 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) {
    240   return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 2)));
    241 }
    242 
    243 // Return a packet with the fourth value of the input Packet replicated
    244 template <>
    245 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) {
    246   return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 3)));
    247 }
    248 
    249 #endif
    250 
    251 #ifdef EIGEN_VECTORIZE_AVX512
    252 template <>
    253 EIGEN_STRONG_INLINE Packet16f
    254 pbroadcast_first<Packet16f>(const Packet16f& a_in) {
    255   Packet4f a = _mm512_castps512_ps128(a_in);
    256   return _mm512_broadcastss_ps(a);
    257 }
    258 template <>
    259 EIGEN_STRONG_INLINE Packet16f
    260 pbroadcast_second<Packet16f>(const Packet16f& a_in) {
    261   Packet4f a = _mm512_castps512_ps128(a_in);
    262   return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(1, 1, 1, 1)));
    263 }
    264 template <>
    265 EIGEN_STRONG_INLINE Packet16f
    266 pbroadcast_third<Packet16f>(const Packet16f& a_in) {
    267   Packet4f a = _mm512_castps512_ps128(a_in);
    268   return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(2, 2, 2, 2)));
    269 }
    270 template <>
    271 EIGEN_STRONG_INLINE Packet16f
    272 pbroadcast_fourth<Packet16f>(const Packet16f& a_in) {
    273   Packet4f a = _mm512_castps512_ps128(a_in);
    274   return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(3, 3, 3, 3)));
    275 }
    276 template <>
    277 EIGEN_STRONG_INLINE Packet8d pbroadcast_first<Packet8d>(const Packet8d& a_in) {
    278   Packet2d a = _mm512_castpd512_pd128(a_in);
    279   return _mm512_broadcastsd_pd(a);
    280 }
    281 template <>
    282 EIGEN_STRONG_INLINE Packet8d pbroadcast_second<Packet8d>(const Packet8d& a_in) {
    283   Packet2d a = _mm_permute_pd(_mm512_castpd512_pd128(a_in), 3);
    284   return _mm512_broadcastsd_pd(a);
    285 }
    286 template <>
    287 EIGEN_STRONG_INLINE Packet8d pbroadcast_third<Packet8d>(const Packet8d& a_in) {
    288   Packet2d a = _mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1);
    289   return _mm512_broadcastsd_pd(a);
    290 }
    291 template <>
    292 EIGEN_STRONG_INLINE Packet8d pbroadcast_fourth<Packet8d>(const Packet8d& a_in) {
    293   Packet2d a =
    294       _mm_permute_pd(_mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1), 3);
    295   return _mm512_broadcastsd_pd(a);
    296 }
    297 template <>
    298 EIGEN_STRONG_INLINE Packet16i
    299 pbroadcast_first<Packet16i>(const Packet16i& a_in) {
    300   Packet4i a = _mm512_castsi512_si128(a_in);
    301   return _mm512_broadcastd_epi32(a);
    302 }
    303 template <>
    304 EIGEN_STRONG_INLINE Packet16i
    305 pbroadcast_second<Packet16i>(const Packet16i& a_in) {
    306   Packet4i a = _mm512_castsi512_si128(a_in);
    307   return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(1, 1, 1, 1)));
    308 }
    309 template <>
    310 EIGEN_STRONG_INLINE Packet16i
    311 pbroadcast_third<Packet16i>(const Packet16i& a_in) {
    312   Packet4i a = _mm512_castsi512_si128(a_in);
    313   return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(2, 2, 2, 2)));
    314 }
    315 template <>
    316 EIGEN_STRONG_INLINE Packet16i
    317 pbroadcast_fourth<Packet16i>(const Packet16i& a_in) {
    318   Packet4i a = _mm512_castsi512_si128(a_in);
    319   return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(3, 3, 3, 3)));
    320 }
    321 #endif
    322 
    323 #ifdef EIGEN_VECTORIZE_AVX
    324 // For a Packet of Size 8 floats(256-bits), swap the 2nd and 3rd quadwords
    325 template <>
    326 EIGEN_STRONG_INLINE Packet8f pinterleave4x64<Packet8f>(const Packet8f& from) {
    327 #ifdef EIGEN_VECTORIZE_AVX2
    328   return _mm256_castsi256_ps(_mm256_permute4x64_epi64(_mm256_castps_si256(from),
    329                                                       _MM_SHUFFLE(3, 1, 2, 0)));
    330 #else
    331   auto tmp1 = _mm256_extract_epi32(_mm256_castps_si256(from), 2);
    332   auto tmp2 = _mm256_extract_epi32(_mm256_castps_si256(from), 3);
    333   auto tmp3 = _mm256_extract_epi32(_mm256_castps_si256(from), 4);
    334   auto tmp4 = _mm256_extract_epi32(_mm256_castps_si256(from), 5);
    335   auto tmp5 = _mm256_insert_epi32(_mm256_castps_si256(from), tmp1, 4);
    336   tmp5 = _mm256_insert_epi32(tmp5, tmp2, 5);
    337   tmp5 = _mm256_insert_epi32(tmp5, tmp3, 2);
    338   tmp5 = _mm256_insert_epi32(tmp5, tmp4, 3);
    339   return _mm256_castsi256_ps(tmp5);
    340 #endif
    341 }
    342 // Return a Packet with 4 floats loaded from 4 bfloat16 values
    343 template <>
    344 EIGEN_STRONG_INLINE Packet8f pload4bf16<Packet8f>(const float* from) {
    345   __m128i zero = _mm_setzero_si128();
    346   __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
    347   return _mm256_castps128_ps256(
    348       _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
    349 }
    350 // Return a Packet with 2 floats loaded from 2 bfloat16 values
    351 template <>
    352 EIGEN_STRONG_INLINE Packet8f pload2bf16<Packet8f>(const float* from) {
    353   __m128i zero = _mm_setzero_si128();
    354   __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
    355   return _mm256_castps128_ps256(
    356       _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
    357 }
    358 
    359 #ifdef EIGEN_VECTORIZE_AVX512
    360 // Return a Packet with 4 floats loaded from 4 bfloat16 values
    361 template <>
    362 EIGEN_STRONG_INLINE Packet16f pload4bf16<Packet16f>(const float* from) {
    363   __m128i zero = _mm_setzero_si128();
    364   __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from));
    365   return _mm512_castps128_ps512(
    366       _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
    367 }
    368 // Return a Packet with 2 floats loaded from 2 bfloat16 values
    369 template <>
    370 EIGEN_STRONG_INLINE Packet16f pload2bf16<Packet16f>(const float* from) {
    371   __m128i zero = _mm_setzero_si128();
    372   __m128i tmp = _mm_castps_si128(_mm_load_ps1(from));
    373   return _mm512_castps128_ps512(
    374       _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)));
    375 }
    376 #endif
    377 
    378 // For each 128-bit lane convert 4 bfloat to 4 float values from the lower half
    379 // of the 128-bit lane
    380 template <typename Packet>
    381 EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_l(const Packet8f& from) {
    382 #ifdef EIGEN_VECTORIZE_AVX2
    383   __m256i zero = _mm256_setzero_si256();
    384   __m256i tmp = _mm256_castps_si256(from);
    385   return _mm256_castsi256_ps(_mm256_unpacklo_epi16(zero, tmp));
    386 #else
    387   __m128i zero = _mm_setzero_si128();
    388   __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0));
    389   __m128i res_l = _mm_unpacklo_epi16(zero, low);
    390   __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1));
    391   __m128i res_h = _mm_unpacklo_epi16(zero, high);
    392   __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l));
    393   res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1);
    394   return res;
    395 #endif
    396 }
    397 
    398 // For each 128-bit lane convert 4 bfloat to 4 float values from the upper half
    399 // of the 128-bit lane
    400 template <typename Packet>
    401 EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_u(const Packet8f& from) {
    402 #ifdef EIGEN_VECTORIZE_AVX2
    403   __m256i zero = _mm256_setzero_si256();
    404   __m256i tmp = _mm256_castps_si256(from);
    405   return _mm256_castsi256_ps(_mm256_unpackhi_epi16(zero, tmp));
    406 #else
    407   __m128i zero = _mm_setzero_si128();
    408   __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0));
    409   __m128i res_l = _mm_unpackhi_epi16(zero, low);
    410   __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1));
    411   __m128i res_h = _mm_unpackhi_epi16(zero, high);
    412   __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l));
    413   res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1);
    414   return res;
    415 #endif
    416 }
    417 
    418 // Return a packet with the first value of the input Packet replicated
    419 template <>
    420 EIGEN_STRONG_INLINE Packet8f pbroadcast_first<Packet8f>(const Packet8f& a) {
    421   return _mm256_set1_ps(pfirst<Packet8f>(a));
    422 }
    423 
    424 // Return a packet with the second value of the input Packet replicated
    425 template <>
    426 EIGEN_STRONG_INLINE Packet8f pbroadcast_second<Packet8f>(const Packet8f& a) {
    427   return _mm256_set1_ps(
    428       _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 1))));
    429 }
    430 
    431 // Return a packet with the third value of the input Packet replicated
    432 template <>
    433 EIGEN_STRONG_INLINE Packet8f pbroadcast_third<Packet8f>(const Packet8f& a) {
    434   return _mm256_set1_ps(
    435       _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 2))));
    436 }
    437 
    438 // Return a packet with the fourth value of the input Packet replicated
    439 template <>
    440 EIGEN_STRONG_INLINE Packet8f pbroadcast_fourth<Packet8f>(const Packet8f& a) {
    441   return _mm256_set1_ps(
    442       _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 3))));
    443 }
    444 
    445 #endif
    446 
    447 #ifdef EIGEN_VECTORIZE_AVX512
    448 
    449 template <typename Packet>
    450 EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_l(const Packet16f& from) {
    451   return _mm512_castsi512_ps(_mm512_slli_epi32(
    452       _mm512_cvtepu16_epi32(_mm512_castsi512_si256(_mm512_castps_si512(from))),
    453       16));
    454 }
    455 
    456 template <typename Packet>
    457 EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) {
    458   Packet16i tmp = _mm512_castps_si512(from);
    459   Packet16i tmp2 = _mm512_alignr_epi32(tmp, tmp, 8);
    460   return _mm512_castsi512_ps(_mm512_slli_epi32(
    461       _mm512_cvtepu16_epi32(_mm512_castsi512_si256(tmp2)), 16));
    462 }
    463 
    464 #endif
    465 }  // namespace internal
    466 }  // namespace Eigen
    467 #endif
    468