Home | History | Annotate | Download | only in spirv
      1 /*
      2  * Copyright  2016 Intel Corporation
      3  *
      4  * Permission is hereby granted, free of charge, to any person obtaining a
      5  * copy of this software and associated documentation files (the "Software"),
      6  * to deal in the Software without restriction, including without limitation
      7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
      8  * and/or sell copies of the Software, and to permit persons to whom the
      9  * Software is furnished to do so, subject to the following conditions:
     10  *
     11  * The above copyright notice and this permission notice (including the next
     12  * paragraph) shall be included in all copies or substantial portions of the
     13  * Software.
     14  *
     15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
     17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
     18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
     19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
     20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
     21  * IN THE SOFTWARE.
     22  */
     23 
     24 #include "vtn_private.h"
     25 
     26 /*
     27  * Normally, column vectors in SPIR-V correspond to a single NIR SSA
     28  * definition. But for matrix multiplies, we want to do one routine for
     29  * multiplying a matrix by a matrix and then pretend that vectors are matrices
     30  * with one column. So we "wrap" these things, and unwrap the result before we
     31  * send it off.
     32  */
     33 
     34 static struct vtn_ssa_value *
     35 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
     36 {
     37    if (val == NULL)
     38       return NULL;
     39 
     40    if (glsl_type_is_matrix(val->type))
     41       return val;
     42 
     43    struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
     44    dest->type = val->type;
     45    dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
     46    dest->elems[0] = val;
     47 
     48    return dest;
     49 }
     50 
     51 static struct vtn_ssa_value *
     52 unwrap_matrix(struct vtn_ssa_value *val)
     53 {
     54    if (glsl_type_is_matrix(val->type))
     55          return val;
     56 
     57    return val->elems[0];
     58 }
     59 
     60 static struct vtn_ssa_value *
     61 matrix_multiply(struct vtn_builder *b,
     62                 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
     63 {
     64 
     65    struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
     66    struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
     67    struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
     68    struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
     69 
     70    unsigned src0_rows = glsl_get_vector_elements(src0->type);
     71    unsigned src0_columns = glsl_get_matrix_columns(src0->type);
     72    unsigned src1_columns = glsl_get_matrix_columns(src1->type);
     73 
     74    const struct glsl_type *dest_type;
     75    if (src1_columns > 1) {
     76       dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
     77                                    src0_rows, src1_columns);
     78    } else {
     79       dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
     80    }
     81    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
     82 
     83    dest = wrap_matrix(b, dest);
     84 
     85    bool transpose_result = false;
     86    if (src0_transpose && src1_transpose) {
     87       /* transpose(A) * transpose(B) = transpose(B * A) */
     88       src1 = src0_transpose;
     89       src0 = src1_transpose;
     90       src0_transpose = NULL;
     91       src1_transpose = NULL;
     92       transpose_result = true;
     93    }
     94 
     95    if (src0_transpose && !src1_transpose &&
     96        glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
     97       /* We already have the rows of src0 and the columns of src1 available,
     98        * so we can just take the dot product of each row with each column to
     99        * get the result.
    100        */
    101 
    102       for (unsigned i = 0; i < src1_columns; i++) {
    103          nir_ssa_def *vec_src[4];
    104          for (unsigned j = 0; j < src0_rows; j++) {
    105             vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
    106                                           src1->elems[i]->def);
    107          }
    108          dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
    109       }
    110    } else {
    111       /* We don't handle the case where src1 is transposed but not src0, since
    112        * the general case only uses individual components of src1 so the
    113        * optimizer should chew through the transpose we emitted for src1.
    114        */
    115 
    116       for (unsigned i = 0; i < src1_columns; i++) {
    117          /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
    118          dest->elems[i]->def =
    119             nir_fmul(&b->nb, src0->elems[0]->def,
    120                      nir_channel(&b->nb, src1->elems[i]->def, 0));
    121          for (unsigned j = 1; j < src0_columns; j++) {
    122             dest->elems[i]->def =
    123                nir_fadd(&b->nb, dest->elems[i]->def,
    124                         nir_fmul(&b->nb, src0->elems[j]->def,
    125                                  nir_channel(&b->nb, src1->elems[i]->def, j)));
    126          }
    127       }
    128    }
    129 
    130    dest = unwrap_matrix(dest);
    131 
    132    if (transpose_result)
    133       dest = vtn_ssa_transpose(b, dest);
    134 
    135    return dest;
    136 }
    137 
    138 static struct vtn_ssa_value *
    139 mat_times_scalar(struct vtn_builder *b,
    140                  struct vtn_ssa_value *mat,
    141                  nir_ssa_def *scalar)
    142 {
    143    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
    144    for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
    145       if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT)
    146          dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
    147       else
    148          dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
    149    }
    150 
    151    return dest;
    152 }
    153 
    154 static void
    155 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
    156                       struct vtn_value *dest,
    157                       struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
    158 {
    159    switch (opcode) {
    160    case SpvOpFNegate: {
    161       dest->ssa = vtn_create_ssa_value(b, src0->type);
    162       unsigned cols = glsl_get_matrix_columns(src0->type);
    163       for (unsigned i = 0; i < cols; i++)
    164          dest->ssa->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
    165       break;
    166    }
    167 
    168    case SpvOpFAdd: {
    169       dest->ssa = vtn_create_ssa_value(b, src0->type);
    170       unsigned cols = glsl_get_matrix_columns(src0->type);
    171       for (unsigned i = 0; i < cols; i++)
    172          dest->ssa->elems[i]->def =
    173             nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
    174       break;
    175    }
    176 
    177    case SpvOpFSub: {
    178       dest->ssa = vtn_create_ssa_value(b, src0->type);
    179       unsigned cols = glsl_get_matrix_columns(src0->type);
    180       for (unsigned i = 0; i < cols; i++)
    181          dest->ssa->elems[i]->def =
    182             nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
    183       break;
    184    }
    185 
    186    case SpvOpTranspose:
    187       dest->ssa = vtn_ssa_transpose(b, src0);
    188       break;
    189 
    190    case SpvOpMatrixTimesScalar:
    191       if (src0->transposed) {
    192          dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
    193                                                            src1->def));
    194       } else {
    195          dest->ssa = mat_times_scalar(b, src0, src1->def);
    196       }
    197       break;
    198 
    199    case SpvOpVectorTimesMatrix:
    200    case SpvOpMatrixTimesVector:
    201    case SpvOpMatrixTimesMatrix:
    202       if (opcode == SpvOpVectorTimesMatrix) {
    203          dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
    204       } else {
    205          dest->ssa = matrix_multiply(b, src0, src1);
    206       }
    207       break;
    208 
    209    default: unreachable("unknown matrix opcode");
    210    }
    211 }
    212 
    213 nir_op
    214 vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap,
    215                                 nir_alu_type src, nir_alu_type dst)
    216 {
    217    /* Indicates that the first two arguments should be swapped.  This is
    218     * used for implementing greater-than and less-than-or-equal.
    219     */
    220    *swap = false;
    221 
    222    switch (opcode) {
    223    case SpvOpSNegate:            return nir_op_ineg;
    224    case SpvOpFNegate:            return nir_op_fneg;
    225    case SpvOpNot:                return nir_op_inot;
    226    case SpvOpIAdd:               return nir_op_iadd;
    227    case SpvOpFAdd:               return nir_op_fadd;
    228    case SpvOpISub:               return nir_op_isub;
    229    case SpvOpFSub:               return nir_op_fsub;
    230    case SpvOpIMul:               return nir_op_imul;
    231    case SpvOpFMul:               return nir_op_fmul;
    232    case SpvOpUDiv:               return nir_op_udiv;
    233    case SpvOpSDiv:               return nir_op_idiv;
    234    case SpvOpFDiv:               return nir_op_fdiv;
    235    case SpvOpUMod:               return nir_op_umod;
    236    case SpvOpSMod:               return nir_op_imod;
    237    case SpvOpFMod:               return nir_op_fmod;
    238    case SpvOpSRem:               return nir_op_irem;
    239    case SpvOpFRem:               return nir_op_frem;
    240 
    241    case SpvOpShiftRightLogical:     return nir_op_ushr;
    242    case SpvOpShiftRightArithmetic:  return nir_op_ishr;
    243    case SpvOpShiftLeftLogical:      return nir_op_ishl;
    244    case SpvOpLogicalOr:             return nir_op_ior;
    245    case SpvOpLogicalEqual:          return nir_op_ieq;
    246    case SpvOpLogicalNotEqual:       return nir_op_ine;
    247    case SpvOpLogicalAnd:            return nir_op_iand;
    248    case SpvOpLogicalNot:            return nir_op_inot;
    249    case SpvOpBitwiseOr:             return nir_op_ior;
    250    case SpvOpBitwiseXor:            return nir_op_ixor;
    251    case SpvOpBitwiseAnd:            return nir_op_iand;
    252    case SpvOpSelect:                return nir_op_bcsel;
    253    case SpvOpIEqual:                return nir_op_ieq;
    254 
    255    case SpvOpBitFieldInsert:        return nir_op_bitfield_insert;
    256    case SpvOpBitFieldSExtract:      return nir_op_ibitfield_extract;
    257    case SpvOpBitFieldUExtract:      return nir_op_ubitfield_extract;
    258    case SpvOpBitReverse:            return nir_op_bitfield_reverse;
    259    case SpvOpBitCount:              return nir_op_bit_count;
    260 
    261    /* The ordered / unordered operators need special implementation besides
    262     * the logical operator to use since they also need to check if operands are
    263     * ordered.
    264     */
    265    case SpvOpFOrdEqual:                            return nir_op_feq;
    266    case SpvOpFUnordEqual:                          return nir_op_feq;
    267    case SpvOpINotEqual:                            return nir_op_ine;
    268    case SpvOpFOrdNotEqual:                         return nir_op_fne;
    269    case SpvOpFUnordNotEqual:                       return nir_op_fne;
    270    case SpvOpULessThan:                            return nir_op_ult;
    271    case SpvOpSLessThan:                            return nir_op_ilt;
    272    case SpvOpFOrdLessThan:                         return nir_op_flt;
    273    case SpvOpFUnordLessThan:                       return nir_op_flt;
    274    case SpvOpUGreaterThan:          *swap = true;  return nir_op_ult;
    275    case SpvOpSGreaterThan:          *swap = true;  return nir_op_ilt;
    276    case SpvOpFOrdGreaterThan:       *swap = true;  return nir_op_flt;
    277    case SpvOpFUnordGreaterThan:     *swap = true;  return nir_op_flt;
    278    case SpvOpULessThanEqual:        *swap = true;  return nir_op_uge;
    279    case SpvOpSLessThanEqual:        *swap = true;  return nir_op_ige;
    280    case SpvOpFOrdLessThanEqual:     *swap = true;  return nir_op_fge;
    281    case SpvOpFUnordLessThanEqual:   *swap = true;  return nir_op_fge;
    282    case SpvOpUGreaterThanEqual:                    return nir_op_uge;
    283    case SpvOpSGreaterThanEqual:                    return nir_op_ige;
    284    case SpvOpFOrdGreaterThanEqual:                 return nir_op_fge;
    285    case SpvOpFUnordGreaterThanEqual:               return nir_op_fge;
    286 
    287    /* Conversions: */
    288    case SpvOpBitcast:               return nir_op_imov;
    289    case SpvOpUConvert:
    290    case SpvOpQuantizeToF16:         return nir_op_fquantize2f16;
    291    case SpvOpConvertFToU:
    292    case SpvOpConvertFToS:
    293    case SpvOpConvertSToF:
    294    case SpvOpConvertUToF:
    295    case SpvOpSConvert:
    296    case SpvOpFConvert:
    297       return nir_type_conversion_op(src, dst);
    298 
    299    /* Derivatives: */
    300    case SpvOpDPdx:         return nir_op_fddx;
    301    case SpvOpDPdy:         return nir_op_fddy;
    302    case SpvOpDPdxFine:     return nir_op_fddx_fine;
    303    case SpvOpDPdyFine:     return nir_op_fddy_fine;
    304    case SpvOpDPdxCoarse:   return nir_op_fddx_coarse;
    305    case SpvOpDPdyCoarse:   return nir_op_fddy_coarse;
    306 
    307    default:
    308       unreachable("No NIR equivalent");
    309    }
    310 }
    311 
    312 static void
    313 handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
    314                       const struct vtn_decoration *dec, void *_void)
    315 {
    316    assert(dec->scope == VTN_DEC_DECORATION);
    317    if (dec->decoration != SpvDecorationNoContraction)
    318       return;
    319 
    320    b->nb.exact = true;
    321 }
    322 
    323 void
    324 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
    325                const uint32_t *w, unsigned count)
    326 {
    327    struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
    328    const struct glsl_type *type =
    329       vtn_value(b, w[1], vtn_value_type_type)->type->type;
    330 
    331    vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
    332 
    333    /* Collect the various SSA sources */
    334    const unsigned num_inputs = count - 3;
    335    struct vtn_ssa_value *vtn_src[4] = { NULL, };
    336    for (unsigned i = 0; i < num_inputs; i++)
    337       vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
    338 
    339    if (glsl_type_is_matrix(vtn_src[0]->type) ||
    340        (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
    341       vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]);
    342       b->nb.exact = false;
    343       return;
    344    }
    345 
    346    val->ssa = vtn_create_ssa_value(b, type);
    347    nir_ssa_def *src[4] = { NULL, };
    348    for (unsigned i = 0; i < num_inputs; i++) {
    349       assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
    350       src[i] = vtn_src[i]->def;
    351    }
    352 
    353    switch (opcode) {
    354    case SpvOpAny:
    355       if (src[0]->num_components == 1) {
    356          val->ssa->def = nir_imov(&b->nb, src[0]);
    357       } else {
    358          nir_op op;
    359          switch (src[0]->num_components) {
    360          case 2:  op = nir_op_bany_inequal2; break;
    361          case 3:  op = nir_op_bany_inequal3; break;
    362          case 4:  op = nir_op_bany_inequal4; break;
    363          default: unreachable("invalid number of components");
    364          }
    365          val->ssa->def = nir_build_alu(&b->nb, op, src[0],
    366                                        nir_imm_int(&b->nb, NIR_FALSE),
    367                                        NULL, NULL);
    368       }
    369       break;
    370 
    371    case SpvOpAll:
    372       if (src[0]->num_components == 1) {
    373          val->ssa->def = nir_imov(&b->nb, src[0]);
    374       } else {
    375          nir_op op;
    376          switch (src[0]->num_components) {
    377          case 2:  op = nir_op_ball_iequal2;  break;
    378          case 3:  op = nir_op_ball_iequal3;  break;
    379          case 4:  op = nir_op_ball_iequal4;  break;
    380          default: unreachable("invalid number of components");
    381          }
    382          val->ssa->def = nir_build_alu(&b->nb, op, src[0],
    383                                        nir_imm_int(&b->nb, NIR_TRUE),
    384                                        NULL, NULL);
    385       }
    386       break;
    387 
    388    case SpvOpOuterProduct: {
    389       for (unsigned i = 0; i < src[1]->num_components; i++) {
    390          val->ssa->elems[i]->def =
    391             nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
    392       }
    393       break;
    394    }
    395 
    396    case SpvOpDot:
    397       val->ssa->def = nir_fdot(&b->nb, src[0], src[1]);
    398       break;
    399 
    400    case SpvOpIAddCarry:
    401       assert(glsl_type_is_struct(val->ssa->type));
    402       val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
    403       val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
    404       break;
    405 
    406    case SpvOpISubBorrow:
    407       assert(glsl_type_is_struct(val->ssa->type));
    408       val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
    409       val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
    410       break;
    411 
    412    case SpvOpUMulExtended:
    413       assert(glsl_type_is_struct(val->ssa->type));
    414       val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
    415       val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
    416       break;
    417 
    418    case SpvOpSMulExtended:
    419       assert(glsl_type_is_struct(val->ssa->type));
    420       val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
    421       val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
    422       break;
    423 
    424    case SpvOpFwidth:
    425       val->ssa->def = nir_fadd(&b->nb,
    426                                nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
    427                                nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
    428       break;
    429    case SpvOpFwidthFine:
    430       val->ssa->def = nir_fadd(&b->nb,
    431                                nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
    432                                nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
    433       break;
    434    case SpvOpFwidthCoarse:
    435       val->ssa->def = nir_fadd(&b->nb,
    436                                nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
    437                                nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
    438       break;
    439 
    440    case SpvOpVectorTimesScalar:
    441       /* The builder will take care of splatting for us. */
    442       val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
    443       break;
    444 
    445    case SpvOpIsNan:
    446       val->ssa->def = nir_fne(&b->nb, src[0], src[0]);
    447       break;
    448 
    449    case SpvOpIsInf:
    450       val->ssa->def = nir_feq(&b->nb, nir_fabs(&b->nb, src[0]),
    451                                       nir_imm_float(&b->nb, INFINITY));
    452       break;
    453 
    454    case SpvOpFUnordEqual:
    455    case SpvOpFUnordNotEqual:
    456    case SpvOpFUnordLessThan:
    457    case SpvOpFUnordGreaterThan:
    458    case SpvOpFUnordLessThanEqual:
    459    case SpvOpFUnordGreaterThanEqual: {
    460       bool swap;
    461       nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
    462       nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
    463       nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type);
    464 
    465       if (swap) {
    466          nir_ssa_def *tmp = src[0];
    467          src[0] = src[1];
    468          src[1] = tmp;
    469       }
    470 
    471       val->ssa->def =
    472          nir_ior(&b->nb,
    473                  nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
    474                  nir_ior(&b->nb,
    475                          nir_fne(&b->nb, src[0], src[0]),
    476                          nir_fne(&b->nb, src[1], src[1])));
    477       break;
    478    }
    479 
    480    case SpvOpFOrdEqual:
    481    case SpvOpFOrdNotEqual:
    482    case SpvOpFOrdLessThan:
    483    case SpvOpFOrdGreaterThan:
    484    case SpvOpFOrdLessThanEqual:
    485    case SpvOpFOrdGreaterThanEqual: {
    486       bool swap;
    487       nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
    488       nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
    489       nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type);
    490 
    491       if (swap) {
    492          nir_ssa_def *tmp = src[0];
    493          src[0] = src[1];
    494          src[1] = tmp;
    495       }
    496 
    497       val->ssa->def =
    498          nir_iand(&b->nb,
    499                   nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
    500                   nir_iand(&b->nb,
    501                           nir_feq(&b->nb, src[0], src[0]),
    502                           nir_feq(&b->nb, src[1], src[1])));
    503       break;
    504    }
    505 
    506    default: {
    507       bool swap;
    508       nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
    509       nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
    510       nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type);
    511 
    512       if (swap) {
    513          nir_ssa_def *tmp = src[0];
    514          src[0] = src[1];
    515          src[1] = tmp;
    516       }
    517 
    518       val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
    519       break;
    520    } /* default */
    521    }
    522 
    523    b->nb.exact = false;
    524 }
    525