Home | History | Annotate | Download | only in common
      1 /****************************************************************************
      2 * Copyright (C) 2017 Intel Corporation.   All Rights Reserved.
      3 *
      4 * Permission is hereby granted, free of charge, to any person obtaining a
      5 * copy of this software and associated documentation files (the "Software"),
      6 * to deal in the Software without restriction, including without limitation
      7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
      8 * and/or sell copies of the Software, and to permit persons to whom the
      9 * Software is furnished to do so, subject to the following conditions:
     10 *
     11 * The above copyright notice and this permission notice (including the next
     12 * paragraph) shall be included in all copies or substantial portions of the
     13 * Software.
     14 *
     15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
     17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
     18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
     19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
     20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
     21 * IN THE SOFTWARE.
     22 ****************************************************************************/
     23 #pragma once
     24 
     25 #if !defined(__cplusplus)
     26 #error C++ compilation required
     27 #endif
     28 
     29 #include <immintrin.h>
     30 #include <inttypes.h>
     31 #include <stdint.h>
     32 
     33 #define SIMD_ARCH_AVX       0
     34 #define SIMD_ARCH_AVX2      1
     35 #define SIMD_ARCH_AVX512    2
     36 
     37 #if !defined(SIMD_ARCH)
     38 #define SIMD_ARCH SIMD_ARCH_AVX
     39 #endif
     40 
     41 #if defined(_MSC_VER)
     42 #define SIMDCALL __vectorcall
     43 #define SIMDINLINE __forceinline
     44 #define SIMDALIGN(type_, align_) __declspec(align(align_)) type_
     45 #else
     46 #define SIMDCALL
     47 #define SIMDINLINE inline
     48 #define SIMDALIGN(type_, align_) type_ __attribute__((aligned(align_)))
     49 #endif
     50 
     51 // For documentation, please see the following include...
     52 // #include "simdlib_interface.hpp"
     53 
     54 namespace SIMDImpl
     55 {
     56     enum class CompareType
     57     {
     58         EQ_OQ      = 0x00, // Equal (ordered, nonsignaling)
     59         LT_OS      = 0x01, // Less-than (ordered, signaling)
     60         LE_OS      = 0x02, // Less-than-or-equal (ordered, signaling)
     61         UNORD_Q    = 0x03, // Unordered (nonsignaling)
     62         NEQ_UQ     = 0x04, // Not-equal (unordered, nonsignaling)
     63         NLT_US     = 0x05, // Not-less-than (unordered, signaling)
     64         NLE_US     = 0x06, // Not-less-than-or-equal (unordered, signaling)
     65         ORD_Q      = 0x07, // Ordered (nonsignaling)
     66         EQ_UQ      = 0x08, // Equal (unordered, non-signaling)
     67         NGE_US     = 0x09, // Not-greater-than-or-equal (unordered, signaling)
     68         NGT_US     = 0x0A, // Not-greater-than (unordered, signaling)
     69         FALSE_OQ   = 0x0B, // False (ordered, nonsignaling)
     70         NEQ_OQ     = 0x0C, // Not-equal (ordered, non-signaling)
     71         GE_OS      = 0x0D, // Greater-than-or-equal (ordered, signaling)
     72         GT_OS      = 0x0E, // Greater-than (ordered, signaling)
     73         TRUE_UQ    = 0x0F, // True (unordered, non-signaling)
     74         EQ_OS      = 0x10, // Equal (ordered, signaling)
     75         LT_OQ      = 0x11, // Less-than (ordered, nonsignaling)
     76         LE_OQ      = 0x12, // Less-than-or-equal (ordered, nonsignaling)
     77         UNORD_S    = 0x13, // Unordered (signaling)
     78         NEQ_US     = 0x14, // Not-equal (unordered, signaling)
     79         NLT_UQ     = 0x15, // Not-less-than (unordered, nonsignaling)
     80         NLE_UQ     = 0x16, // Not-less-than-or-equal (unordered, nonsignaling)
     81         ORD_S      = 0x17, // Ordered (signaling)
     82         EQ_US      = 0x18, // Equal (unordered, signaling)
     83         NGE_UQ     = 0x19, // Not-greater-than-or-equal (unordered, nonsignaling)
     84         NGT_UQ     = 0x1A, // Not-greater-than (unordered, nonsignaling)
     85         FALSE_OS   = 0x1B, // False (ordered, signaling)
     86         NEQ_OS     = 0x1C, // Not-equal (ordered, signaling)
     87         GE_OQ      = 0x1D, // Greater-than-or-equal (ordered, nonsignaling)
     88         GT_OQ      = 0x1E, // Greater-than (ordered, nonsignaling)
     89         TRUE_US    = 0x1F, // True (unordered, signaling)
     90     };
     91 
     92 #if SIMD_ARCH >= SIMD_ARCH_AVX512
     93     enum class CompareTypeInt
     94     {
     95         EQ  = _MM_CMPINT_EQ,    // Equal
     96         LT  = _MM_CMPINT_LT,    // Less than
     97         LE  = _MM_CMPINT_LE,    // Less than or Equal
     98         NE  = _MM_CMPINT_NE,    // Not Equal
     99         GE  = _MM_CMPINT_GE,    // Greater than or Equal
    100         GT  = _MM_CMPINT_GT,    // Greater than
    101     };
    102 #endif // SIMD_ARCH >= SIMD_ARCH_AVX512
    103 
    104     enum class ScaleFactor
    105     {
    106         SF_1 = 1,   // No scaling
    107         SF_2 = 2,   // Scale offset by 2
    108         SF_4 = 4,   // Scale offset by 4
    109         SF_8 = 8,   // Scale offset by 8
    110     };
    111 
    112     enum class RoundMode
    113     {
    114         TO_NEAREST_INT  = 0x00, // Round to nearest integer == TRUNCATE(value + 0.5)
    115         TO_NEG_INF      = 0x01, // Round to negative infinity
    116         TO_POS_INF      = 0x02, // Round to positive infinity
    117         TO_ZERO         = 0x03, // Round to 0 a.k.a. truncate
    118         CUR_DIRECTION   = 0x04, // Round in direction set in MXCSR register
    119 
    120         RAISE_EXC       = 0x00, // Raise exception on overflow
    121         NO_EXC          = 0x08, // Suppress exceptions
    122 
    123         NINT            = static_cast<int>(TO_NEAREST_INT)  | static_cast<int>(RAISE_EXC),
    124         NINT_NOEXC      = static_cast<int>(TO_NEAREST_INT)  | static_cast<int>(NO_EXC),
    125         FLOOR           = static_cast<int>(TO_NEG_INF)      | static_cast<int>(RAISE_EXC),
    126         FLOOR_NOEXC     = static_cast<int>(TO_NEG_INF)      | static_cast<int>(NO_EXC),
    127         CEIL            = static_cast<int>(TO_POS_INF)      | static_cast<int>(RAISE_EXC),
    128         CEIL_NOEXC      = static_cast<int>(TO_POS_INF)      | static_cast<int>(NO_EXC),
    129         TRUNC           = static_cast<int>(TO_ZERO)         | static_cast<int>(RAISE_EXC),
    130         TRUNC_NOEXC     = static_cast<int>(TO_ZERO)         | static_cast<int>(NO_EXC),
    131         RINT            = static_cast<int>(CUR_DIRECTION)   | static_cast<int>(RAISE_EXC),
    132         NEARBYINT       = static_cast<int>(CUR_DIRECTION)   | static_cast<int>(NO_EXC),
    133     };
    134 
    135     struct Traits
    136     {
    137         using CompareType = SIMDImpl::CompareType;
    138         using ScaleFactor = SIMDImpl::ScaleFactor;
    139         using RoundMode   = SIMDImpl::RoundMode;
    140     };
    141 
    142     // Attribute, 4-dimensional attribute in SIMD SOA layout
    143     template<typename Float, typename Integer, typename Double>
    144     union Vec4
    145     {
    146         Float   v[4];
    147         Integer vi[4];
    148         Double  vd[4];
    149         struct
    150         {
    151             Float  x;
    152             Float  y;
    153             Float  z;
    154             Float  w;
    155         };
    156         SIMDINLINE Float& SIMDCALL operator[] (const int i) { return v[i]; }
    157         SIMDINLINE Float const & SIMDCALL operator[] (const int i) const { return v[i]; }
    158         SIMDINLINE Vec4& SIMDCALL operator=(Vec4 const & in)
    159         {
    160             v[0] = in.v[0];
    161             v[1] = in.v[1];
    162             v[2] = in.v[2];
    163             v[3] = in.v[3];
    164             return *this;
    165         }
    166     };
    167 
    168     namespace SIMD128Impl
    169     {
    170         union Float
    171         {
    172             SIMDINLINE Float() = default;
    173             SIMDINLINE Float(__m128 in) : v(in) {}
    174             SIMDINLINE Float& SIMDCALL operator=(__m128 in) { v = in; return *this; }
    175             SIMDINLINE Float& SIMDCALL operator=(Float const & in) { v = in.v; return *this; }
    176             SIMDINLINE SIMDCALL operator __m128() const { return v; }
    177 
    178             SIMDALIGN(__m128, 16) v;
    179         };
    180 
    181         union Integer
    182         {
    183             SIMDINLINE Integer() = default;
    184             SIMDINLINE Integer(__m128i in) : v(in) {}
    185             SIMDINLINE Integer& SIMDCALL operator=(__m128i in) { v = in; return *this; }
    186             SIMDINLINE Integer& SIMDCALL operator=(Integer const & in) { v = in.v; return *this; }
    187             SIMDINLINE SIMDCALL operator __m128i() const { return v; }
    188 
    189             SIMDALIGN(__m128i, 16) v;
    190         };
    191 
    192         union Double
    193         {
    194             SIMDINLINE Double() = default;
    195             SIMDINLINE Double(__m128d in) : v(in) {}
    196             SIMDINLINE Double& SIMDCALL operator=(__m128d in) { v = in; return *this; }
    197             SIMDINLINE Double& SIMDCALL operator=(Double const & in) { v = in.v; return *this; }
    198             SIMDINLINE SIMDCALL operator __m128d() const { return v; }
    199 
    200             SIMDALIGN(__m128d, 16) v;
    201         };
    202 
    203         using Vec4 = SIMDImpl::Vec4<Float, Integer, Double>;
    204         using Mask = uint8_t;
    205 
    206         static const uint32_t SIMD_WIDTH = 4;
    207     } // ns SIMD128Impl
    208 
    209     namespace SIMD256Impl
    210     {
    211         union Float
    212         {
    213             SIMDINLINE Float() = default;
    214             SIMDINLINE Float(__m256 in) : v(in) {}
    215             SIMDINLINE Float(SIMD128Impl::Float const &in_lo, SIMD128Impl::Float const &in_hi = _mm_setzero_ps())
    216             {
    217                 v = _mm256_insertf128_ps(_mm256_castps128_ps256(in_lo), in_hi, 0x1);
    218             }
    219             SIMDINLINE Float& SIMDCALL operator=(__m256 in) { v = in; return *this; }
    220             SIMDINLINE Float& SIMDCALL operator=(Float const & in) { v = in.v; return *this; }
    221             SIMDINLINE SIMDCALL operator __m256() const { return v; }
    222 
    223             SIMDALIGN(__m256, 32) v;
    224             SIMD128Impl::Float v4[2];
    225         };
    226 
    227         union Integer
    228         {
    229             SIMDINLINE Integer() = default;
    230             SIMDINLINE Integer(__m256i in) : v(in) {}
    231             SIMDINLINE Integer(SIMD128Impl::Integer const &in_lo, SIMD128Impl::Integer const &in_hi = _mm_setzero_si128())
    232             {
    233                 v = _mm256_insertf128_si256(_mm256_castsi128_si256(in_lo), in_hi, 0x1);
    234             }
    235             SIMDINLINE Integer& SIMDCALL operator=(__m256i in) { v = in; return *this; }
    236             SIMDINLINE Integer& SIMDCALL operator=(Integer const & in) { v = in.v; return *this; }
    237             SIMDINLINE SIMDCALL operator __m256i() const { return v; }
    238 
    239             SIMDALIGN(__m256i, 32) v;
    240             SIMD128Impl::Integer v4[2];
    241         };
    242 
    243         union Double
    244         {
    245             SIMDINLINE Double() = default;
    246             SIMDINLINE Double(__m256d const &in) : v(in) {}
    247             SIMDINLINE Double(SIMD128Impl::Double const &in_lo, SIMD128Impl::Double const &in_hi = _mm_setzero_pd())
    248             {
    249                 v = _mm256_insertf128_pd(_mm256_castpd128_pd256(in_lo), in_hi, 0x1);
    250             }
    251             SIMDINLINE Double& SIMDCALL operator=(__m256d in) { v = in; return *this; }
    252             SIMDINLINE Double& SIMDCALL operator=(Double const & in) { v = in.v; return *this; }
    253             SIMDINLINE SIMDCALL operator __m256d() const { return v; }
    254 
    255             SIMDALIGN(__m256d, 32) v;
    256             SIMD128Impl::Double v4[2];
    257         };
    258 
    259         using Vec4 = SIMDImpl::Vec4<Float, Integer, Double>;
    260         using Mask = uint8_t;
    261 
    262         static const uint32_t SIMD_WIDTH = 8;
    263     } // ns SIMD256Impl
    264 
    265     namespace SIMD512Impl
    266     {
    267 #if !(defined(__AVX512F__) || defined(_MM_K0_REG))
    268         // Define AVX512 types if not included via immintrin.h.
    269         // All data members of these types are ONLY to viewed
    270         // in a debugger.  Do NOT access them via code!
    271         union __m512
    272         {
    273         private:
    274             float m512_f32[16];
    275         };
    276         struct __m512d
    277         {
    278         private:
    279             double m512d_f64[8];
    280         };
    281 
    282         union __m512i
    283         {
    284         private:
    285             int8_t              m512i_i8[64];
    286             int16_t             m512i_i16[32];
    287             int32_t             m512i_i32[16];
    288             int64_t             m512i_i64[8];
    289             uint8_t             m512i_u8[64];
    290             uint16_t            m512i_u16[32];
    291             uint32_t            m512i_u32[16];
    292             uint64_t            m512i_u64[8];
    293         };
    294 
    295         using __mmask16 = uint16_t;
    296 #endif
    297 
    298 #if defined(__INTEL_COMPILER) || (SIMD_ARCH >= SIMD_ARCH_AVX512)
    299 #define SIMD_ALIGNMENT_BYTES 64
    300 #else
    301 #define SIMD_ALIGNMENT_BYTES 32
    302 #endif
    303 
    304         union Float
    305         {
    306             SIMDINLINE Float() = default;
    307             SIMDINLINE Float(__m512 in) : v(in) {}
    308             SIMDINLINE Float(SIMD256Impl::Float const &in_lo, SIMD256Impl::Float const &in_hi = _mm256_setzero_ps()) { v8[0] = in_lo; v8[1] = in_hi; }
    309             SIMDINLINE Float& SIMDCALL operator=(__m512 in) { v = in; return *this; }
    310             SIMDINLINE Float& SIMDCALL operator=(Float const & in)
    311             {
    312 #if SIMD_ARCH >= SIMD_ARCH_AVX512
    313                 v = in.v;
    314 #else
    315                 v8[0] = in.v8[0];
    316                 v8[1] = in.v8[1];
    317 #endif
    318                 return *this;
    319             }
    320             SIMDINLINE SIMDCALL operator __m512() const { return v; }
    321 
    322             SIMDALIGN(__m512, SIMD_ALIGNMENT_BYTES) v;
    323             SIMD256Impl::Float v8[2];
    324         };
    325 
    326         union Integer
    327         {
    328             SIMDINLINE Integer() = default;
    329             SIMDINLINE Integer(__m512i in) : v(in) {}
    330             SIMDINLINE Integer(SIMD256Impl::Integer const &in_lo, SIMD256Impl::Integer const &in_hi = _mm256_setzero_si256()) { v8[0] = in_lo; v8[1] = in_hi; }
    331             SIMDINLINE Integer& SIMDCALL operator=(__m512i in) { v = in; return *this; }
    332             SIMDINLINE Integer& SIMDCALL operator=(Integer const & in)
    333             {
    334 #if SIMD_ARCH >= SIMD_ARCH_AVX512
    335                 v = in.v;
    336 #else
    337                 v8[0] = in.v8[0];
    338                 v8[1] = in.v8[1];
    339 #endif
    340                 return *this;
    341             }
    342 
    343             SIMDINLINE SIMDCALL operator __m512i() const { return v; }
    344 
    345             SIMDALIGN(__m512i, SIMD_ALIGNMENT_BYTES) v;
    346             SIMD256Impl::Integer v8[2];
    347         };
    348 
    349         union Double
    350         {
    351             SIMDINLINE Double() = default;
    352             SIMDINLINE Double(__m512d in) : v(in) {}
    353             SIMDINLINE Double(SIMD256Impl::Double const &in_lo, SIMD256Impl::Double const &in_hi = _mm256_setzero_pd()) { v8[0] = in_lo; v8[1] = in_hi; }
    354             SIMDINLINE Double& SIMDCALL operator=(__m512d in) { v = in; return *this; }
    355             SIMDINLINE Double& SIMDCALL operator=(Double const & in)
    356             {
    357 #if SIMD_ARCH >= SIMD_ARCH_AVX512
    358                 v = in.v;
    359 #else
    360                 v8[0] = in.v8[0];
    361                 v8[1] = in.v8[1];
    362 #endif
    363                 return *this;
    364             }
    365 
    366             SIMDINLINE SIMDCALL operator __m512d() const { return v; }
    367 
    368             SIMDALIGN(__m512d, SIMD_ALIGNMENT_BYTES) v;
    369             SIMD256Impl::Double v8[2];
    370         };
    371 
    372         typedef SIMDImpl::Vec4<Float, Integer, Double> SIMDALIGN(Vec4, 64);
    373         using Mask = __mmask16;
    374 
    375         static const uint32_t SIMD_WIDTH = 16;
    376 
    377 #undef SIMD_ALIGNMENT_BYTES
    378     } // ns SIMD512Impl
    379 } // ns SIMDImpl
    380