Home | History | Annotate | Download | only in generators
      1 """Mul primitive used by the GEMM function.
      2 
      3 The Mul primitive takes 1-3 zipped rows and 1-3 zipped columns and performs
      4 matrix multiplication on those resulting in a small 1x1 to 3x3 block of results.
      5 """
      6 
      7 import neon_emitter
      8 
      9 
     10 class Error(Exception):
     11   """Module level error."""
     12 
     13 
     14 class ConfigurationError(Error):
     15   """Unsupported configuration."""
     16 
     17 
     18 class MulLanes(object):
     19 
     20   def __init__(self, input_address):
     21     self.input_address = input_address
     22     self.lanes = []
     23 
     24   def AddLane(self, lane):
     25     self.lanes.append(lane)
     26 
     27   def FreeRegisters(self, registers):
     28     for i in range(0, len(self.lanes)):
     29       registers.FreeRegister(self.lanes[i])
     30       self.lanes[i] = None
     31 
     32 
     33 def GenerateMulLanes(registers, lane_count, address):
     34   lanes = MulLanes(address)
     35   for unused_i in range(0, lane_count):
     36     lanes.AddLane(registers.DoubleRegister())
     37   return lanes
     38 
     39 
     40 def Generate3MulLanes(quad_register, registers, address):
     41   lanes = MulLanes(address)
     42   lanes.AddLane(registers.Low(quad_register))
     43   lanes.AddLane(registers.High(quad_register))
     44   lanes.AddLane(registers.DoubleRegister())
     45   return lanes
     46 
     47 
     48 def GenerateAndClearAggregators(emitter, registers, aggregator_count):
     49   """Prepare aggregators and emit aggregator clear code."""
     50   emitter.EmitComment('Clear aggregators.')
     51   aggregators = []
     52   for i in range(0, aggregator_count):
     53     aggregator = registers.QuadRegister()
     54     aggregators.append(aggregator)
     55     if i < 3:
     56       emitter.EmitVMov('i32', aggregator, emitter.ImmediateConstant(0))
     57     else:
     58       emitter.EmitVMov('i32', aggregator, aggregators[i - 3])
     59   emitter.EmitNewline()
     60   return aggregators
     61 
     62 
     63 def GenerateNxMLoadMultiplyAggregate(emitter, registers, left_lanes,
     64                                      right_lanes, aggregators, count):
     65   """Emit inner loop for N rows x M cols multiplication."""
     66   emitter.EmitComment('General NxM lanes loop.')
     67   emitter.EmitNumericalLabel(1)
     68   emitter.EmitNewline()
     69   emitter.EmitComment('Subtract counter.')
     70   emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
     71   emitter.EmitNewline()
     72 
     73   emitter.EmitVLoadA('1.8', left_lanes.lanes,
     74                      emitter.DereferenceIncrement(left_lanes.input_address, 64))
     75   emitter.EmitVLoadA(
     76       '1.8', right_lanes.lanes,
     77       emitter.DereferenceIncrement(right_lanes.input_address, 64))
     78 
     79   emitter.EmitPldOffset(left_lanes.input_address, emitter.ImmediateConstant(64))
     80   emitter.EmitPldOffset(right_lanes.input_address,
     81                         emitter.ImmediateConstant(64))
     82 
     83   rows = len(left_lanes.lanes)
     84   cols = len(right_lanes.lanes)
     85 
     86   multiply_results = []
     87   for i in range(0, rows * cols):
     88     multiply_results.append(registers.QuadRegister())
     89 
     90   for row in range(0, rows):
     91     for col in range(0, cols):
     92       index = row * cols + col
     93       emitter.EmitVMull('u8', multiply_results[index], right_lanes.lanes[col],
     94                         left_lanes.lanes[row])
     95 
     96   for i in range(0, rows * cols):
     97     emitter.EmitVPadal('u16', aggregators[i], multiply_results[i])
     98 
     99   emitter.EmitNewline()
    100   emitter.EmitComment('Loop break.')
    101   emitter.EmitBneBack(1)
    102   emitter.EmitNewline()
    103 
    104   for register in multiply_results:
    105     registers.FreeRegister(register)
    106 
    107 
    108 def Generate3x3LoadMultiplyAggregate(emitter, registers, left_lanes,
    109                                      right_lanes, aggregators, count,
    110                                      backup_register):
    111   """Emit inner loop for 3 rows x 3 cols multiplication (register trick)."""
    112   emitter.EmitComment('3x3 lanes loop.')
    113   emitter.EmitNumericalLabel(1)
    114   emitter.EmitNewline()
    115   emitter.EmitComment('Subtract counter.')
    116   emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
    117   emitter.EmitNewline()
    118 
    119   emitter.EmitVLoadA('1.8', left_lanes.lanes,
    120                      emitter.DereferenceIncrement(left_lanes.input_address, 64))
    121   emitter.EmitVLoadA(
    122       '1.8', right_lanes.lanes,
    123       emitter.DereferenceIncrement(right_lanes.input_address, 64))
    124 
    125   emitter.EmitPldOffset(left_lanes.input_address, emitter.ImmediateConstant(64))
    126   emitter.EmitPldOffset(right_lanes.input_address,
    127                         emitter.ImmediateConstant(64))
    128 
    129   temp = []
    130   for unused_i in range(0, 4):
    131     temp.append(registers.QuadRegister())
    132 
    133   emitter.EmitVMull('u8', temp[0], left_lanes.lanes[0], right_lanes.lanes[0])
    134   emitter.EmitVMull('u8', temp[1], left_lanes.lanes[0], right_lanes.lanes[1])
    135   emitter.EmitVMull('u8', temp[2], left_lanes.lanes[0], right_lanes.lanes[2])
    136   emitter.EmitVMull('u8', temp[3], left_lanes.lanes[1], right_lanes.lanes[0])
    137 
    138   emitter.EmitVPadal('u16', aggregators[0], temp[0])
    139   emitter.EmitVPadal('u16', aggregators[1], temp[1])
    140   emitter.EmitVPadal('u16', aggregators[2], temp[2])
    141   emitter.EmitVPadal('u16', aggregators[3], temp[3])
    142 
    143   emitter.EmitVMull('u8', temp[0], left_lanes.lanes[1], right_lanes.lanes[1])
    144   emitter.EmitVMull('u8', temp[1], left_lanes.lanes[1], right_lanes.lanes[2])
    145   emitter.EmitVMull('u8', temp[2], left_lanes.lanes[2], right_lanes.lanes[0])
    146   emitter.EmitVMull('u8', temp[3], left_lanes.lanes[2], right_lanes.lanes[1])
    147   emitter.EmitVMull('u8', backup_register, left_lanes.lanes[2],
    148                     right_lanes.lanes[2])
    149 
    150   emitter.EmitVPadal('u16', aggregators[4], temp[0])
    151   emitter.EmitVPadal('u16', aggregators[5], temp[1])
    152   emitter.EmitVPadal('u16', aggregators[6], temp[2])
    153   emitter.EmitVPadal('u16', aggregators[7], temp[3])
    154   emitter.EmitVPadal('u16', aggregators[8], backup_register)
    155 
    156   emitter.EmitNewline()
    157   emitter.EmitComment('Loop break.')
    158   emitter.EmitBneBack(1)
    159   emitter.EmitNewline()
    160 
    161   for register in temp:
    162     registers.FreeRegister(register)
    163 
    164 
    165 def ReadParams(emitter, registers, input_address, elements, min_reg):
    166   if elements == 1 or elements == 2:
    167     register = registers.DoubleRegister(min_reg * 2)
    168     emitter.EmitVLoad('1.32', register, emitter.Dereference(input_address, 64))
    169     return register
    170   elif elements == 3 or elements == 4:
    171     register = registers.QuadRegister(min_reg)
    172     emitter.EmitVLoad('1.32', register, emitter.Dereference(input_address, 64))
    173     return register
    174   else:
    175     raise ConfigurationError('Unsupported elements no: %d' % elements)
    176 
    177 
    178 def Duplicate(emitter, registers, rows, cols, min_register, values):
    179   """Populate a grid of registers duplicating provided values."""
    180   duplicated = []
    181   if cols == 1 or cols == 2:
    182     for unused_i in range(0, rows):
    183       duplicated.append(registers.DoubleRegister(min_register))
    184   elif cols == 3 or cols == 4:
    185     for unused_i in range(0, rows):
    186       duplicated.append(registers.QuadRegister(min_register))
    187   else:
    188     raise ConfigurationError('Unsupported duplicate amount: %d' % cols)
    189 
    190   if rows == 1:
    191     emitter.EmitVDup('32', duplicated[0], emitter.Lane(values, 0))
    192   elif rows == 2:
    193     emitter.EmitVDup('32', duplicated[0], emitter.Lane(values, 0))
    194     emitter.EmitVDup('32', duplicated[1], emitter.Lane(values, 1))
    195   elif rows == 3:
    196     emitter.EmitVDup('32', duplicated[0], emitter.Lane(
    197         registers.Low(values), 0))
    198     emitter.EmitVDup('32', duplicated[1], emitter.Lane(
    199         registers.Low(values), 1))
    200     emitter.EmitVDup('32', duplicated[2], emitter.Lane(
    201         registers.High(values), 0))
    202   elif rows == 4:
    203     emitter.EmitVDup('32', duplicated[0], emitter.Lane(
    204         registers.Low(values), 0))
    205     emitter.EmitVDup('32', duplicated[1], emitter.Lane(
    206         registers.Low(values), 1))
    207     emitter.EmitVDup('32', duplicated[2], emitter.Lane(
    208         registers.High(values), 0))
    209     emitter.EmitVDup('32', duplicated[3], emitter.Lane(
    210         registers.High(values), 1))
    211 
    212   return duplicated
    213 
    214 
    215 def DuplicateGeneralRegister(emitter, registers, cols, general_register,
    216                              min_register):
    217   if cols == 1 or cols == 2:
    218     duplicated = registers.DoubleRegister(min_register)
    219   elif cols == 3 or cols == 4:
    220     duplicated = registers.QuadRegister(min_register)
    221   else:
    222     raise ConfigurationError('Unsupported duplicate amount: %d' % cols)
    223 
    224   emitter.EmitVDup('32', duplicated, general_register)
    225   return duplicated
    226 
    227 
    228 def ReduceAggregator(emitter, registers, aggregators, row, cols):
    229   if cols == 1:
    230     register = registers.Low(aggregators[row])
    231     emitter.EmitVPadd('u32', register, register, register)
    232     return register
    233   elif cols == 2:
    234     register = registers.Low(aggregators[row * 2])
    235     emitter.EmitVPadd('u32', register, register,
    236                       registers.Low(aggregators[row * 2 + 1]))
    237     return register
    238   elif cols == 3:
    239     register = aggregators[row * 3]
    240     emitter.EmitVPadd('u32', registers.Low(register), registers.Low(register),
    241                       registers.Low(aggregators[row * 3 + 1]))
    242     emitter.EmitVPadd('u32', registers.High(register),
    243                       registers.Low(aggregators[row * 3 + 2]),
    244                       registers.Low(aggregators[row * 3 + 2]))
    245     return register
    246   elif cols == 4:
    247     register = aggregators[row * 3]
    248     emitter.EmitVPadd('u32', registers.Low(register), registers.Low(register),
    249                       registers.Low(aggregators[row * 3 + 1]))
    250     emitter.EmitVPadd('u32', registers.High(register),
    251                       registers.Low(aggregators[row * 3 + 2]),
    252                       registers.Low(aggregators[row * 3 + 3]))
    253     return register
    254   else:
    255     raise ConfigurationError('Unsupported columns no: %d' % cols)
    256 
    257 
    258 def StoreAggregator(emitter, registers, aggregator, cols, result_address,
    259                     result_stride):
    260   if cols == 1:
    261     emitter.EmitVStoreOffset('1.32', emitter.Lane(aggregator, 0),
    262                              emitter.Dereference(result_address, None),
    263                              result_stride)
    264   elif cols == 2:
    265     emitter.EmitVStoreOffset('1.32', aggregator,
    266                              emitter.Dereference(result_address, None),
    267                              result_stride)
    268   elif cols == 3:
    269     emitter.EmitVStore('1.32', registers.Low(aggregator),
    270                        emitter.DereferenceIncrement(result_address, None))
    271     emitter.EmitVStoreOffset('1.32', emitter.Lane(
    272         registers.High(aggregator),
    273         0), emitter.Dereference(result_address, None), result_stride)
    274     emitter.EmitNewline()
    275   elif cols == 4:
    276     emitter.EmitVStoreOffsetA(
    277         '1.32', [registers.Low(aggregator), registers.High(aggregator)],
    278         emitter.Dereference(result_address, None), result_stride)
    279   else:
    280     raise ConfigurationError('Unsupported columns no: %d' % cols)
    281 
    282 
    283 def GenerateAggregatorReduceStore(emitter, registers, aggregators, result_type,
    284                                   lhs_add, rhs_add, left_lanes, right_lanes,
    285                                   results, results_stride):
    286   """Emit code that reduces 4 lane aggregators to 1 value, and stores them."""
    287   rows = len(left_lanes.lanes)
    288   cols = len(right_lanes.lanes)
    289 
    290   if lhs_add:
    291     left_offset = ReadParams(emitter, registers, left_lanes.input_address, rows,
    292                              4)
    293     left_offsets = Duplicate(emitter, registers, rows, cols, 4, left_offset)
    294   else:
    295     left_offsets = None
    296 
    297   if rhs_add:
    298     right_offset = ReadParams(emitter, registers, right_lanes.input_address,
    299                               cols, 4)
    300   else:
    301     right_offset = None
    302 
    303   if result_type is 'float':
    304     result_scale = DuplicateGeneralRegister(
    305         emitter, registers, cols, registers.MapParameter('result_scale'), 4)
    306   else:
    307     result_scale = None
    308 
    309   if cols == 3:
    310     emitter.EmitNewline()
    311     emitter.EmitComment('Change stride because storing in two ops.')
    312     emitter.EmitSub(results_stride, results_stride,
    313                     emitter.ImmediateConstant(8))
    314 
    315   emitter.EmitNewline()
    316   emitter.EmitComment('Horizontal reduce aggregators.')
    317   for aggregator in aggregators:
    318     emitter.EmitVPadd('u32', registers.Low(aggregator),
    319                       registers.Low(aggregator), registers.High(aggregator))
    320 
    321   emitter.EmitNewline()
    322   emitter.EmitComment('Reduce rows.')
    323   row_temps = []
    324   for i in range(0, rows):
    325     row_temps.append(ReduceAggregator(emitter, registers, aggregators, i, cols))
    326 
    327   if lhs_add:
    328     emitter.EmitNewline()
    329     emitter.EmitComment('Add lhs offsets to aggregated rows.')
    330     for (row_temp, left_offset) in zip(row_temps, left_offsets):
    331       emitter.EmitVAdd('s32', row_temp, row_temp, left_offset)
    332 
    333   if rhs_add:
    334     emitter.EmitNewline()
    335     emitter.EmitComment('Add rhs offset to aggregated rows.')
    336     for row_temp in row_temps:
    337       emitter.EmitVAdd('s32', row_temp, row_temp, right_offset)
    338 
    339   if result_type is 'float':
    340     emitter.EmitNewline()
    341     emitter.EmitComment('Convert to float. Multiply by result scale.')
    342     for row_temp in row_temps:
    343       emitter.EmitVCvt('f32', 's32', row_temp, row_temp)
    344     for row_temp in row_temps:
    345       emitter.EmitVMul('f32', row_temp, row_temp, result_scale)
    346 
    347   emitter.EmitNewline()
    348   emitter.EmitComment('Store reduced rows.')
    349   for row_temp in row_temps:
    350     StoreAggregator(emitter, registers, row_temp, cols, results, results_stride)
    351 
    352 
    353 def BuildName(result_type, lhs_add, rhs_add, left, right):
    354   name = 'mul_%dx8_%dx8_%s' % (left, right, result_type)
    355   if lhs_add:
    356     name += '_lhsadd'
    357   if rhs_add:
    358     name += '_rhsadd'
    359   return name
    360 
    361 
    362 def CppResultType(result_type):
    363   if result_type is 'int32':
    364     return 'std::int32_t*'
    365   elif result_type is 'float':
    366     return 'float*'
    367   else:
    368     raise ConfigurationError('Unsupported result type: %s' % result_type)
    369 
    370 
    371 def GetParameters(result_type):
    372   params = [['const std::uint8_t*', 'lhs'], ['const std::uint8_t*', 'rhs'],
    373             ['std::int32_t', 'count'], [CppResultType(result_type), 'result'],
    374             ['std::int32_t', 'result_stride']]
    375   if result_type is 'float':
    376     params.append(['float', 'result_scale'])
    377   return params
    378 
    379 
    380 def GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, left_lanes_count,
    381                       right_lanes_count):
    382   """Emit the multiply code for given rows and cols counts."""
    383   if left_lanes_count < 1 or left_lanes_count > 4:
    384     raise ConfigurationError('Left_lanes should be: 1, 2, 3 or 4.')
    385   if right_lanes_count < 1 or right_lanes_count > 4:
    386     raise ConfigurationError('Right_lanes should be: 1, 2, 3 or 4.')
    387 
    388   emitter.EmitFunctionBeginA(
    389       BuildName(result_type, lhs_add, rhs_add, left_lanes_count,
    390                 right_lanes_count), GetParameters(result_type), 'inline void')
    391 
    392   emitter.EmitAssert('count % 8 == 0')
    393   emitter.EmitAssert('count >= 8')
    394   emitter.EmitAsmBegin()
    395 
    396   registers = neon_emitter.NeonRegisters()
    397 
    398   count = registers.MapParameter('count')
    399 
    400   size = left_lanes_count * right_lanes_count
    401 
    402   lhs = registers.MapParameter('lhs')
    403   rhs = registers.MapParameter('rhs')
    404 
    405   emitter.EmitPld(lhs)
    406   emitter.EmitPld(rhs)
    407 
    408   aggregators = GenerateAndClearAggregators(emitter, registers, size)
    409 
    410   if size < 9:
    411     left_lanes = GenerateMulLanes(registers, left_lanes_count, lhs)
    412     right_lanes = GenerateMulLanes(registers, right_lanes_count, rhs)
    413 
    414     GenerateNxMLoadMultiplyAggregate(emitter, registers, left_lanes,
    415                                      right_lanes, aggregators, count)
    416 
    417   else:  # left == 3 and right == 3
    418     backup_register = registers.QuadRegister()
    419     left_lanes = Generate3MulLanes(backup_register, registers, lhs)
    420     right_lanes = GenerateMulLanes(registers, right_lanes_count, rhs)
    421 
    422     Generate3x3LoadMultiplyAggregate(emitter, registers, left_lanes,
    423                                      right_lanes, aggregators, count,
    424                                      backup_register)
    425   left_lanes.FreeRegisters(registers)
    426   right_lanes.FreeRegisters(registers)
    427 
    428   GenerateAggregatorReduceStore(emitter, registers, aggregators, result_type,
    429                                 lhs_add, rhs_add, left_lanes, right_lanes,
    430                                 registers.MapParameter('result'),
    431                                 registers.MapParameter('result_stride'))
    432 
    433   emitter.EmitAsmEnd(registers.MappedParameters(), [],
    434                      registers.Clobbers() + ['cc', 'memory'])
    435   emitter.EmitFunctionEnd()
    436 
    437 
    438 def GenerateFunctions(emitter, result_type, lhs_add, rhs_add):
    439   for left_lanes in range(1, 4):
    440     for right_lanes in range(1, 4):
    441       GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, left_lanes,
    442                         right_lanes)
    443       emitter.EmitNewline()
    444 
    445   GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, 1, 4)
    446   emitter.EmitNewline()
    447 
    448 
    449 if __name__ == '__main__':
    450   GenerateFunctions(neon_emitter.NeonEmitter(), 'int32', True, True)
    451