1 """Mul primitive used by the GEMM function. 2 3 The Mul primitive takes 1-3 zipped rows and 1-3 zipped columns and performs 4 matrix multiplication on those resulting in a small 1x1 to 3x3 block of results. 5 """ 6 7 import neon_emitter 8 9 10 class Error(Exception): 11 """Module level error.""" 12 13 14 class ConfigurationError(Error): 15 """Unsupported configuration.""" 16 17 18 class MulLanes(object): 19 20 def __init__(self, input_address): 21 self.input_address = input_address 22 self.lanes = [] 23 24 def AddLane(self, lane): 25 self.lanes.append(lane) 26 27 def FreeRegisters(self, registers): 28 for i in range(0, len(self.lanes)): 29 registers.FreeRegister(self.lanes[i]) 30 self.lanes[i] = None 31 32 33 def GenerateMulLanes(registers, lane_count, address): 34 lanes = MulLanes(address) 35 for unused_i in range(0, lane_count): 36 lanes.AddLane(registers.DoubleRegister()) 37 return lanes 38 39 40 def Generate3MulLanes(quad_register, registers, address): 41 lanes = MulLanes(address) 42 lanes.AddLane(registers.Low(quad_register)) 43 lanes.AddLane(registers.High(quad_register)) 44 lanes.AddLane(registers.DoubleRegister()) 45 return lanes 46 47 48 def GenerateAndClearAggregators(emitter, registers, aggregator_count): 49 """Prepare aggregators and emit aggregator clear code.""" 50 emitter.EmitComment('Clear aggregators.') 51 aggregators = [] 52 for i in range(0, aggregator_count): 53 aggregator = registers.QuadRegister() 54 aggregators.append(aggregator) 55 if i < 3: 56 emitter.EmitVMov('i32', aggregator, emitter.ImmediateConstant(0)) 57 else: 58 emitter.EmitVMov('i32', aggregator, aggregators[i - 3]) 59 emitter.EmitNewline() 60 return aggregators 61 62 63 def GenerateNxMLoadMultiplyAggregate(emitter, registers, left_lanes, 64 right_lanes, aggregators, count): 65 """Emit inner loop for N rows x M cols multiplication.""" 66 emitter.EmitComment('General NxM lanes loop.') 67 emitter.EmitNumericalLabel(1) 68 emitter.EmitNewline() 69 emitter.EmitComment('Subtract counter.') 70 emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) 71 emitter.EmitNewline() 72 73 emitter.EmitVLoadA('1.8', left_lanes.lanes, 74 emitter.DereferenceIncrement(left_lanes.input_address, 64)) 75 emitter.EmitVLoadA( 76 '1.8', right_lanes.lanes, 77 emitter.DereferenceIncrement(right_lanes.input_address, 64)) 78 79 emitter.EmitPldOffset(left_lanes.input_address, emitter.ImmediateConstant(64)) 80 emitter.EmitPldOffset(right_lanes.input_address, 81 emitter.ImmediateConstant(64)) 82 83 rows = len(left_lanes.lanes) 84 cols = len(right_lanes.lanes) 85 86 multiply_results = [] 87 for i in range(0, rows * cols): 88 multiply_results.append(registers.QuadRegister()) 89 90 for row in range(0, rows): 91 for col in range(0, cols): 92 index = row * cols + col 93 emitter.EmitVMull('u8', multiply_results[index], right_lanes.lanes[col], 94 left_lanes.lanes[row]) 95 96 for i in range(0, rows * cols): 97 emitter.EmitVPadal('u16', aggregators[i], multiply_results[i]) 98 99 emitter.EmitNewline() 100 emitter.EmitComment('Loop break.') 101 emitter.EmitBneBack(1) 102 emitter.EmitNewline() 103 104 for register in multiply_results: 105 registers.FreeRegister(register) 106 107 108 def Generate3x3LoadMultiplyAggregate(emitter, registers, left_lanes, 109 right_lanes, aggregators, count, 110 backup_register): 111 """Emit inner loop for 3 rows x 3 cols multiplication (register trick).""" 112 emitter.EmitComment('3x3 lanes loop.') 113 emitter.EmitNumericalLabel(1) 114 emitter.EmitNewline() 115 emitter.EmitComment('Subtract counter.') 116 emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) 117 emitter.EmitNewline() 118 119 emitter.EmitVLoadA('1.8', left_lanes.lanes, 120 emitter.DereferenceIncrement(left_lanes.input_address, 64)) 121 emitter.EmitVLoadA( 122 '1.8', right_lanes.lanes, 123 emitter.DereferenceIncrement(right_lanes.input_address, 64)) 124 125 emitter.EmitPldOffset(left_lanes.input_address, emitter.ImmediateConstant(64)) 126 emitter.EmitPldOffset(right_lanes.input_address, 127 emitter.ImmediateConstant(64)) 128 129 temp = [] 130 for unused_i in range(0, 4): 131 temp.append(registers.QuadRegister()) 132 133 emitter.EmitVMull('u8', temp[0], left_lanes.lanes[0], right_lanes.lanes[0]) 134 emitter.EmitVMull('u8', temp[1], left_lanes.lanes[0], right_lanes.lanes[1]) 135 emitter.EmitVMull('u8', temp[2], left_lanes.lanes[0], right_lanes.lanes[2]) 136 emitter.EmitVMull('u8', temp[3], left_lanes.lanes[1], right_lanes.lanes[0]) 137 138 emitter.EmitVPadal('u16', aggregators[0], temp[0]) 139 emitter.EmitVPadal('u16', aggregators[1], temp[1]) 140 emitter.EmitVPadal('u16', aggregators[2], temp[2]) 141 emitter.EmitVPadal('u16', aggregators[3], temp[3]) 142 143 emitter.EmitVMull('u8', temp[0], left_lanes.lanes[1], right_lanes.lanes[1]) 144 emitter.EmitVMull('u8', temp[1], left_lanes.lanes[1], right_lanes.lanes[2]) 145 emitter.EmitVMull('u8', temp[2], left_lanes.lanes[2], right_lanes.lanes[0]) 146 emitter.EmitVMull('u8', temp[3], left_lanes.lanes[2], right_lanes.lanes[1]) 147 emitter.EmitVMull('u8', backup_register, left_lanes.lanes[2], 148 right_lanes.lanes[2]) 149 150 emitter.EmitVPadal('u16', aggregators[4], temp[0]) 151 emitter.EmitVPadal('u16', aggregators[5], temp[1]) 152 emitter.EmitVPadal('u16', aggregators[6], temp[2]) 153 emitter.EmitVPadal('u16', aggregators[7], temp[3]) 154 emitter.EmitVPadal('u16', aggregators[8], backup_register) 155 156 emitter.EmitNewline() 157 emitter.EmitComment('Loop break.') 158 emitter.EmitBneBack(1) 159 emitter.EmitNewline() 160 161 for register in temp: 162 registers.FreeRegister(register) 163 164 165 def ReadParams(emitter, registers, input_address, elements, min_reg): 166 if elements == 1 or elements == 2: 167 register = registers.DoubleRegister(min_reg * 2) 168 emitter.EmitVLoad('1.32', register, emitter.Dereference(input_address, 64)) 169 return register 170 elif elements == 3: 171 register = registers.QuadRegister(min_reg) 172 emitter.EmitVLoad('1.32', register, emitter.Dereference(input_address, 64)) 173 return register 174 else: 175 raise ConfigurationError('Unsupported elements no: %d' % elements) 176 177 178 def Duplicate(emitter, registers, rows, cols, min_register, values): 179 """Populate a grid of registers duplicating provided values.""" 180 duplicated = [] 181 if cols == 1 or cols == 2: 182 for unused_i in range(0, rows): 183 duplicated.append(registers.DoubleRegister(min_register)) 184 elif cols == 3: 185 for unused_i in range(0, rows): 186 duplicated.append(registers.QuadRegister(min_register)) 187 else: 188 raise ConfigurationError('Unsupported duplicate amount: %d' % cols) 189 190 if rows == 1: 191 emitter.EmitVDup('32', duplicated[0], emitter.Lane(values, 0)) 192 elif rows == 2: 193 emitter.EmitVDup('32', duplicated[0], emitter.Lane(values, 0)) 194 emitter.EmitVDup('32', duplicated[1], emitter.Lane(values, 1)) 195 elif rows == 3: 196 emitter.EmitVDup('32', duplicated[0], emitter.Lane( 197 registers.Low(values), 0)) 198 emitter.EmitVDup('32', duplicated[1], emitter.Lane( 199 registers.Low(values), 1)) 200 emitter.EmitVDup('32', duplicated[2], emitter.Lane( 201 registers.High(values), 0)) 202 203 return duplicated 204 205 206 def DuplicateGeneralRegister(emitter, registers, cols, general_register, 207 min_register): 208 if cols == 1 or cols == 2: 209 duplicated = registers.DoubleRegister(min_register) 210 elif cols == 3: 211 duplicated = registers.QuadRegister(min_register) 212 else: 213 raise ConfigurationError('Unsupported duplicate amount: %d' % cols) 214 215 emitter.EmitVDup('32', duplicated, general_register) 216 return duplicated 217 218 219 def ReduceAggregator(emitter, registers, aggregators, row, cols): 220 if cols == 1: 221 register = registers.Low(aggregators[row]) 222 emitter.EmitVPadd('u32', register, register, register) 223 return register 224 elif cols == 2: 225 register = registers.Low(aggregators[row * 2]) 226 emitter.EmitVPadd('u32', register, register, 227 registers.Low(aggregators[row * 2 + 1])) 228 return register 229 elif cols == 3: 230 register = aggregators[row * 3] 231 emitter.EmitVPadd('u32', registers.Low(register), registers.Low(register), 232 registers.Low(aggregators[row * 3 + 1])) 233 emitter.EmitVPadd('u32', registers.High(register), 234 registers.Low(aggregators[row * 3 + 2]), 235 registers.Low(aggregators[row * 3 + 2])) 236 return register 237 else: 238 raise ConfigurationError('Unsupported columns no: %d' % cols) 239 240 241 def StoreAggregator(emitter, registers, aggregator, cols, result_address, 242 result_stride): 243 if cols == 1: 244 emitter.EmitVStoreOffset('1.32', emitter.Lane(aggregator, 0), 245 emitter.Dereference(result_address, None), 246 result_stride) 247 elif cols == 2: 248 emitter.EmitVStoreOffset('1.32', aggregator, 249 emitter.Dereference(result_address, None), 250 result_stride) 251 elif cols == 3: 252 emitter.EmitVStore('1.32', registers.Low(aggregator), 253 emitter.DereferenceIncrement(result_address, None)) 254 emitter.EmitVStoreOffset('1.32', emitter.Lane( 255 registers.High(aggregator), 256 0), emitter.Dereference(result_address, None), result_stride) 257 emitter.EmitNewline() 258 else: 259 raise ConfigurationError('Unsupported columns no: %d' % cols) 260 261 262 def GenerateAggregatorReduceStore(emitter, registers, aggregators, result_type, 263 lhs_add, rhs_add, left_lanes, right_lanes, 264 results, results_stride): 265 """Emit code that reduces 4 lane aggregators to 1 value, and stores them.""" 266 rows = len(left_lanes.lanes) 267 cols = len(right_lanes.lanes) 268 269 if lhs_add: 270 left_offset = ReadParams(emitter, registers, left_lanes.input_address, rows, 271 4) 272 left_offsets = Duplicate(emitter, registers, rows, cols, 4, left_offset) 273 else: 274 left_offsets = None 275 276 if rhs_add: 277 right_offset = ReadParams(emitter, registers, right_lanes.input_address, 278 cols, 4) 279 else: 280 right_offset = None 281 282 if result_type is 'float': 283 result_scale = DuplicateGeneralRegister( 284 emitter, registers, cols, registers.MapParameter('result_scale'), 4) 285 else: 286 result_scale = None 287 288 if cols == 3: 289 emitter.EmitNewline() 290 emitter.EmitComment('Change stride because storing in two ops.') 291 emitter.EmitSub(results_stride, results_stride, 292 emitter.ImmediateConstant(8)) 293 294 emitter.EmitNewline() 295 emitter.EmitComment('Horizontal reduce aggregators.') 296 for aggregator in aggregators: 297 emitter.EmitVPadd('u32', registers.Low(aggregator), 298 registers.Low(aggregator), registers.High(aggregator)) 299 300 emitter.EmitNewline() 301 emitter.EmitComment('Reduce rows.') 302 row_temps = [] 303 for i in range(0, rows): 304 row_temps.append(ReduceAggregator(emitter, registers, aggregators, i, cols)) 305 306 if lhs_add: 307 emitter.EmitNewline() 308 emitter.EmitComment('Add lhs offsets to aggregated rows.') 309 for (row_temp, left_offset) in zip(row_temps, left_offsets): 310 emitter.EmitVAdd('s32', row_temp, row_temp, left_offset) 311 312 if rhs_add: 313 emitter.EmitNewline() 314 emitter.EmitComment('Add rhs offset to aggregated rows.') 315 for row_temp in row_temps: 316 emitter.EmitVAdd('s32', row_temp, row_temp, right_offset) 317 318 if result_type is 'float': 319 emitter.EmitNewline() 320 emitter.EmitComment('Convert to float. Multiply by result scale.') 321 for row_temp in row_temps: 322 emitter.EmitVCvt('f32', 's32', row_temp, row_temp) 323 for row_temp in row_temps: 324 emitter.EmitVMul('f32', row_temp, row_temp, result_scale) 325 326 emitter.EmitNewline() 327 emitter.EmitComment('Store reduced rows.') 328 for row_temp in row_temps: 329 StoreAggregator(emitter, registers, row_temp, cols, results, results_stride) 330 331 332 def BuildName(result_type, lhs_add, rhs_add, left, right): 333 name = 'mul_%dx8_%dx8_%s' % (left, right, result_type) 334 if lhs_add: 335 name += '_lhsadd' 336 if rhs_add: 337 name += '_rhsadd' 338 return name 339 340 341 def CppResultType(result_type): 342 if result_type is 'int32': 343 return 'std::int32_t*' 344 elif result_type is 'float': 345 return 'float*' 346 else: 347 raise ConfigurationError('Unsupported result type: %s' % result_type) 348 349 350 def GetParameters(result_type): 351 params = [['const std::uint8_t*', 'lhs'], ['const std::uint8_t*', 'rhs'], 352 ['std::int32_t', 'count'], [CppResultType(result_type), 'result'], 353 ['std::int32_t', 'result_stride']] 354 if result_type is 'float': 355 params.append(['float', 'result_scale']) 356 return params 357 358 359 def GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, left_lanes_count, 360 right_lanes_count): 361 """Emit the multiply code for given rows and cols counts.""" 362 if left_lanes_count < 1 or left_lanes_count > 3: 363 raise ConfigurationError('Left_lanes should be: 1, 2 or 3.') 364 if right_lanes_count < 1 or right_lanes_count > 3: 365 raise ConfigurationError('Right_lanes should be: 1, 2 or 3.') 366 367 emitter.EmitFunctionBeginA( 368 BuildName(result_type, lhs_add, rhs_add, left_lanes_count, 369 right_lanes_count), GetParameters(result_type), 'inline void') 370 371 emitter.EmitAssert('count % 8 == 0') 372 emitter.EmitAssert('count >= 8') 373 emitter.EmitAsmBegin() 374 375 registers = neon_emitter.NeonRegisters() 376 377 count = registers.MapParameter('count') 378 379 size = left_lanes_count * right_lanes_count 380 381 if size < 9: 382 aggregators = GenerateAndClearAggregators(emitter, registers, size) 383 384 left_lanes = GenerateMulLanes(registers, left_lanes_count, 385 registers.MapParameter('lhs')) 386 right_lanes = GenerateMulLanes(registers, right_lanes_count, 387 registers.MapParameter('rhs')) 388 389 emitter.EmitPld(left_lanes.input_address) 390 emitter.EmitPld(right_lanes.input_address) 391 392 GenerateNxMLoadMultiplyAggregate(emitter, registers, left_lanes, 393 right_lanes, aggregators, count) 394 395 else: # left == 3 and right == 3 396 aggregators = GenerateAndClearAggregators(emitter, registers, size) 397 backup_register = registers.QuadRegister() 398 left_lanes = Generate3MulLanes(backup_register, registers, 399 registers.MapParameter('lhs')) 400 right_lanes = GenerateMulLanes(registers, right_lanes_count, 401 registers.MapParameter('rhs')) 402 403 emitter.EmitPld(left_lanes.input_address) 404 emitter.EmitPld(right_lanes.input_address) 405 406 Generate3x3LoadMultiplyAggregate(emitter, registers, left_lanes, 407 right_lanes, aggregators, count, 408 backup_register) 409 410 left_lanes.FreeRegisters(registers) 411 right_lanes.FreeRegisters(registers) 412 413 GenerateAggregatorReduceStore(emitter, registers, aggregators, result_type, 414 lhs_add, rhs_add, left_lanes, right_lanes, 415 registers.MapParameter('result'), 416 registers.MapParameter('result_stride')) 417 418 emitter.EmitAsmEnd(registers.MappedParameters(), [], 419 registers.Clobbers() + ['cc', 'memory']) 420 emitter.EmitFunctionEnd() 421 422 423 def GenerateFunctions(emitter, result_type, lhs_add, rhs_add): 424 for left_lanes in range(1, 4): 425 for right_lanes in range(1, 4): 426 GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, left_lanes, 427 right_lanes) 428 emitter.EmitNewline() 429