1 /* 2 * Copyright 2016 Intel Corporation 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 */ 23 24 #include "vtn_private.h" 25 26 /* 27 * Normally, column vectors in SPIR-V correspond to a single NIR SSA 28 * definition. But for matrix multiplies, we want to do one routine for 29 * multiplying a matrix by a matrix and then pretend that vectors are matrices 30 * with one column. So we "wrap" these things, and unwrap the result before we 31 * send it off. 32 */ 33 34 static struct vtn_ssa_value * 35 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val) 36 { 37 if (val == NULL) 38 return NULL; 39 40 if (glsl_type_is_matrix(val->type)) 41 return val; 42 43 struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value); 44 dest->type = val->type; 45 dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1); 46 dest->elems[0] = val; 47 48 return dest; 49 } 50 51 static struct vtn_ssa_value * 52 unwrap_matrix(struct vtn_ssa_value *val) 53 { 54 if (glsl_type_is_matrix(val->type)) 55 return val; 56 57 return val->elems[0]; 58 } 59 60 static struct vtn_ssa_value * 61 matrix_multiply(struct vtn_builder *b, 62 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1) 63 { 64 65 struct vtn_ssa_value *src0 = wrap_matrix(b, _src0); 66 struct vtn_ssa_value *src1 = wrap_matrix(b, _src1); 67 struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed); 68 struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed); 69 70 unsigned src0_rows = glsl_get_vector_elements(src0->type); 71 unsigned src0_columns = glsl_get_matrix_columns(src0->type); 72 unsigned src1_columns = glsl_get_matrix_columns(src1->type); 73 74 const struct glsl_type *dest_type; 75 if (src1_columns > 1) { 76 dest_type = glsl_matrix_type(glsl_get_base_type(src0->type), 77 src0_rows, src1_columns); 78 } else { 79 dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows); 80 } 81 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type); 82 83 dest = wrap_matrix(b, dest); 84 85 bool transpose_result = false; 86 if (src0_transpose && src1_transpose) { 87 /* transpose(A) * transpose(B) = transpose(B * A) */ 88 src1 = src0_transpose; 89 src0 = src1_transpose; 90 src0_transpose = NULL; 91 src1_transpose = NULL; 92 transpose_result = true; 93 } 94 95 if (src0_transpose && !src1_transpose && 96 glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) { 97 /* We already have the rows of src0 and the columns of src1 available, 98 * so we can just take the dot product of each row with each column to 99 * get the result. 100 */ 101 102 for (unsigned i = 0; i < src1_columns; i++) { 103 nir_ssa_def *vec_src[4]; 104 for (unsigned j = 0; j < src0_rows; j++) { 105 vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def, 106 src1->elems[i]->def); 107 } 108 dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows); 109 } 110 } else { 111 /* We don't handle the case where src1 is transposed but not src0, since 112 * the general case only uses individual components of src1 so the 113 * optimizer should chew through the transpose we emitted for src1. 114 */ 115 116 for (unsigned i = 0; i < src1_columns; i++) { 117 /* dest[i] = sum(src0[j] * src1[i][j] for all j) */ 118 dest->elems[i]->def = 119 nir_fmul(&b->nb, src0->elems[0]->def, 120 nir_channel(&b->nb, src1->elems[i]->def, 0)); 121 for (unsigned j = 1; j < src0_columns; j++) { 122 dest->elems[i]->def = 123 nir_fadd(&b->nb, dest->elems[i]->def, 124 nir_fmul(&b->nb, src0->elems[j]->def, 125 nir_channel(&b->nb, src1->elems[i]->def, j))); 126 } 127 } 128 } 129 130 dest = unwrap_matrix(dest); 131 132 if (transpose_result) 133 dest = vtn_ssa_transpose(b, dest); 134 135 return dest; 136 } 137 138 static struct vtn_ssa_value * 139 mat_times_scalar(struct vtn_builder *b, 140 struct vtn_ssa_value *mat, 141 nir_ssa_def *scalar) 142 { 143 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type); 144 for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) { 145 if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT) 146 dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar); 147 else 148 dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar); 149 } 150 151 return dest; 152 } 153 154 static void 155 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode, 156 struct vtn_value *dest, 157 struct vtn_ssa_value *src0, struct vtn_ssa_value *src1) 158 { 159 switch (opcode) { 160 case SpvOpFNegate: { 161 dest->ssa = vtn_create_ssa_value(b, src0->type); 162 unsigned cols = glsl_get_matrix_columns(src0->type); 163 for (unsigned i = 0; i < cols; i++) 164 dest->ssa->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def); 165 break; 166 } 167 168 case SpvOpFAdd: { 169 dest->ssa = vtn_create_ssa_value(b, src0->type); 170 unsigned cols = glsl_get_matrix_columns(src0->type); 171 for (unsigned i = 0; i < cols; i++) 172 dest->ssa->elems[i]->def = 173 nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def); 174 break; 175 } 176 177 case SpvOpFSub: { 178 dest->ssa = vtn_create_ssa_value(b, src0->type); 179 unsigned cols = glsl_get_matrix_columns(src0->type); 180 for (unsigned i = 0; i < cols; i++) 181 dest->ssa->elems[i]->def = 182 nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def); 183 break; 184 } 185 186 case SpvOpTranspose: 187 dest->ssa = vtn_ssa_transpose(b, src0); 188 break; 189 190 case SpvOpMatrixTimesScalar: 191 if (src0->transposed) { 192 dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed, 193 src1->def)); 194 } else { 195 dest->ssa = mat_times_scalar(b, src0, src1->def); 196 } 197 break; 198 199 case SpvOpVectorTimesMatrix: 200 case SpvOpMatrixTimesVector: 201 case SpvOpMatrixTimesMatrix: 202 if (opcode == SpvOpVectorTimesMatrix) { 203 dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0); 204 } else { 205 dest->ssa = matrix_multiply(b, src0, src1); 206 } 207 break; 208 209 default: unreachable("unknown matrix opcode"); 210 } 211 } 212 213 nir_op 214 vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap, 215 nir_alu_type src, nir_alu_type dst) 216 { 217 /* Indicates that the first two arguments should be swapped. This is 218 * used for implementing greater-than and less-than-or-equal. 219 */ 220 *swap = false; 221 222 switch (opcode) { 223 case SpvOpSNegate: return nir_op_ineg; 224 case SpvOpFNegate: return nir_op_fneg; 225 case SpvOpNot: return nir_op_inot; 226 case SpvOpIAdd: return nir_op_iadd; 227 case SpvOpFAdd: return nir_op_fadd; 228 case SpvOpISub: return nir_op_isub; 229 case SpvOpFSub: return nir_op_fsub; 230 case SpvOpIMul: return nir_op_imul; 231 case SpvOpFMul: return nir_op_fmul; 232 case SpvOpUDiv: return nir_op_udiv; 233 case SpvOpSDiv: return nir_op_idiv; 234 case SpvOpFDiv: return nir_op_fdiv; 235 case SpvOpUMod: return nir_op_umod; 236 case SpvOpSMod: return nir_op_imod; 237 case SpvOpFMod: return nir_op_fmod; 238 case SpvOpSRem: return nir_op_irem; 239 case SpvOpFRem: return nir_op_frem; 240 241 case SpvOpShiftRightLogical: return nir_op_ushr; 242 case SpvOpShiftRightArithmetic: return nir_op_ishr; 243 case SpvOpShiftLeftLogical: return nir_op_ishl; 244 case SpvOpLogicalOr: return nir_op_ior; 245 case SpvOpLogicalEqual: return nir_op_ieq; 246 case SpvOpLogicalNotEqual: return nir_op_ine; 247 case SpvOpLogicalAnd: return nir_op_iand; 248 case SpvOpLogicalNot: return nir_op_inot; 249 case SpvOpBitwiseOr: return nir_op_ior; 250 case SpvOpBitwiseXor: return nir_op_ixor; 251 case SpvOpBitwiseAnd: return nir_op_iand; 252 case SpvOpSelect: return nir_op_bcsel; 253 case SpvOpIEqual: return nir_op_ieq; 254 255 case SpvOpBitFieldInsert: return nir_op_bitfield_insert; 256 case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract; 257 case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract; 258 case SpvOpBitReverse: return nir_op_bitfield_reverse; 259 case SpvOpBitCount: return nir_op_bit_count; 260 261 /* The ordered / unordered operators need special implementation besides 262 * the logical operator to use since they also need to check if operands are 263 * ordered. 264 */ 265 case SpvOpFOrdEqual: return nir_op_feq; 266 case SpvOpFUnordEqual: return nir_op_feq; 267 case SpvOpINotEqual: return nir_op_ine; 268 case SpvOpFOrdNotEqual: return nir_op_fne; 269 case SpvOpFUnordNotEqual: return nir_op_fne; 270 case SpvOpULessThan: return nir_op_ult; 271 case SpvOpSLessThan: return nir_op_ilt; 272 case SpvOpFOrdLessThan: return nir_op_flt; 273 case SpvOpFUnordLessThan: return nir_op_flt; 274 case SpvOpUGreaterThan: *swap = true; return nir_op_ult; 275 case SpvOpSGreaterThan: *swap = true; return nir_op_ilt; 276 case SpvOpFOrdGreaterThan: *swap = true; return nir_op_flt; 277 case SpvOpFUnordGreaterThan: *swap = true; return nir_op_flt; 278 case SpvOpULessThanEqual: *swap = true; return nir_op_uge; 279 case SpvOpSLessThanEqual: *swap = true; return nir_op_ige; 280 case SpvOpFOrdLessThanEqual: *swap = true; return nir_op_fge; 281 case SpvOpFUnordLessThanEqual: *swap = true; return nir_op_fge; 282 case SpvOpUGreaterThanEqual: return nir_op_uge; 283 case SpvOpSGreaterThanEqual: return nir_op_ige; 284 case SpvOpFOrdGreaterThanEqual: return nir_op_fge; 285 case SpvOpFUnordGreaterThanEqual: return nir_op_fge; 286 287 /* Conversions: */ 288 case SpvOpBitcast: return nir_op_imov; 289 case SpvOpUConvert: 290 case SpvOpQuantizeToF16: return nir_op_fquantize2f16; 291 case SpvOpConvertFToU: 292 case SpvOpConvertFToS: 293 case SpvOpConvertSToF: 294 case SpvOpConvertUToF: 295 case SpvOpSConvert: 296 case SpvOpFConvert: 297 return nir_type_conversion_op(src, dst); 298 299 /* Derivatives: */ 300 case SpvOpDPdx: return nir_op_fddx; 301 case SpvOpDPdy: return nir_op_fddy; 302 case SpvOpDPdxFine: return nir_op_fddx_fine; 303 case SpvOpDPdyFine: return nir_op_fddy_fine; 304 case SpvOpDPdxCoarse: return nir_op_fddx_coarse; 305 case SpvOpDPdyCoarse: return nir_op_fddy_coarse; 306 307 default: 308 unreachable("No NIR equivalent"); 309 } 310 } 311 312 static void 313 handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member, 314 const struct vtn_decoration *dec, void *_void) 315 { 316 assert(dec->scope == VTN_DEC_DECORATION); 317 if (dec->decoration != SpvDecorationNoContraction) 318 return; 319 320 b->nb.exact = true; 321 } 322 323 void 324 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, 325 const uint32_t *w, unsigned count) 326 { 327 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); 328 const struct glsl_type *type = 329 vtn_value(b, w[1], vtn_value_type_type)->type->type; 330 331 vtn_foreach_decoration(b, val, handle_no_contraction, NULL); 332 333 /* Collect the various SSA sources */ 334 const unsigned num_inputs = count - 3; 335 struct vtn_ssa_value *vtn_src[4] = { NULL, }; 336 for (unsigned i = 0; i < num_inputs; i++) 337 vtn_src[i] = vtn_ssa_value(b, w[i + 3]); 338 339 if (glsl_type_is_matrix(vtn_src[0]->type) || 340 (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) { 341 vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]); 342 b->nb.exact = false; 343 return; 344 } 345 346 val->ssa = vtn_create_ssa_value(b, type); 347 nir_ssa_def *src[4] = { NULL, }; 348 for (unsigned i = 0; i < num_inputs; i++) { 349 assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type)); 350 src[i] = vtn_src[i]->def; 351 } 352 353 switch (opcode) { 354 case SpvOpAny: 355 if (src[0]->num_components == 1) { 356 val->ssa->def = nir_imov(&b->nb, src[0]); 357 } else { 358 nir_op op; 359 switch (src[0]->num_components) { 360 case 2: op = nir_op_bany_inequal2; break; 361 case 3: op = nir_op_bany_inequal3; break; 362 case 4: op = nir_op_bany_inequal4; break; 363 default: unreachable("invalid number of components"); 364 } 365 val->ssa->def = nir_build_alu(&b->nb, op, src[0], 366 nir_imm_int(&b->nb, NIR_FALSE), 367 NULL, NULL); 368 } 369 break; 370 371 case SpvOpAll: 372 if (src[0]->num_components == 1) { 373 val->ssa->def = nir_imov(&b->nb, src[0]); 374 } else { 375 nir_op op; 376 switch (src[0]->num_components) { 377 case 2: op = nir_op_ball_iequal2; break; 378 case 3: op = nir_op_ball_iequal3; break; 379 case 4: op = nir_op_ball_iequal4; break; 380 default: unreachable("invalid number of components"); 381 } 382 val->ssa->def = nir_build_alu(&b->nb, op, src[0], 383 nir_imm_int(&b->nb, NIR_TRUE), 384 NULL, NULL); 385 } 386 break; 387 388 case SpvOpOuterProduct: { 389 for (unsigned i = 0; i < src[1]->num_components; i++) { 390 val->ssa->elems[i]->def = 391 nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i)); 392 } 393 break; 394 } 395 396 case SpvOpDot: 397 val->ssa->def = nir_fdot(&b->nb, src[0], src[1]); 398 break; 399 400 case SpvOpIAddCarry: 401 assert(glsl_type_is_struct(val->ssa->type)); 402 val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]); 403 val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]); 404 break; 405 406 case SpvOpISubBorrow: 407 assert(glsl_type_is_struct(val->ssa->type)); 408 val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]); 409 val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]); 410 break; 411 412 case SpvOpUMulExtended: 413 assert(glsl_type_is_struct(val->ssa->type)); 414 val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]); 415 val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]); 416 break; 417 418 case SpvOpSMulExtended: 419 assert(glsl_type_is_struct(val->ssa->type)); 420 val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]); 421 val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]); 422 break; 423 424 case SpvOpFwidth: 425 val->ssa->def = nir_fadd(&b->nb, 426 nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])), 427 nir_fabs(&b->nb, nir_fddy(&b->nb, src[0]))); 428 break; 429 case SpvOpFwidthFine: 430 val->ssa->def = nir_fadd(&b->nb, 431 nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])), 432 nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0]))); 433 break; 434 case SpvOpFwidthCoarse: 435 val->ssa->def = nir_fadd(&b->nb, 436 nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])), 437 nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0]))); 438 break; 439 440 case SpvOpVectorTimesScalar: 441 /* The builder will take care of splatting for us. */ 442 val->ssa->def = nir_fmul(&b->nb, src[0], src[1]); 443 break; 444 445 case SpvOpIsNan: 446 val->ssa->def = nir_fne(&b->nb, src[0], src[0]); 447 break; 448 449 case SpvOpIsInf: 450 val->ssa->def = nir_feq(&b->nb, nir_fabs(&b->nb, src[0]), 451 nir_imm_float(&b->nb, INFINITY)); 452 break; 453 454 case SpvOpFUnordEqual: 455 case SpvOpFUnordNotEqual: 456 case SpvOpFUnordLessThan: 457 case SpvOpFUnordGreaterThan: 458 case SpvOpFUnordLessThanEqual: 459 case SpvOpFUnordGreaterThanEqual: { 460 bool swap; 461 nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type); 462 nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type); 463 nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type); 464 465 if (swap) { 466 nir_ssa_def *tmp = src[0]; 467 src[0] = src[1]; 468 src[1] = tmp; 469 } 470 471 val->ssa->def = 472 nir_ior(&b->nb, 473 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL), 474 nir_ior(&b->nb, 475 nir_fne(&b->nb, src[0], src[0]), 476 nir_fne(&b->nb, src[1], src[1]))); 477 break; 478 } 479 480 case SpvOpFOrdEqual: 481 case SpvOpFOrdNotEqual: 482 case SpvOpFOrdLessThan: 483 case SpvOpFOrdGreaterThan: 484 case SpvOpFOrdLessThanEqual: 485 case SpvOpFOrdGreaterThanEqual: { 486 bool swap; 487 nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type); 488 nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type); 489 nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type); 490 491 if (swap) { 492 nir_ssa_def *tmp = src[0]; 493 src[0] = src[1]; 494 src[1] = tmp; 495 } 496 497 val->ssa->def = 498 nir_iand(&b->nb, 499 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL), 500 nir_iand(&b->nb, 501 nir_feq(&b->nb, src[0], src[0]), 502 nir_feq(&b->nb, src[1], src[1]))); 503 break; 504 } 505 506 default: { 507 bool swap; 508 nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type); 509 nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type); 510 nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type); 511 512 if (swap) { 513 nir_ssa_def *tmp = src[0]; 514 src[0] = src[1]; 515 src[1] = tmp; 516 } 517 518 val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]); 519 break; 520 } /* default */ 521 } 522 523 b->nb.exact = false; 524 } 525