Home | History | Annotate | Download | only in generators
      1 """Mul primitive used by the GEMM function.
      2 
      3 The Mul primitive takes 1-3 zipped rows and 1-3 zipped columns and performs
      4 matrix multiplication on those resulting in a small 1x1 to 3x3 block of results.
      5 """
      6 
      7 import neon_emitter
      8 
      9 
     10 class Error(Exception):
     11   """Module level error."""
     12 
     13 
     14 class ConfigurationError(Error):
     15   """Unsupported configuration."""
     16 
     17 
     18 class MulLanes(object):
     19 
     20   def __init__(self, input_address):
     21     self.input_address = input_address
     22     self.lanes = []
     23 
     24   def AddLane(self, lane):
     25     self.lanes.append(lane)
     26 
     27   def FreeRegisters(self, registers):
     28     for i in range(0, len(self.lanes)):
     29       registers.FreeRegister(self.lanes[i])
     30       self.lanes[i] = None
     31 
     32 
     33 def GenerateMulLanes(registers, lane_count, address):
     34   lanes = MulLanes(address)
     35   for unused_i in range(0, lane_count):
     36     lanes.AddLane(registers.DoubleRegister())
     37   return lanes
     38 
     39 
     40 def Generate3MulLanes(quad_register, registers, address):
     41   lanes = MulLanes(address)
     42   lanes.AddLane(registers.Low(quad_register))
     43   lanes.AddLane(registers.High(quad_register))
     44   lanes.AddLane(registers.DoubleRegister())
     45   return lanes
     46 
     47 
     48 def GenerateAndClearAggregators(emitter, registers, aggregator_count):
     49   """Prepare aggregators and emit aggregator clear code."""
     50   emitter.EmitComment('Clear aggregators.')
     51   aggregators = []
     52   for i in range(0, aggregator_count):
     53     aggregator = registers.QuadRegister()
     54     aggregators.append(aggregator)
     55     if i < 3:
     56       emitter.EmitVMov('i32', aggregator, emitter.ImmediateConstant(0))
     57     else:
     58       emitter.EmitVMov('i32', aggregator, aggregators[i - 3])
     59   emitter.EmitNewline()
     60   return aggregators
     61 
     62 
     63 def GenerateNxMLoadMultiplyAggregate(emitter, registers, left_lanes,
     64                                      right_lanes, aggregators, count):
     65   """Emit inner loop for N rows x M cols multiplication."""
     66   emitter.EmitComment('General NxM lanes loop.')
     67   emitter.EmitNumericalLabel(1)
     68   emitter.EmitNewline()
     69   emitter.EmitComment('Subtract counter.')
     70   emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
     71   emitter.EmitNewline()
     72 
     73   emitter.EmitVLoadA('1.8', left_lanes.lanes,
     74                      emitter.DereferenceIncrement(left_lanes.input_address, 64))
     75   emitter.EmitVLoadA(
     76       '1.8', right_lanes.lanes,
     77       emitter.DereferenceIncrement(right_lanes.input_address, 64))
     78 
     79   emitter.EmitPldOffset(left_lanes.input_address, emitter.ImmediateConstant(64))
     80   emitter.EmitPldOffset(right_lanes.input_address,
     81                         emitter.ImmediateConstant(64))
     82 
     83   rows = len(left_lanes.lanes)
     84   cols = len(right_lanes.lanes)
     85 
     86   multiply_results = []
     87   for i in range(0, rows * cols):
     88     multiply_results.append(registers.QuadRegister())
     89 
     90   for row in range(0, rows):
     91     for col in range(0, cols):
     92       index = row * cols + col
     93       emitter.EmitVMull('u8', multiply_results[index], right_lanes.lanes[col],
     94                         left_lanes.lanes[row])
     95 
     96   for i in range(0, rows * cols):
     97     emitter.EmitVPadal('u16', aggregators[i], multiply_results[i])
     98 
     99   emitter.EmitNewline()
    100   emitter.EmitComment('Loop break.')
    101   emitter.EmitBneBack(1)
    102   emitter.EmitNewline()
    103 
    104   for register in multiply_results:
    105     registers.FreeRegister(register)
    106 
    107 
    108 def Generate3x3LoadMultiplyAggregate(emitter, registers, left_lanes,
    109                                      right_lanes, aggregators, count,
    110                                      backup_register):
    111   """Emit inner loop for 3 rows x 3 cols multiplication (register trick)."""
    112   emitter.EmitComment('3x3 lanes loop.')
    113   emitter.EmitNumericalLabel(1)
    114   emitter.EmitNewline()
    115   emitter.EmitComment('Subtract counter.')
    116   emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
    117   emitter.EmitNewline()
    118 
    119   emitter.EmitVLoadA('1.8', left_lanes.lanes,
    120                      emitter.DereferenceIncrement(left_lanes.input_address, 64))
    121   emitter.EmitVLoadA(
    122       '1.8', right_lanes.lanes,
    123       emitter.DereferenceIncrement(right_lanes.input_address, 64))
    124 
    125   emitter.EmitPldOffset(left_lanes.input_address, emitter.ImmediateConstant(64))
    126   emitter.EmitPldOffset(right_lanes.input_address,
    127                         emitter.ImmediateConstant(64))
    128 
    129   temp = []
    130   for unused_i in range(0, 4):
    131     temp.append(registers.QuadRegister())
    132 
    133   emitter.EmitVMull('u8', temp[0], left_lanes.lanes[0], right_lanes.lanes[0])
    134   emitter.EmitVMull('u8', temp[1], left_lanes.lanes[0], right_lanes.lanes[1])
    135   emitter.EmitVMull('u8', temp[2], left_lanes.lanes[0], right_lanes.lanes[2])
    136   emitter.EmitVMull('u8', temp[3], left_lanes.lanes[1], right_lanes.lanes[0])
    137 
    138   emitter.EmitVPadal('u16', aggregators[0], temp[0])
    139   emitter.EmitVPadal('u16', aggregators[1], temp[1])
    140   emitter.EmitVPadal('u16', aggregators[2], temp[2])
    141   emitter.EmitVPadal('u16', aggregators[3], temp[3])
    142 
    143   emitter.EmitVMull('u8', temp[0], left_lanes.lanes[1], right_lanes.lanes[1])
    144   emitter.EmitVMull('u8', temp[1], left_lanes.lanes[1], right_lanes.lanes[2])
    145   emitter.EmitVMull('u8', temp[2], left_lanes.lanes[2], right_lanes.lanes[0])
    146   emitter.EmitVMull('u8', temp[3], left_lanes.lanes[2], right_lanes.lanes[1])
    147   emitter.EmitVMull('u8', backup_register, left_lanes.lanes[2],
    148                     right_lanes.lanes[2])
    149 
    150   emitter.EmitVPadal('u16', aggregators[4], temp[0])
    151   emitter.EmitVPadal('u16', aggregators[5], temp[1])
    152   emitter.EmitVPadal('u16', aggregators[6], temp[2])
    153   emitter.EmitVPadal('u16', aggregators[7], temp[3])
    154   emitter.EmitVPadal('u16', aggregators[8], backup_register)
    155 
    156   emitter.EmitNewline()
    157   emitter.EmitComment('Loop break.')
    158   emitter.EmitBneBack(1)
    159   emitter.EmitNewline()
    160 
    161   for register in temp:
    162     registers.FreeRegister(register)
    163 
    164 
    165 def ReadParams(emitter, registers, input_address, elements, min_reg):
    166   if elements == 1 or elements == 2:
    167     register = registers.DoubleRegister(min_reg * 2)
    168     emitter.EmitVLoad('1.32', register, emitter.Dereference(input_address, 64))
    169     return register
    170   elif elements == 3:
    171     register = registers.QuadRegister(min_reg)
    172     emitter.EmitVLoad('1.32', register, emitter.Dereference(input_address, 64))
    173     return register
    174   else:
    175     raise ConfigurationError('Unsupported elements no: %d' % elements)
    176 
    177 
    178 def Duplicate(emitter, registers, rows, cols, min_register, values):
    179   """Populate a grid of registers duplicating provided values."""
    180   duplicated = []
    181   if cols == 1 or cols == 2:
    182     for unused_i in range(0, rows):
    183       duplicated.append(registers.DoubleRegister(min_register))
    184   elif cols == 3:
    185     for unused_i in range(0, rows):
    186       duplicated.append(registers.QuadRegister(min_register))
    187   else:
    188     raise ConfigurationError('Unsupported duplicate amount: %d' % cols)
    189 
    190   if rows == 1:
    191     emitter.EmitVDup('32', duplicated[0], emitter.Lane(values, 0))
    192   elif rows == 2:
    193     emitter.EmitVDup('32', duplicated[0], emitter.Lane(values, 0))
    194     emitter.EmitVDup('32', duplicated[1], emitter.Lane(values, 1))
    195   elif rows == 3:
    196     emitter.EmitVDup('32', duplicated[0], emitter.Lane(
    197         registers.Low(values), 0))
    198     emitter.EmitVDup('32', duplicated[1], emitter.Lane(
    199         registers.Low(values), 1))
    200     emitter.EmitVDup('32', duplicated[2], emitter.Lane(
    201         registers.High(values), 0))
    202 
    203   return duplicated
    204 
    205 
    206 def DuplicateGeneralRegister(emitter, registers, cols, general_register,
    207                              min_register):
    208   if cols == 1 or cols == 2:
    209     duplicated = registers.DoubleRegister(min_register)
    210   elif cols == 3:
    211     duplicated = registers.QuadRegister(min_register)
    212   else:
    213     raise ConfigurationError('Unsupported duplicate amount: %d' % cols)
    214 
    215   emitter.EmitVDup('32', duplicated, general_register)
    216   return duplicated
    217 
    218 
    219 def ReduceAggregator(emitter, registers, aggregators, row, cols):
    220   if cols == 1:
    221     register = registers.Low(aggregators[row])
    222     emitter.EmitVPadd('u32', register, register, register)
    223     return register
    224   elif cols == 2:
    225     register = registers.Low(aggregators[row * 2])
    226     emitter.EmitVPadd('u32', register, register,
    227                       registers.Low(aggregators[row * 2 + 1]))
    228     return register
    229   elif cols == 3:
    230     register = aggregators[row * 3]
    231     emitter.EmitVPadd('u32', registers.Low(register), registers.Low(register),
    232                       registers.Low(aggregators[row * 3 + 1]))
    233     emitter.EmitVPadd('u32', registers.High(register),
    234                       registers.Low(aggregators[row * 3 + 2]),
    235                       registers.Low(aggregators[row * 3 + 2]))
    236     return register
    237   else:
    238     raise ConfigurationError('Unsupported columns no: %d' % cols)
    239 
    240 
    241 def StoreAggregator(emitter, registers, aggregator, cols, result_address,
    242                     result_stride):
    243   if cols == 1:
    244     emitter.EmitVStoreOffset('1.32', emitter.Lane(aggregator, 0),
    245                              emitter.Dereference(result_address, None),
    246                              result_stride)
    247   elif cols == 2:
    248     emitter.EmitVStoreOffset('1.32', aggregator,
    249                              emitter.Dereference(result_address, None),
    250                              result_stride)
    251   elif cols == 3:
    252     emitter.EmitVStore('1.32', registers.Low(aggregator),
    253                        emitter.DereferenceIncrement(result_address, None))
    254     emitter.EmitVStoreOffset('1.32', emitter.Lane(
    255         registers.High(aggregator),
    256         0), emitter.Dereference(result_address, None), result_stride)
    257     emitter.EmitNewline()
    258   else:
    259     raise ConfigurationError('Unsupported columns no: %d' % cols)
    260 
    261 
    262 def GenerateAggregatorReduceStore(emitter, registers, aggregators, result_type,
    263                                   lhs_add, rhs_add, left_lanes, right_lanes,
    264                                   results, results_stride):
    265   """Emit code that reduces 4 lane aggregators to 1 value, and stores them."""
    266   rows = len(left_lanes.lanes)
    267   cols = len(right_lanes.lanes)
    268 
    269   if lhs_add:
    270     left_offset = ReadParams(emitter, registers, left_lanes.input_address, rows,
    271                              4)
    272     left_offsets = Duplicate(emitter, registers, rows, cols, 4, left_offset)
    273   else:
    274     left_offsets = None
    275 
    276   if rhs_add:
    277     right_offset = ReadParams(emitter, registers, right_lanes.input_address,
    278                               cols, 4)
    279   else:
    280     right_offset = None
    281 
    282   if result_type is 'float':
    283     result_scale = DuplicateGeneralRegister(
    284         emitter, registers, cols, registers.MapParameter('result_scale'), 4)
    285   else:
    286     result_scale = None
    287 
    288   if cols == 3:
    289     emitter.EmitNewline()
    290     emitter.EmitComment('Change stride because storing in two ops.')
    291     emitter.EmitSub(results_stride, results_stride,
    292                     emitter.ImmediateConstant(8))
    293 
    294   emitter.EmitNewline()
    295   emitter.EmitComment('Horizontal reduce aggregators.')
    296   for aggregator in aggregators:
    297     emitter.EmitVPadd('u32', registers.Low(aggregator),
    298                       registers.Low(aggregator), registers.High(aggregator))
    299 
    300   emitter.EmitNewline()
    301   emitter.EmitComment('Reduce rows.')
    302   row_temps = []
    303   for i in range(0, rows):
    304     row_temps.append(ReduceAggregator(emitter, registers, aggregators, i, cols))
    305 
    306   if lhs_add:
    307     emitter.EmitNewline()
    308     emitter.EmitComment('Add lhs offsets to aggregated rows.')
    309     for (row_temp, left_offset) in zip(row_temps, left_offsets):
    310       emitter.EmitVAdd('s32', row_temp, row_temp, left_offset)
    311 
    312   if rhs_add:
    313     emitter.EmitNewline()
    314     emitter.EmitComment('Add rhs offset to aggregated rows.')
    315     for row_temp in row_temps:
    316       emitter.EmitVAdd('s32', row_temp, row_temp, right_offset)
    317 
    318   if result_type is 'float':
    319     emitter.EmitNewline()
    320     emitter.EmitComment('Convert to float. Multiply by result scale.')
    321     for row_temp in row_temps:
    322       emitter.EmitVCvt('f32', 's32', row_temp, row_temp)
    323     for row_temp in row_temps:
    324       emitter.EmitVMul('f32', row_temp, row_temp, result_scale)
    325 
    326   emitter.EmitNewline()
    327   emitter.EmitComment('Store reduced rows.')
    328   for row_temp in row_temps:
    329     StoreAggregator(emitter, registers, row_temp, cols, results, results_stride)
    330 
    331 
    332 def BuildName(result_type, lhs_add, rhs_add, left, right):
    333   name = 'mul_%dx8_%dx8_%s' % (left, right, result_type)
    334   if lhs_add:
    335     name += '_lhsadd'
    336   if rhs_add:
    337     name += '_rhsadd'
    338   return name
    339 
    340 
    341 def CppResultType(result_type):
    342   if result_type is 'int32':
    343     return 'std::int32_t*'
    344   elif result_type is 'float':
    345     return 'float*'
    346   else:
    347     raise ConfigurationError('Unsupported result type: %s' % result_type)
    348 
    349 
    350 def GetParameters(result_type):
    351   params = [['const std::uint8_t*', 'lhs'], ['const std::uint8_t*', 'rhs'],
    352             ['std::int32_t', 'count'], [CppResultType(result_type), 'result'],
    353             ['std::int32_t', 'result_stride']]
    354   if result_type is 'float':
    355     params.append(['float', 'result_scale'])
    356   return params
    357 
    358 
    359 def GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, left_lanes_count,
    360                       right_lanes_count):
    361   """Emit the multiply code for given rows and cols counts."""
    362   if left_lanes_count < 1 or left_lanes_count > 3:
    363     raise ConfigurationError('Left_lanes should be: 1, 2 or 3.')
    364   if right_lanes_count < 1 or right_lanes_count > 3:
    365     raise ConfigurationError('Right_lanes should be: 1, 2 or 3.')
    366 
    367   emitter.EmitFunctionBeginA(
    368       BuildName(result_type, lhs_add, rhs_add, left_lanes_count,
    369                 right_lanes_count), GetParameters(result_type), 'inline void')
    370 
    371   emitter.EmitAssert('count % 8 == 0')
    372   emitter.EmitAssert('count >= 8')
    373   emitter.EmitAsmBegin()
    374 
    375   registers = neon_emitter.NeonRegisters()
    376 
    377   count = registers.MapParameter('count')
    378 
    379   size = left_lanes_count * right_lanes_count
    380 
    381   if size < 9:
    382     aggregators = GenerateAndClearAggregators(emitter, registers, size)
    383 
    384     left_lanes = GenerateMulLanes(registers, left_lanes_count,
    385                                   registers.MapParameter('lhs'))
    386     right_lanes = GenerateMulLanes(registers, right_lanes_count,
    387                                    registers.MapParameter('rhs'))
    388 
    389     emitter.EmitPld(left_lanes.input_address)
    390     emitter.EmitPld(right_lanes.input_address)
    391 
    392     GenerateNxMLoadMultiplyAggregate(emitter, registers, left_lanes,
    393                                      right_lanes, aggregators, count)
    394 
    395   else:  # left == 3 and right == 3
    396     aggregators = GenerateAndClearAggregators(emitter, registers, size)
    397     backup_register = registers.QuadRegister()
    398     left_lanes = Generate3MulLanes(backup_register, registers,
    399                                    registers.MapParameter('lhs'))
    400     right_lanes = GenerateMulLanes(registers, right_lanes_count,
    401                                    registers.MapParameter('rhs'))
    402 
    403     emitter.EmitPld(left_lanes.input_address)
    404     emitter.EmitPld(right_lanes.input_address)
    405 
    406     Generate3x3LoadMultiplyAggregate(emitter, registers, left_lanes,
    407                                      right_lanes, aggregators, count,
    408                                      backup_register)
    409 
    410   left_lanes.FreeRegisters(registers)
    411   right_lanes.FreeRegisters(registers)
    412 
    413   GenerateAggregatorReduceStore(emitter, registers, aggregators, result_type,
    414                                 lhs_add, rhs_add, left_lanes, right_lanes,
    415                                 registers.MapParameter('result'),
    416                                 registers.MapParameter('result_stride'))
    417 
    418   emitter.EmitAsmEnd(registers.MappedParameters(), [],
    419                      registers.Clobbers() + ['cc', 'memory'])
    420   emitter.EmitFunctionEnd()
    421 
    422 
    423 def GenerateFunctions(emitter, result_type, lhs_add, rhs_add):
    424   for left_lanes in range(1, 4):
    425     for right_lanes in range(1, 4):
    426       GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, left_lanes,
    427                         right_lanes)
    428       emitter.EmitNewline()
    429