1 """Mul primitive used by the GEMM function. 2 3 The Mul primitive takes 1-3 zipped rows and 1-3 zipped columns and performs 4 matrix multiplication on those resulting in a small 1x1 to 3x3 block of results. 5 """ 6 7 import neon_emitter 8 9 10 class Error(Exception): 11 """Module level error.""" 12 13 14 class ConfigurationError(Error): 15 """Unsupported configuration.""" 16 17 18 class MulLanes(object): 19 20 def __init__(self, input_address): 21 self.input_address = input_address 22 self.lanes = [] 23 24 def AddLane(self, lane): 25 self.lanes.append(lane) 26 27 def FreeRegisters(self, registers): 28 for i in range(0, len(self.lanes)): 29 registers.FreeRegister(self.lanes[i]) 30 self.lanes[i] = None 31 32 33 def GenerateMulLanes(registers, lane_count, address): 34 lanes = MulLanes(address) 35 for unused_i in range(0, lane_count): 36 lanes.AddLane(registers.DoubleRegister()) 37 return lanes 38 39 40 def Generate3MulLanes(quad_register, registers, address): 41 lanes = MulLanes(address) 42 lanes.AddLane(registers.Low(quad_register)) 43 lanes.AddLane(registers.High(quad_register)) 44 lanes.AddLane(registers.DoubleRegister()) 45 return lanes 46 47 48 def GenerateAndClearAggregators(emitter, registers, aggregator_count): 49 """Prepare aggregators and emit aggregator clear code.""" 50 emitter.EmitComment('Clear aggregators.') 51 aggregators = [] 52 for i in range(0, aggregator_count): 53 aggregator = registers.QuadRegister() 54 aggregators.append(aggregator) 55 if i < 3: 56 emitter.EmitVMov('i32', aggregator, emitter.ImmediateConstant(0)) 57 else: 58 emitter.EmitVMov('i32', aggregator, aggregators[i - 3]) 59 emitter.EmitNewline() 60 return aggregators 61 62 63 def GenerateNxMLoadMultiplyAggregate(emitter, registers, left_lanes, 64 right_lanes, aggregators, count): 65 """Emit inner loop for N rows x M cols multiplication.""" 66 emitter.EmitComment('General NxM lanes loop.') 67 emitter.EmitNumericalLabel(1) 68 emitter.EmitNewline() 69 emitter.EmitComment('Subtract counter.') 70 emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) 71 emitter.EmitNewline() 72 73 emitter.EmitVLoadA('1.8', left_lanes.lanes, 74 emitter.DereferenceIncrement(left_lanes.input_address, 64)) 75 emitter.EmitVLoadA( 76 '1.8', right_lanes.lanes, 77 emitter.DereferenceIncrement(right_lanes.input_address, 64)) 78 79 emitter.EmitPldOffset(left_lanes.input_address, emitter.ImmediateConstant(64)) 80 emitter.EmitPldOffset(right_lanes.input_address, 81 emitter.ImmediateConstant(64)) 82 83 rows = len(left_lanes.lanes) 84 cols = len(right_lanes.lanes) 85 86 multiply_results = [] 87 for i in range(0, rows * cols): 88 multiply_results.append(registers.QuadRegister()) 89 90 for row in range(0, rows): 91 for col in range(0, cols): 92 index = row * cols + col 93 emitter.EmitVMull('u8', multiply_results[index], right_lanes.lanes[col], 94 left_lanes.lanes[row]) 95 96 for i in range(0, rows * cols): 97 emitter.EmitVPadal('u16', aggregators[i], multiply_results[i]) 98 99 emitter.EmitNewline() 100 emitter.EmitComment('Loop break.') 101 emitter.EmitBneBack(1) 102 emitter.EmitNewline() 103 104 for register in multiply_results: 105 registers.FreeRegister(register) 106 107 108 def Generate3x3LoadMultiplyAggregate(emitter, registers, left_lanes, 109 right_lanes, aggregators, count, 110 backup_register): 111 """Emit inner loop for 3 rows x 3 cols multiplication (register trick).""" 112 emitter.EmitComment('3x3 lanes loop.') 113 emitter.EmitNumericalLabel(1) 114 emitter.EmitNewline() 115 emitter.EmitComment('Subtract counter.') 116 emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) 117 emitter.EmitNewline() 118 119 emitter.EmitVLoadA('1.8', left_lanes.lanes, 120 emitter.DereferenceIncrement(left_lanes.input_address, 64)) 121 emitter.EmitVLoadA( 122 '1.8', right_lanes.lanes, 123 emitter.DereferenceIncrement(right_lanes.input_address, 64)) 124 125 emitter.EmitPldOffset(left_lanes.input_address, emitter.ImmediateConstant(64)) 126 emitter.EmitPldOffset(right_lanes.input_address, 127 emitter.ImmediateConstant(64)) 128 129 temp = [] 130 for unused_i in range(0, 4): 131 temp.append(registers.QuadRegister()) 132 133 emitter.EmitVMull('u8', temp[0], left_lanes.lanes[0], right_lanes.lanes[0]) 134 emitter.EmitVMull('u8', temp[1], left_lanes.lanes[0], right_lanes.lanes[1]) 135 emitter.EmitVMull('u8', temp[2], left_lanes.lanes[0], right_lanes.lanes[2]) 136 emitter.EmitVMull('u8', temp[3], left_lanes.lanes[1], right_lanes.lanes[0]) 137 138 emitter.EmitVPadal('u16', aggregators[0], temp[0]) 139 emitter.EmitVPadal('u16', aggregators[1], temp[1]) 140 emitter.EmitVPadal('u16', aggregators[2], temp[2]) 141 emitter.EmitVPadal('u16', aggregators[3], temp[3]) 142 143 emitter.EmitVMull('u8', temp[0], left_lanes.lanes[1], right_lanes.lanes[1]) 144 emitter.EmitVMull('u8', temp[1], left_lanes.lanes[1], right_lanes.lanes[2]) 145 emitter.EmitVMull('u8', temp[2], left_lanes.lanes[2], right_lanes.lanes[0]) 146 emitter.EmitVMull('u8', temp[3], left_lanes.lanes[2], right_lanes.lanes[1]) 147 emitter.EmitVMull('u8', backup_register, left_lanes.lanes[2], 148 right_lanes.lanes[2]) 149 150 emitter.EmitVPadal('u16', aggregators[4], temp[0]) 151 emitter.EmitVPadal('u16', aggregators[5], temp[1]) 152 emitter.EmitVPadal('u16', aggregators[6], temp[2]) 153 emitter.EmitVPadal('u16', aggregators[7], temp[3]) 154 emitter.EmitVPadal('u16', aggregators[8], backup_register) 155 156 emitter.EmitNewline() 157 emitter.EmitComment('Loop break.') 158 emitter.EmitBneBack(1) 159 emitter.EmitNewline() 160 161 for register in temp: 162 registers.FreeRegister(register) 163 164 165 def ReadParams(emitter, registers, input_address, elements, min_reg): 166 if elements == 1 or elements == 2: 167 register = registers.DoubleRegister(min_reg * 2) 168 emitter.EmitVLoad('1.32', register, emitter.Dereference(input_address, 64)) 169 return register 170 elif elements == 3 or elements == 4: 171 register = registers.QuadRegister(min_reg) 172 emitter.EmitVLoad('1.32', register, emitter.Dereference(input_address, 64)) 173 return register 174 else: 175 raise ConfigurationError('Unsupported elements no: %d' % elements) 176 177 178 def Duplicate(emitter, registers, rows, cols, min_register, values): 179 """Populate a grid of registers duplicating provided values.""" 180 duplicated = [] 181 if cols == 1 or cols == 2: 182 for unused_i in range(0, rows): 183 duplicated.append(registers.DoubleRegister(min_register)) 184 elif cols == 3 or cols == 4: 185 for unused_i in range(0, rows): 186 duplicated.append(registers.QuadRegister(min_register)) 187 else: 188 raise ConfigurationError('Unsupported duplicate amount: %d' % cols) 189 190 if rows == 1: 191 emitter.EmitVDup('32', duplicated[0], emitter.Lane(values, 0)) 192 elif rows == 2: 193 emitter.EmitVDup('32', duplicated[0], emitter.Lane(values, 0)) 194 emitter.EmitVDup('32', duplicated[1], emitter.Lane(values, 1)) 195 elif rows == 3: 196 emitter.EmitVDup('32', duplicated[0], emitter.Lane( 197 registers.Low(values), 0)) 198 emitter.EmitVDup('32', duplicated[1], emitter.Lane( 199 registers.Low(values), 1)) 200 emitter.EmitVDup('32', duplicated[2], emitter.Lane( 201 registers.High(values), 0)) 202 elif rows == 4: 203 emitter.EmitVDup('32', duplicated[0], emitter.Lane( 204 registers.Low(values), 0)) 205 emitter.EmitVDup('32', duplicated[1], emitter.Lane( 206 registers.Low(values), 1)) 207 emitter.EmitVDup('32', duplicated[2], emitter.Lane( 208 registers.High(values), 0)) 209 emitter.EmitVDup('32', duplicated[3], emitter.Lane( 210 registers.High(values), 1)) 211 212 return duplicated 213 214 215 def DuplicateGeneralRegister(emitter, registers, cols, general_register, 216 min_register): 217 if cols == 1 or cols == 2: 218 duplicated = registers.DoubleRegister(min_register) 219 elif cols == 3 or cols == 4: 220 duplicated = registers.QuadRegister(min_register) 221 else: 222 raise ConfigurationError('Unsupported duplicate amount: %d' % cols) 223 224 emitter.EmitVDup('32', duplicated, general_register) 225 return duplicated 226 227 228 def ReduceAggregator(emitter, registers, aggregators, row, cols): 229 if cols == 1: 230 register = registers.Low(aggregators[row]) 231 emitter.EmitVPadd('u32', register, register, register) 232 return register 233 elif cols == 2: 234 register = registers.Low(aggregators[row * 2]) 235 emitter.EmitVPadd('u32', register, register, 236 registers.Low(aggregators[row * 2 + 1])) 237 return register 238 elif cols == 3: 239 register = aggregators[row * 3] 240 emitter.EmitVPadd('u32', registers.Low(register), registers.Low(register), 241 registers.Low(aggregators[row * 3 + 1])) 242 emitter.EmitVPadd('u32', registers.High(register), 243 registers.Low(aggregators[row * 3 + 2]), 244 registers.Low(aggregators[row * 3 + 2])) 245 return register 246 elif cols == 4: 247 register = aggregators[row * 3] 248 emitter.EmitVPadd('u32', registers.Low(register), registers.Low(register), 249 registers.Low(aggregators[row * 3 + 1])) 250 emitter.EmitVPadd('u32', registers.High(register), 251 registers.Low(aggregators[row * 3 + 2]), 252 registers.Low(aggregators[row * 3 + 3])) 253 return register 254 else: 255 raise ConfigurationError('Unsupported columns no: %d' % cols) 256 257 258 def StoreAggregator(emitter, registers, aggregator, cols, result_address, 259 result_stride): 260 if cols == 1: 261 emitter.EmitVStoreOffset('1.32', emitter.Lane(aggregator, 0), 262 emitter.Dereference(result_address, None), 263 result_stride) 264 elif cols == 2: 265 emitter.EmitVStoreOffset('1.32', aggregator, 266 emitter.Dereference(result_address, None), 267 result_stride) 268 elif cols == 3: 269 emitter.EmitVStore('1.32', registers.Low(aggregator), 270 emitter.DereferenceIncrement(result_address, None)) 271 emitter.EmitVStoreOffset('1.32', emitter.Lane( 272 registers.High(aggregator), 273 0), emitter.Dereference(result_address, None), result_stride) 274 emitter.EmitNewline() 275 elif cols == 4: 276 emitter.EmitVStoreOffsetA( 277 '1.32', [registers.Low(aggregator), registers.High(aggregator)], 278 emitter.Dereference(result_address, None), result_stride) 279 else: 280 raise ConfigurationError('Unsupported columns no: %d' % cols) 281 282 283 def GenerateAggregatorReduceStore(emitter, registers, aggregators, result_type, 284 lhs_add, rhs_add, left_lanes, right_lanes, 285 results, results_stride): 286 """Emit code that reduces 4 lane aggregators to 1 value, and stores them.""" 287 rows = len(left_lanes.lanes) 288 cols = len(right_lanes.lanes) 289 290 if lhs_add: 291 left_offset = ReadParams(emitter, registers, left_lanes.input_address, rows, 292 4) 293 left_offsets = Duplicate(emitter, registers, rows, cols, 4, left_offset) 294 else: 295 left_offsets = None 296 297 if rhs_add: 298 right_offset = ReadParams(emitter, registers, right_lanes.input_address, 299 cols, 4) 300 else: 301 right_offset = None 302 303 if result_type is 'float': 304 result_scale = DuplicateGeneralRegister( 305 emitter, registers, cols, registers.MapParameter('result_scale'), 4) 306 else: 307 result_scale = None 308 309 if cols == 3: 310 emitter.EmitNewline() 311 emitter.EmitComment('Change stride because storing in two ops.') 312 emitter.EmitSub(results_stride, results_stride, 313 emitter.ImmediateConstant(8)) 314 315 emitter.EmitNewline() 316 emitter.EmitComment('Horizontal reduce aggregators.') 317 for aggregator in aggregators: 318 emitter.EmitVPadd('u32', registers.Low(aggregator), 319 registers.Low(aggregator), registers.High(aggregator)) 320 321 emitter.EmitNewline() 322 emitter.EmitComment('Reduce rows.') 323 row_temps = [] 324 for i in range(0, rows): 325 row_temps.append(ReduceAggregator(emitter, registers, aggregators, i, cols)) 326 327 if lhs_add: 328 emitter.EmitNewline() 329 emitter.EmitComment('Add lhs offsets to aggregated rows.') 330 for (row_temp, left_offset) in zip(row_temps, left_offsets): 331 emitter.EmitVAdd('s32', row_temp, row_temp, left_offset) 332 333 if rhs_add: 334 emitter.EmitNewline() 335 emitter.EmitComment('Add rhs offset to aggregated rows.') 336 for row_temp in row_temps: 337 emitter.EmitVAdd('s32', row_temp, row_temp, right_offset) 338 339 if result_type is 'float': 340 emitter.EmitNewline() 341 emitter.EmitComment('Convert to float. Multiply by result scale.') 342 for row_temp in row_temps: 343 emitter.EmitVCvt('f32', 's32', row_temp, row_temp) 344 for row_temp in row_temps: 345 emitter.EmitVMul('f32', row_temp, row_temp, result_scale) 346 347 emitter.EmitNewline() 348 emitter.EmitComment('Store reduced rows.') 349 for row_temp in row_temps: 350 StoreAggregator(emitter, registers, row_temp, cols, results, results_stride) 351 352 353 def BuildName(result_type, lhs_add, rhs_add, left, right): 354 name = 'mul_%dx8_%dx8_%s' % (left, right, result_type) 355 if lhs_add: 356 name += '_lhsadd' 357 if rhs_add: 358 name += '_rhsadd' 359 return name 360 361 362 def CppResultType(result_type): 363 if result_type is 'int32': 364 return 'std::int32_t*' 365 elif result_type is 'float': 366 return 'float*' 367 else: 368 raise ConfigurationError('Unsupported result type: %s' % result_type) 369 370 371 def GetParameters(result_type): 372 params = [['const std::uint8_t*', 'lhs'], ['const std::uint8_t*', 'rhs'], 373 ['std::int32_t', 'count'], [CppResultType(result_type), 'result'], 374 ['std::int32_t', 'result_stride']] 375 if result_type is 'float': 376 params.append(['float', 'result_scale']) 377 return params 378 379 380 def GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, left_lanes_count, 381 right_lanes_count): 382 """Emit the multiply code for given rows and cols counts.""" 383 if left_lanes_count < 1 or left_lanes_count > 4: 384 raise ConfigurationError('Left_lanes should be: 1, 2, 3 or 4.') 385 if right_lanes_count < 1 or right_lanes_count > 4: 386 raise ConfigurationError('Right_lanes should be: 1, 2, 3 or 4.') 387 388 emitter.EmitFunctionBeginA( 389 BuildName(result_type, lhs_add, rhs_add, left_lanes_count, 390 right_lanes_count), GetParameters(result_type), 'inline void') 391 392 emitter.EmitAssert('count % 8 == 0') 393 emitter.EmitAssert('count >= 8') 394 emitter.EmitAsmBegin() 395 396 registers = neon_emitter.NeonRegisters() 397 398 count = registers.MapParameter('count') 399 400 size = left_lanes_count * right_lanes_count 401 402 lhs = registers.MapParameter('lhs') 403 rhs = registers.MapParameter('rhs') 404 405 emitter.EmitPld(lhs) 406 emitter.EmitPld(rhs) 407 408 aggregators = GenerateAndClearAggregators(emitter, registers, size) 409 410 if size < 9: 411 left_lanes = GenerateMulLanes(registers, left_lanes_count, lhs) 412 right_lanes = GenerateMulLanes(registers, right_lanes_count, rhs) 413 414 GenerateNxMLoadMultiplyAggregate(emitter, registers, left_lanes, 415 right_lanes, aggregators, count) 416 417 else: # left == 3 and right == 3 418 backup_register = registers.QuadRegister() 419 left_lanes = Generate3MulLanes(backup_register, registers, lhs) 420 right_lanes = GenerateMulLanes(registers, right_lanes_count, rhs) 421 422 Generate3x3LoadMultiplyAggregate(emitter, registers, left_lanes, 423 right_lanes, aggregators, count, 424 backup_register) 425 left_lanes.FreeRegisters(registers) 426 right_lanes.FreeRegisters(registers) 427 428 GenerateAggregatorReduceStore(emitter, registers, aggregators, result_type, 429 lhs_add, rhs_add, left_lanes, right_lanes, 430 registers.MapParameter('result'), 431 registers.MapParameter('result_stride')) 432 433 emitter.EmitAsmEnd(registers.MappedParameters(), [], 434 registers.Clobbers() + ['cc', 'memory']) 435 emitter.EmitFunctionEnd() 436 437 438 def GenerateFunctions(emitter, result_type, lhs_add, rhs_add): 439 for left_lanes in range(1, 4): 440 for right_lanes in range(1, 4): 441 GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, left_lanes, 442 right_lanes) 443 emitter.EmitNewline() 444 445 GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, 1, 4) 446 emitter.EmitNewline() 447 448 449 if __name__ == '__main__': 450 GenerateFunctions(neon_emitter.NeonEmitter(), 'int32', True, True) 451