1 /* 2 * Copyright 2014-2015 Broadcom 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 */ 23 24 #include "nir.h" 25 #include "nir_builder.h" 26 27 /** @file nir_lower_alu_to_scalar.c 28 * 29 * Replaces nir_alu_instr operations with more than one channel used in the 30 * arguments with individual per-channel operations. 31 */ 32 33 static void 34 nir_alu_ssa_dest_init(nir_alu_instr *instr, unsigned num_components, 35 unsigned bit_size) 36 { 37 nir_ssa_dest_init(&instr->instr, &instr->dest.dest, num_components, 38 bit_size, NULL); 39 instr->dest.write_mask = (1 << num_components) - 1; 40 } 41 42 static void 43 lower_reduction(nir_alu_instr *instr, nir_op chan_op, nir_op merge_op, 44 nir_builder *builder) 45 { 46 unsigned num_components = nir_op_infos[instr->op].input_sizes[0]; 47 48 nir_ssa_def *last = NULL; 49 for (unsigned i = 0; i < num_components; i++) { 50 nir_alu_instr *chan = nir_alu_instr_create(builder->shader, chan_op); 51 nir_alu_ssa_dest_init(chan, 1, instr->dest.dest.ssa.bit_size); 52 nir_alu_src_copy(&chan->src[0], &instr->src[0], chan); 53 chan->src[0].swizzle[0] = chan->src[0].swizzle[i]; 54 if (nir_op_infos[chan_op].num_inputs > 1) { 55 assert(nir_op_infos[chan_op].num_inputs == 2); 56 nir_alu_src_copy(&chan->src[1], &instr->src[1], chan); 57 chan->src[1].swizzle[0] = chan->src[1].swizzle[i]; 58 } 59 chan->exact = instr->exact; 60 61 nir_builder_instr_insert(builder, &chan->instr); 62 63 if (i == 0) { 64 last = &chan->dest.dest.ssa; 65 } else { 66 last = nir_build_alu(builder, merge_op, 67 last, &chan->dest.dest.ssa, NULL, NULL); 68 } 69 } 70 71 assert(instr->dest.write_mask == 1); 72 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(last)); 73 nir_instr_remove(&instr->instr); 74 } 75 76 static bool 77 lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b) 78 { 79 unsigned num_src = nir_op_infos[instr->op].num_inputs; 80 unsigned i, chan; 81 82 assert(instr->dest.dest.is_ssa); 83 assert(instr->dest.write_mask != 0); 84 85 b->cursor = nir_before_instr(&instr->instr); 86 b->exact = instr->exact; 87 88 #define LOWER_REDUCTION(name, chan, merge) \ 89 case name##2: \ 90 case name##3: \ 91 case name##4: \ 92 lower_reduction(instr, chan, merge, b); \ 93 return true; 94 95 switch (instr->op) { 96 case nir_op_vec4: 97 case nir_op_vec3: 98 case nir_op_vec2: 99 /* We don't need to scalarize these ops, they're the ones generated to 100 * group up outputs into a value that can be SSAed. 101 */ 102 return false; 103 104 case nir_op_pack_half_2x16: 105 if (!b->shader->options->lower_pack_half_2x16) 106 return false; 107 108 nir_ssa_def *val = 109 nir_pack_half_2x16_split(b, nir_channel(b, instr->src[0].src.ssa, 110 instr->src[0].swizzle[0]), 111 nir_channel(b, instr->src[0].src.ssa, 112 instr->src[0].swizzle[1])); 113 114 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(val)); 115 nir_instr_remove(&instr->instr); 116 return true; 117 118 case nir_op_unpack_unorm_4x8: 119 case nir_op_unpack_snorm_4x8: 120 case nir_op_unpack_unorm_2x16: 121 case nir_op_unpack_snorm_2x16: 122 /* There is no scalar version of these ops, unless we were to break it 123 * down to bitshifts and math (which is definitely not intended). 124 */ 125 return false; 126 127 case nir_op_unpack_half_2x16: { 128 if (!b->shader->options->lower_unpack_half_2x16) 129 return false; 130 131 nir_ssa_def *comps[2]; 132 comps[0] = nir_unpack_half_2x16_split_x(b, instr->src[0].src.ssa); 133 comps[1] = nir_unpack_half_2x16_split_y(b, instr->src[0].src.ssa); 134 nir_ssa_def *vec = nir_vec(b, comps, 2); 135 136 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(vec)); 137 nir_instr_remove(&instr->instr); 138 return true; 139 } 140 141 case nir_op_pack_uvec2_to_uint: { 142 assert(b->shader->options->lower_pack_snorm_2x16 || 143 b->shader->options->lower_pack_unorm_2x16); 144 145 nir_ssa_def *word = 146 nir_extract_u16(b, instr->src[0].src.ssa, nir_imm_int(b, 0)); 147 nir_ssa_def *val = 148 nir_ior(b, nir_ishl(b, nir_channel(b, word, 1), nir_imm_int(b, 16)), 149 nir_channel(b, word, 0)); 150 151 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(val)); 152 nir_instr_remove(&instr->instr); 153 break; 154 } 155 156 case nir_op_pack_uvec4_to_uint: { 157 assert(b->shader->options->lower_pack_snorm_4x8 || 158 b->shader->options->lower_pack_unorm_4x8); 159 160 nir_ssa_def *byte = 161 nir_extract_u8(b, instr->src[0].src.ssa, nir_imm_int(b, 0)); 162 nir_ssa_def *val = 163 nir_ior(b, nir_ior(b, nir_ishl(b, nir_channel(b, byte, 3), nir_imm_int(b, 24)), 164 nir_ishl(b, nir_channel(b, byte, 2), nir_imm_int(b, 16))), 165 nir_ior(b, nir_ishl(b, nir_channel(b, byte, 1), nir_imm_int(b, 8)), 166 nir_channel(b, byte, 0))); 167 168 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(val)); 169 nir_instr_remove(&instr->instr); 170 break; 171 } 172 173 case nir_op_fdph: { 174 nir_ssa_def *sum[4]; 175 for (unsigned i = 0; i < 3; i++) { 176 sum[i] = nir_fmul(b, nir_channel(b, instr->src[0].src.ssa, 177 instr->src[0].swizzle[i]), 178 nir_channel(b, instr->src[1].src.ssa, 179 instr->src[1].swizzle[i])); 180 } 181 sum[3] = nir_channel(b, instr->src[1].src.ssa, instr->src[1].swizzle[3]); 182 183 nir_ssa_def *val = nir_fadd(b, nir_fadd(b, sum[0], sum[1]), 184 nir_fadd(b, sum[2], sum[3])); 185 186 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(val)); 187 nir_instr_remove(&instr->instr); 188 return true; 189 } 190 191 case nir_op_unpack_double_2x32: 192 return false; 193 194 LOWER_REDUCTION(nir_op_fdot, nir_op_fmul, nir_op_fadd); 195 LOWER_REDUCTION(nir_op_ball_fequal, nir_op_feq, nir_op_iand); 196 LOWER_REDUCTION(nir_op_ball_iequal, nir_op_ieq, nir_op_iand); 197 LOWER_REDUCTION(nir_op_bany_fnequal, nir_op_fne, nir_op_ior); 198 LOWER_REDUCTION(nir_op_bany_inequal, nir_op_ine, nir_op_ior); 199 LOWER_REDUCTION(nir_op_fall_equal, nir_op_seq, nir_op_fand); 200 LOWER_REDUCTION(nir_op_fany_nequal, nir_op_sne, nir_op_for); 201 202 default: 203 break; 204 } 205 206 if (instr->dest.dest.ssa.num_components == 1) 207 return false; 208 209 unsigned num_components = instr->dest.dest.ssa.num_components; 210 nir_ssa_def *comps[] = { NULL, NULL, NULL, NULL }; 211 212 for (chan = 0; chan < 4; chan++) { 213 if (!(instr->dest.write_mask & (1 << chan))) 214 continue; 215 216 nir_alu_instr *lower = nir_alu_instr_create(b->shader, instr->op); 217 for (i = 0; i < num_src; i++) { 218 /* We only handle same-size-as-dest (input_sizes[] == 0) or scalar 219 * args (input_sizes[] == 1). 220 */ 221 assert(nir_op_infos[instr->op].input_sizes[i] < 2); 222 unsigned src_chan = (nir_op_infos[instr->op].input_sizes[i] == 1 ? 223 0 : chan); 224 225 nir_alu_src_copy(&lower->src[i], &instr->src[i], lower); 226 for (int j = 0; j < 4; j++) 227 lower->src[i].swizzle[j] = instr->src[i].swizzle[src_chan]; 228 } 229 230 nir_alu_ssa_dest_init(lower, 1, instr->dest.dest.ssa.bit_size); 231 lower->dest.saturate = instr->dest.saturate; 232 comps[chan] = &lower->dest.dest.ssa; 233 lower->exact = instr->exact; 234 235 nir_builder_instr_insert(b, &lower->instr); 236 } 237 238 nir_ssa_def *vec = nir_vec(b, comps, num_components); 239 240 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(vec)); 241 242 nir_instr_remove(&instr->instr); 243 return true; 244 } 245 246 static bool 247 nir_lower_alu_to_scalar_impl(nir_function_impl *impl) 248 { 249 nir_builder builder; 250 nir_builder_init(&builder, impl); 251 bool progress = false; 252 253 nir_foreach_block(block, impl) { 254 nir_foreach_instr_safe(instr, block) { 255 if (instr->type == nir_instr_type_alu) { 256 progress = lower_alu_instr_scalar(nir_instr_as_alu(instr), 257 &builder) || progress; 258 } 259 } 260 } 261 262 nir_metadata_preserve(impl, nir_metadata_block_index | 263 nir_metadata_dominance); 264 265 return progress; 266 } 267 268 bool 269 nir_lower_alu_to_scalar(nir_shader *shader) 270 { 271 bool progress = false; 272 273 nir_foreach_function(function, shader) { 274 if (function->impl) 275 progress = nir_lower_alu_to_scalar_impl(function->impl) || progress; 276 } 277 278 return progress; 279 } 280