Home | History | Annotate | Download | only in generators
      1 """Multiply primitive optimized for the gemv operation."""
      2 
      3 import neon_emitter
      4 
      5 
      6 class Error(Exception):
      7   """Module level error."""
      8 
      9 
     10 class ConfigurationError(Error):
     11   """Unsupported configuration."""
     12 
     13 
     14 def GenerateLoadMultiplyAggregate(emitter, registers, lanes_count, aggregators,
     15                                   count, lhs, rhs_1, rhs_2):
     16   """Emit inner loop for 1 row x M cols multiplication."""
     17   emitter.EmitComment('General 1xM lanes loop.')
     18   emitter.EmitNumericalLabel(1)
     19   emitter.EmitNewline()
     20   emitter.EmitComment('Subtract counter.')
     21   emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
     22   emitter.EmitNewline()
     23 
     24   right_load = [registers.DoubleRegister() for unused_i in range(4)]
     25   left_load = registers.DoubleRegister()
     26 
     27   emitter.EmitVLoad('1.8', left_load, emitter.DereferenceIncrement(lhs, 64))
     28   emitter.EmitVLoadA('1.8', right_load, emitter.DereferenceIncrement(rhs_1, 64))
     29 
     30   emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64))
     31   emitter.EmitPldOffset(rhs_1, emitter.ImmediateConstant(128))
     32 
     33   multiply_results = [registers.QuadRegister() for unused_i in range(4)]
     34 
     35   for i in range(4):
     36     emitter.EmitVMull('u8', multiply_results[i], right_load[i], left_load)
     37 
     38   emitter.EmitVLoadA('1.8', right_load[:lanes_count],
     39                      emitter.DereferenceIncrement(rhs_2, 64))
     40   emitter.EmitPldOffset(rhs_2, emitter.ImmediateConstant(lanes_count * 32))
     41 
     42   for i in range(4):
     43     emitter.EmitVPadal('u16', aggregators[i], multiply_results[i])
     44 
     45   for i in range(lanes_count):
     46     emitter.EmitVMull('u8', multiply_results[i], right_load[i], left_load)
     47 
     48   for i in range(lanes_count):
     49     emitter.EmitVPadal('u16', aggregators[i + 4], multiply_results[i])
     50 
     51   emitter.EmitNewline()
     52   emitter.EmitComment('Loop break.')
     53   emitter.EmitBneBack(1)
     54   emitter.EmitNewline()
     55 
     56   registers.FreeRegister(left_load)
     57   registers.FreeRegisters(right_load)
     58   registers.FreeRegisters(multiply_results)
     59 
     60 
     61 def ReadLeft(emitter, registers, lhs):
     62   register = registers.QuadRegister()
     63   emitter.EmitVLoadA('1.32', [emitter.AllLanes(registers.Low(register)),
     64                               emitter.AllLanes(registers.High(register))],
     65                      emitter.Dereference(lhs, None))
     66   return register
     67 
     68 
     69 def ReadRight(emitter, registers, rhs, count):
     70   if count == 1 or count == 2:
     71     register = registers.DoubleRegister()
     72   elif count == 3 or count == 4:
     73     register = registers.QuadRegister()
     74   else:
     75     raise ConfigurationError('Unsupported elements no: %d' % count)
     76   emitter.EmitVLoad('1.32', register, emitter.Dereference(rhs, 64))
     77   return register
     78 
     79 
     80 def DuplicateGeneralRegister(emitter, registers, general_register,
     81                              min_register):
     82   duplicated = registers.QuadRegister(min_register)
     83   emitter.EmitVDup('32', duplicated, general_register)
     84   return duplicated
     85 
     86 
     87 def GenerateAggregatorReduceStore(emitter, registers, lanes_count, aggregators,
     88                                   result_type, lhs_add, rhs_add, lhs, rhs_1,
     89                                   rhs_2, results):
     90   """Generates assembly responsible for reducing the 4 way aggregators."""
     91   if lhs_add:
     92     left_offset = ReadLeft(emitter, registers, lhs)
     93   else:
     94     left_offset = None
     95 
     96   if rhs_add:
     97     right_offset_1 = ReadRight(emitter, registers, rhs_1, 4)
     98     right_offset_2 = ReadRight(emitter, registers, rhs_2, lanes_count)
     99   else:
    100     right_offset_1 = None
    101     right_offset_2 = None
    102 
    103   if result_type is 'float':
    104     result_scale = DuplicateGeneralRegister(
    105         emitter, registers, registers.MapParameter('result_scale'), 4)
    106   else:
    107     result_scale = None
    108 
    109   emitter.EmitNewline()
    110   emitter.EmitComment('Horizontal reduce aggregators.')
    111   for aggregator in aggregators:
    112     emitter.EmitVPadd('u32', registers.Low(aggregator),
    113                       registers.Low(aggregator), registers.High(aggregator))
    114 
    115   temp = aggregators[0]
    116   emitter.EmitVPadd('u32', registers.Low(temp), registers.Low(aggregators[0]),
    117                     registers.Low(aggregators[1]))
    118   emitter.EmitVPadd('u32', registers.High(temp), registers.Low(aggregators[2]),
    119                     registers.Low(aggregators[3]))
    120 
    121   if lanes_count == 1:
    122     temp_2 = registers.Low(aggregators[1])
    123     emitter.EmitVPadd('u32', temp_2, registers.Low(aggregators[4]),
    124                       registers.Low(aggregators[4]))
    125   elif lanes_count == 2:
    126     temp_2 = registers.Low(aggregators[1])
    127     emitter.EmitVPadd('u32', temp_2, registers.Low(aggregators[4]),
    128                       registers.Low(aggregators[5]))
    129   elif lanes_count == 3:
    130     temp_2 = aggregators[1]
    131     emitter.EmitVPadd('u32', registers.Low(temp_2),
    132                       registers.Low(aggregators[4]),
    133                       registers.Low(aggregators[5]))
    134     emitter.EmitVPadd('u32', registers.High(temp_2),
    135                       registers.Low(aggregators[6]),
    136                       registers.Low(aggregators[6]))
    137   elif lanes_count == 4:
    138     temp_2 = aggregators[1]
    139     emitter.EmitVPadd('u32', registers.Low(temp_2),
    140                       registers.Low(aggregators[4]),
    141                       registers.Low(aggregators[5]))
    142     emitter.EmitVPadd('u32', registers.High(temp_2),
    143                       registers.Low(aggregators[6]),
    144                       registers.Low(aggregators[7]))
    145   else:
    146     temp_2 = None
    147 
    148   if lhs_add:
    149     emitter.EmitNewline()
    150     emitter.EmitComment('Add lhs offsets to aggregated rows.')
    151     emitter.EmitVAdd('s32', temp, temp, left_offset)
    152     if lanes_count == 1 or lanes_count == 2:
    153       emitter.EmitVAdd('s32', temp_2, temp_2, registers.Low(left_offset))
    154     elif lanes_count == 3 or lanes_count == 4:
    155       emitter.EmitVAdd('s32', temp_2, temp_2, left_offset)
    156 
    157   if rhs_add:
    158     emitter.EmitNewline()
    159     emitter.EmitComment('Add rhs offset to aggregated rows.')
    160     emitter.EmitVAdd('s32', temp, temp, right_offset_1)
    161     emitter.EmitVAdd('s32', temp_2, temp_2, right_offset_2)
    162 
    163   if result_type is 'float':
    164     emitter.EmitNewline()
    165     emitter.EmitComment('Convert to float and scale.')
    166     emitter.EmitVCvt('f32', 's32', temp, temp)
    167     emitter.EmitVCvt('f32', 's32', temp_2, temp_2)
    168     emitter.EmitVMul('f32', temp, temp, result_scale)
    169     if lanes_count == 1 or lanes_count == 2:
    170       emitter.EmitVMul('f32', temp_2, temp_2, registers.Low(result_scale))
    171     elif lanes_count == 3 or lanes_count == 4:
    172       emitter.EmitVMul('f32', temp_2, temp_2, result_scale)
    173 
    174   emitter.EmitNewline()
    175   emitter.EmitComment('Store results.')
    176   if lanes_count == 1:
    177     emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp)],
    178                         emitter.DereferenceIncrement(results, None))
    179     emitter.EmitVStore('1.32', emitter.Lane(temp_2, 0),
    180                        emitter.Dereference(results, None))
    181   elif lanes_count == 2:
    182     emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp),
    183                                  temp_2], emitter.Dereference(results, None))
    184   elif lanes_count == 3:
    185     emitter.EmitVStoreA(
    186         '1.32',
    187         [registers.Low(temp), registers.High(temp), registers.Low(temp_2)],
    188         emitter.DereferenceIncrement(results, None))
    189     emitter.EmitVStore('1.32', emitter.Lane(
    190         registers.High(temp_2), 0), emitter.Dereference(results, None))
    191   elif lanes_count == 4:
    192     emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp),
    193                                  registers.Low(temp_2), registers.High(temp_2)],
    194                         emitter.Dereference(results, None))
    195 
    196 
    197 def BuildName(result_type, lhs_add, rhs_add, lanes):
    198   name = 'mul_1x8_%dx8_%s' % (lanes, result_type)
    199   if lhs_add:
    200     name += '_lhsadd'
    201   if rhs_add:
    202     name += '_rhsadd'
    203   return name
    204 
    205 
    206 def CppResultType(result_type):
    207   if result_type is 'int32':
    208     return 'std::int32_t*'
    209   elif result_type is 'float':
    210     return 'float*'
    211   else:
    212     raise ConfigurationError('Unsupported result type: %s' % result_type)
    213 
    214 
    215 def GetParameters(result_type):
    216   params = [['const std::uint8_t*', 'lhs'], ['const std::uint8_t*', 'rhs_1'],
    217             ['const std::uint8_t*', 'rhs_2'], ['std::int32_t', 'count'],
    218             [CppResultType(result_type), 'result']]
    219   if result_type is 'float':
    220     params.append(['float', 'result_scale'])
    221   return params
    222 
    223 
    224 def GenerateAndClearAggregators(emitter, registers, aggregator_count):
    225   """Prepare aggregators and emit aggregator clear code."""
    226   emitter.EmitNewline()
    227   emitter.EmitComment('Clear aggregators.')
    228   aggregators = []
    229   for i in range(aggregator_count):
    230     aggregator = registers.QuadRegister()
    231     aggregators.append(aggregator)
    232     if i < 3:
    233       emitter.EmitVMov('i32', aggregator, emitter.ImmediateConstant(0))
    234     else:
    235       emitter.EmitVMov('i32', aggregator, aggregators[i - 3])
    236   emitter.EmitNewline()
    237   return aggregators
    238 
    239 
    240 def GenerateMul1x8Mx8(emitter, result_type, lhs_add, rhs_add, lanes_count):
    241   """Generates the 1xN multiplication primitive."""
    242   if lanes_count < 1 or lanes_count > 4:
    243     raise ConfigurationError('Lanes should be: 1, 2, 3 or 4.')
    244 
    245   emitter.EmitFunctionBeginA(
    246       BuildName(result_type, lhs_add, rhs_add, lanes_count + 4),
    247       GetParameters(result_type), 'inline void')
    248 
    249   emitter.EmitAssert('count % 8 == 0')
    250   emitter.EmitAssert('count >= 8')
    251   emitter.EmitAsmBegin()
    252 
    253   registers = neon_emitter.NeonRegisters()
    254 
    255   count = registers.MapParameter('count')
    256 
    257   lhs = registers.MapParameter('lhs')
    258   rhs_1 = registers.MapParameter('rhs_1')
    259   rhs_2 = registers.MapParameter('rhs_2')
    260 
    261   emitter.EmitPld(lhs)
    262   emitter.EmitPld(rhs_1)
    263   emitter.EmitPld(rhs_2)
    264 
    265   aggregators = GenerateAndClearAggregators(emitter, registers, lanes_count + 4)
    266 
    267   GenerateLoadMultiplyAggregate(emitter, registers, lanes_count, aggregators,
    268                                 count, lhs, rhs_1, rhs_2)
    269   GenerateAggregatorReduceStore(emitter, registers, lanes_count, aggregators,
    270                                 result_type, lhs_add, rhs_add, lhs, rhs_1,
    271                                 rhs_2, registers.MapParameter('result'))
    272 
    273   emitter.EmitAsmEnd(registers.MappedParameters(), [],
    274                      registers.Clobbers() + ['cc', 'memory'])
    275   emitter.EmitFunctionEnd()
    276 
    277 
    278 def GenerateFunctions(emitter, result_type, lhs_add, rhs_add):
    279   for lanes in range(1, 5):
    280     GenerateMul1x8Mx8(emitter, result_type, lhs_add, rhs_add, lanes)
    281     emitter.EmitNewline()
    282 
    283 
    284 if __name__ == '__main__':
    285   GenerateFunctions(neon_emitter.NeonEmitter(), 'int32', True, True)
    286