1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_SPARSE_MATMUL_OP_H_ 17 #define TENSORFLOW_KERNELS_SPARSE_MATMUL_OP_H_ 18 19 #include "third_party/eigen3/Eigen/Core" 20 #include "tensorflow/core/platform/types.h" 21 22 #if defined(PLATFORM_WINDOWS) 23 #include "tensorflow/core/platform/windows/cpu_info.h" 24 #include "tensorflow/core/platform/windows/intrinsics_port.h" 25 #endif 26 27 namespace Eigen { 28 namespace internal { 29 30 // Return the float representation of the bfloat16 value 31 // in the lower 16-bits of input 32 template <typename Packet> 33 EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_l(const Packet& from) { 34 tensorflow::uint32 tmp; 35 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ 36 tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000; 37 #else 38 tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000; 39 #endif 40 return reinterpret_cast<const float&>(tmp); 41 } 42 43 // Return the float representation of the bfloat16 value 44 // in the upper 16-bits of input 45 template <typename Packet> 46 EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_u(const Packet& from) { 47 tensorflow::uint32 tmp; 48 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ 49 tmp = (reinterpret_cast<const tensorflow::uint32&>(from) << 16) & 0xffff0000; 50 #else 51 tmp = (reinterpret_cast<const tensorflow::uint32&>(from)) & 0xffff0000; 52 #endif 53 return reinterpret_cast<const float&>(tmp); 54 } 55 56 // Specialization non-scalar version on non-sse. 57 // Enable vectorization on z13 and higher 58 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \ 59 defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR) 60 template <typename Packet> 61 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) { 62 float r[4]; 63 tensorflow::uint32 p[4]; 64 pstoreu(r, from); 65 tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r); 66 p[0] = (ir[0] << 16) & 0xffff0000; 67 p[1] = ir[0] & 0xffff0000; 68 p[2] = (ir[1] << 16) & 0xffff0000; 69 p[3] = ir[1] & 0xffff0000; 70 return ploadu<Packet4f>(reinterpret_cast<float*>(p)); 71 } 72 73 template <typename Packet> 74 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) { 75 float r[4]; 76 tensorflow::uint32 p[4]; 77 pstoreu(r, from); 78 tensorflow::uint32* ir = reinterpret_cast<tensorflow::uint32*>(r); 79 p[0] = (ir[2] << 16) & 0xffff0000; 80 p[1] = ir[2] & 0xffff0000; 81 p[2] = (ir[3] << 16) & 0xffff0000; 82 p[3] = ir[3] & 0xffff0000; 83 return ploadu<Packet4f>(reinterpret_cast<float*>(p)); 84 } 85 #endif 86 87 template <typename Packet> 88 EIGEN_DEVICE_FUNC inline Packet pinterleave4x64(const Packet& from) { 89 return from; 90 } 91 92 template <typename Packet> 93 EIGEN_DEVICE_FUNC inline Packet pbroadcast_first(const Packet& a) { 94 return a; 95 } 96 97 template <typename Packet> 98 EIGEN_DEVICE_FUNC inline Packet pbroadcast_second(const Packet& a) { 99 assert(false && "Not applicable to Scalar Values"); 100 return a; 101 } 102 103 template <typename Packet> 104 EIGEN_DEVICE_FUNC inline Packet pbroadcast_third(const Packet& a) { 105 assert(false && "Not applicable to Scalar Values"); 106 return a; 107 } 108 109 template <typename Packet> 110 EIGEN_DEVICE_FUNC inline Packet pbroadcast_fourth(const Packet& a) { 111 assert(false && "Not applicable to Scalar Values"); 112 return a; 113 } 114 115 template <typename Packet> 116 EIGEN_DEVICE_FUNC inline Packet pload4bf16( 117 const typename unpacket_traits<Packet>::type* from) { 118 assert(false && "Not applicable to Scalar Values"); 119 return Packet(); 120 } 121 122 template <typename Packet> 123 EIGEN_DEVICE_FUNC inline Packet pload2bf16( 124 const typename unpacket_traits<Packet>::type* from) { 125 assert(false && "Not applicable to Scalar Values"); 126 return Packet(); 127 } 128 129 // Specialization for pload4bf16 and pload2bf16 for non-sse. 130 // Enable vectorization on z13 and higher. 131 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \ 132 defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR) 133 template <> 134 EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) { 135 tensorflow::uint32 p[4]; 136 const tensorflow::uint32* ir = 137 reinterpret_cast<const tensorflow::uint32*>(from); 138 p[0] = (ir[0] << 16) & 0xffff0000; 139 p[1] = ir[0] & 0xffff0000; 140 p[2] = (ir[1] << 16) & 0xffff0000; 141 p[3] = ir[1] & 0xffff0000; 142 return ploadu<Packet4f>(reinterpret_cast<float*>(p)); 143 } 144 145 template <> 146 EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) { 147 tensorflow::uint32 p[4]; 148 const tensorflow::uint32* ir = 149 reinterpret_cast<const tensorflow::uint32*>(from); 150 p[0] = (ir[0] << 16) & 0xffff0000; 151 p[1] = ir[0] & 0xffff0000; 152 p[2] = (ir[0] << 16) & 0xffff0000; 153 p[3] = ir[0] & 0xffff0000; 154 return ploadu<Packet4f>(reinterpret_cast<float*>(p)); 155 } 156 #endif 157 158 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) 159 // Return a packet with the first value of the input Packet replicated 160 template <> 161 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) { 162 return vec_splat(a, 0); 163 } 164 165 // Return a packet with the second value of the input Packet replicated 166 template <> 167 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) { 168 return vec_splat(a, 1); 169 } 170 171 // Return a packet with the third value of the input Packet replicated 172 template <> 173 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) { 174 return vec_splat(a, 2); 175 } 176 177 // Return a packet with the fourth value of the input Packet replicated 178 template <> 179 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) { 180 return vec_splat(a, 3); 181 } 182 #endif 183 184 #ifdef EIGEN_VECTORIZE_SSE2 185 // For PacketSize of 4 floats the Packet is not modified 186 template <> 187 EIGEN_STRONG_INLINE Packet4f pinterleave4x64<Packet4f>(const Packet4f& from) { 188 return from; 189 } 190 191 // Return a Packet with 4 floats loaded from 4 bfloat16 values 192 template <> 193 EIGEN_STRONG_INLINE Packet4f pload4bf16<Packet4f>(const float* from) { 194 __m128i zero = _mm_setzero_si128(); 195 __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from)); 196 return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)); 197 } 198 199 // Return a Packet with 2 floats loaded from 2 bfloat16 values 200 template <> 201 EIGEN_STRONG_INLINE Packet4f pload2bf16<Packet4f>(const float* from) { 202 __m128i zero = _mm_setzero_si128(); 203 __m128i tmp = _mm_castps_si128(_mm_load_ps1(from)); 204 return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)); 205 } 206 207 // Return a Packet with 4 floats expanded from 4 bfloat16 values 208 // in the lower half of the 128-bit lane 209 template <typename Packet> 210 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) { 211 __m128i zero = _mm_setzero_si128(); 212 __m128i tmp = _mm_castps_si128(from); 213 return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)); 214 } 215 216 // Return a Packet with 4 floats expanded from 4 bfloat16 values 217 // in the upper half of the 128-bit lane 218 template <typename Packet> 219 EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) { 220 __m128i zero = _mm_setzero_si128(); 221 __m128i tmp = _mm_castps_si128(from); 222 return _mm_castsi128_ps(_mm_unpackhi_epi16(zero, tmp)); 223 } 224 225 // Return a packet with the first value of the input Packet replicated 226 template <> 227 EIGEN_STRONG_INLINE Packet4f pbroadcast_first<Packet4f>(const Packet4f& a) { 228 return _mm_set1_ps(pfirst<Packet4f>(a)); 229 } 230 231 // Return a packet with the second value of the input Packet replicated 232 template <> 233 EIGEN_STRONG_INLINE Packet4f pbroadcast_second<Packet4f>(const Packet4f& a) { 234 return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 1))); 235 } 236 237 // Return a packet with the third value of the input Packet replicated 238 template <> 239 EIGEN_STRONG_INLINE Packet4f pbroadcast_third<Packet4f>(const Packet4f& a) { 240 return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 2))); 241 } 242 243 // Return a packet with the fourth value of the input Packet replicated 244 template <> 245 EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth<Packet4f>(const Packet4f& a) { 246 return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 3))); 247 } 248 249 #endif 250 251 #ifdef EIGEN_VECTORIZE_AVX512 252 template <> 253 EIGEN_STRONG_INLINE Packet16f 254 pbroadcast_first<Packet16f>(const Packet16f& a_in) { 255 Packet4f a = _mm512_castps512_ps128(a_in); 256 return _mm512_broadcastss_ps(a); 257 } 258 template <> 259 EIGEN_STRONG_INLINE Packet16f 260 pbroadcast_second<Packet16f>(const Packet16f& a_in) { 261 Packet4f a = _mm512_castps512_ps128(a_in); 262 return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(1, 1, 1, 1))); 263 } 264 template <> 265 EIGEN_STRONG_INLINE Packet16f 266 pbroadcast_third<Packet16f>(const Packet16f& a_in) { 267 Packet4f a = _mm512_castps512_ps128(a_in); 268 return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(2, 2, 2, 2))); 269 } 270 template <> 271 EIGEN_STRONG_INLINE Packet16f 272 pbroadcast_fourth<Packet16f>(const Packet16f& a_in) { 273 Packet4f a = _mm512_castps512_ps128(a_in); 274 return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(3, 3, 3, 3))); 275 } 276 template <> 277 EIGEN_STRONG_INLINE Packet8d pbroadcast_first<Packet8d>(const Packet8d& a_in) { 278 Packet2d a = _mm512_castpd512_pd128(a_in); 279 return _mm512_broadcastsd_pd(a); 280 } 281 template <> 282 EIGEN_STRONG_INLINE Packet8d pbroadcast_second<Packet8d>(const Packet8d& a_in) { 283 Packet2d a = _mm_permute_pd(_mm512_castpd512_pd128(a_in), 3); 284 return _mm512_broadcastsd_pd(a); 285 } 286 template <> 287 EIGEN_STRONG_INLINE Packet8d pbroadcast_third<Packet8d>(const Packet8d& a_in) { 288 Packet2d a = _mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1); 289 return _mm512_broadcastsd_pd(a); 290 } 291 template <> 292 EIGEN_STRONG_INLINE Packet8d pbroadcast_fourth<Packet8d>(const Packet8d& a_in) { 293 Packet2d a = 294 _mm_permute_pd(_mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1), 3); 295 return _mm512_broadcastsd_pd(a); 296 } 297 template <> 298 EIGEN_STRONG_INLINE Packet16i 299 pbroadcast_first<Packet16i>(const Packet16i& a_in) { 300 Packet4i a = _mm512_castsi512_si128(a_in); 301 return _mm512_broadcastd_epi32(a); 302 } 303 template <> 304 EIGEN_STRONG_INLINE Packet16i 305 pbroadcast_second<Packet16i>(const Packet16i& a_in) { 306 Packet4i a = _mm512_castsi512_si128(a_in); 307 return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(1, 1, 1, 1))); 308 } 309 template <> 310 EIGEN_STRONG_INLINE Packet16i 311 pbroadcast_third<Packet16i>(const Packet16i& a_in) { 312 Packet4i a = _mm512_castsi512_si128(a_in); 313 return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(2, 2, 2, 2))); 314 } 315 template <> 316 EIGEN_STRONG_INLINE Packet16i 317 pbroadcast_fourth<Packet16i>(const Packet16i& a_in) { 318 Packet4i a = _mm512_castsi512_si128(a_in); 319 return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(3, 3, 3, 3))); 320 } 321 #endif 322 323 #ifdef EIGEN_VECTORIZE_AVX 324 // For a Packet of Size 8 floats(256-bits), swap the 2nd and 3rd quadwords 325 template <> 326 EIGEN_STRONG_INLINE Packet8f pinterleave4x64<Packet8f>(const Packet8f& from) { 327 #ifdef EIGEN_VECTORIZE_AVX2 328 return _mm256_castsi256_ps(_mm256_permute4x64_epi64(_mm256_castps_si256(from), 329 _MM_SHUFFLE(3, 1, 2, 0))); 330 #else 331 auto tmp1 = _mm256_extract_epi32(_mm256_castps_si256(from), 2); 332 auto tmp2 = _mm256_extract_epi32(_mm256_castps_si256(from), 3); 333 auto tmp3 = _mm256_extract_epi32(_mm256_castps_si256(from), 4); 334 auto tmp4 = _mm256_extract_epi32(_mm256_castps_si256(from), 5); 335 auto tmp5 = _mm256_insert_epi32(_mm256_castps_si256(from), tmp1, 4); 336 tmp5 = _mm256_insert_epi32(tmp5, tmp2, 5); 337 tmp5 = _mm256_insert_epi32(tmp5, tmp3, 2); 338 tmp5 = _mm256_insert_epi32(tmp5, tmp4, 3); 339 return _mm256_castsi256_ps(tmp5); 340 #endif 341 } 342 // Return a Packet with 4 floats loaded from 4 bfloat16 values 343 template <> 344 EIGEN_STRONG_INLINE Packet8f pload4bf16<Packet8f>(const float* from) { 345 __m128i zero = _mm_setzero_si128(); 346 __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from)); 347 return _mm256_castps128_ps256( 348 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp))); 349 } 350 // Return a Packet with 2 floats loaded from 2 bfloat16 values 351 template <> 352 EIGEN_STRONG_INLINE Packet8f pload2bf16<Packet8f>(const float* from) { 353 __m128i zero = _mm_setzero_si128(); 354 __m128i tmp = _mm_castps_si128(_mm_load_ps1(from)); 355 return _mm256_castps128_ps256( 356 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp))); 357 } 358 359 #ifdef EIGEN_VECTORIZE_AVX512 360 // Return a Packet with 4 floats loaded from 4 bfloat16 values 361 template <> 362 EIGEN_STRONG_INLINE Packet16f pload4bf16<Packet16f>(const float* from) { 363 __m128i zero = _mm_setzero_si128(); 364 __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from)); 365 return _mm512_castps128_ps512( 366 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp))); 367 } 368 // Return a Packet with 2 floats loaded from 2 bfloat16 values 369 template <> 370 EIGEN_STRONG_INLINE Packet16f pload2bf16<Packet16f>(const float* from) { 371 __m128i zero = _mm_setzero_si128(); 372 __m128i tmp = _mm_castps_si128(_mm_load_ps1(from)); 373 return _mm512_castps128_ps512( 374 _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp))); 375 } 376 #endif 377 378 // For each 128-bit lane convert 4 bfloat to 4 float values from the lower half 379 // of the 128-bit lane 380 template <typename Packet> 381 EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_l(const Packet8f& from) { 382 #ifdef EIGEN_VECTORIZE_AVX2 383 __m256i zero = _mm256_setzero_si256(); 384 __m256i tmp = _mm256_castps_si256(from); 385 return _mm256_castsi256_ps(_mm256_unpacklo_epi16(zero, tmp)); 386 #else 387 __m128i zero = _mm_setzero_si128(); 388 __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0)); 389 __m128i res_l = _mm_unpacklo_epi16(zero, low); 390 __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1)); 391 __m128i res_h = _mm_unpacklo_epi16(zero, high); 392 __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l)); 393 res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1); 394 return res; 395 #endif 396 } 397 398 // For each 128-bit lane convert 4 bfloat to 4 float values from the upper half 399 // of the 128-bit lane 400 template <typename Packet> 401 EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_u(const Packet8f& from) { 402 #ifdef EIGEN_VECTORIZE_AVX2 403 __m256i zero = _mm256_setzero_si256(); 404 __m256i tmp = _mm256_castps_si256(from); 405 return _mm256_castsi256_ps(_mm256_unpackhi_epi16(zero, tmp)); 406 #else 407 __m128i zero = _mm_setzero_si128(); 408 __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0)); 409 __m128i res_l = _mm_unpackhi_epi16(zero, low); 410 __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1)); 411 __m128i res_h = _mm_unpackhi_epi16(zero, high); 412 __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l)); 413 res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1); 414 return res; 415 #endif 416 } 417 418 // Return a packet with the first value of the input Packet replicated 419 template <> 420 EIGEN_STRONG_INLINE Packet8f pbroadcast_first<Packet8f>(const Packet8f& a) { 421 return _mm256_set1_ps(pfirst<Packet8f>(a)); 422 } 423 424 // Return a packet with the second value of the input Packet replicated 425 template <> 426 EIGEN_STRONG_INLINE Packet8f pbroadcast_second<Packet8f>(const Packet8f& a) { 427 return _mm256_set1_ps( 428 _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 1)))); 429 } 430 431 // Return a packet with the third value of the input Packet replicated 432 template <> 433 EIGEN_STRONG_INLINE Packet8f pbroadcast_third<Packet8f>(const Packet8f& a) { 434 return _mm256_set1_ps( 435 _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 2)))); 436 } 437 438 // Return a packet with the fourth value of the input Packet replicated 439 template <> 440 EIGEN_STRONG_INLINE Packet8f pbroadcast_fourth<Packet8f>(const Packet8f& a) { 441 return _mm256_set1_ps( 442 _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 3)))); 443 } 444 445 #endif 446 447 #ifdef EIGEN_VECTORIZE_AVX512 448 449 template <typename Packet> 450 EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_l(const Packet16f& from) { 451 return _mm512_castsi512_ps(_mm512_slli_epi32( 452 _mm512_cvtepu16_epi32(_mm512_castsi512_si256(_mm512_castps_si512(from))), 453 16)); 454 } 455 456 template <typename Packet> 457 EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) { 458 Packet16i tmp = _mm512_castps_si512(from); 459 Packet16i tmp2 = _mm512_alignr_epi32(tmp, tmp, 8); 460 return _mm512_castsi512_ps(_mm512_slli_epi32( 461 _mm512_cvtepu16_epi32(_mm512_castsi512_si256(tmp2)), 16)); 462 } 463 464 #endif 465 } // namespace internal 466 } // namespace Eigen 467 #endif 468