Home | History | Annotate | Download | only in generators
      1 """Zip primitive used by the GEMM function.
      2 
      3 Takes 1 to 3 rows of data and interleaves them in 8 byte chunks. Pads to
      4 multiply of 8 length with zeros. Calculates row sums and appends those at the
      5 end.
      6 """
      7 
      8 
      9 import neon_emitter
     10 
     11 
     12 class Error(Exception):
     13   """Module level error."""
     14 
     15 
     16 class ConfigurationError(Error):
     17   """Unsupported configuration."""
     18 
     19 
     20 class ZipLane(object):
     21 
     22   def __init__(self, input_address, load, aggregator):
     23     self.input_address = input_address
     24     self.load = load
     25     self.aggregator = aggregator
     26 
     27 
     28 def GenerateZipLanes(emitter, registers, zip_lanes, input_address, stride):
     29   """Prepares read lanes for the zip operation.
     30 
     31   Args:
     32     emitter: ARM/NEON emitter.
     33     registers: ARM/NEON registers state.
     34     zip_lanes: number of lanes to prepare.
     35     input_address: register that contains the input address for the first lane.
     36     stride: memory stride for lane inputs.
     37 
     38   Returns:
     39     Array of ZipLane objects.
     40   """
     41   lanes = []
     42   last_address_register = input_address
     43   for i in range(0, zip_lanes):
     44     if not i:
     45       lanes.append(ZipLane(input_address,
     46                            registers.DoubleRegister(),
     47                            registers.QuadRegister(2)))
     48     else:
     49       address_register = registers.GeneralRegister()
     50       lanes.append(ZipLane(address_register,
     51                            registers.DoubleRegister(),
     52                            registers.QuadRegister(2)))
     53       emitter.EmitAdd(address_register, last_address_register, stride)
     54       last_address_register = address_register
     55   return lanes
     56 
     57 
     58 def BuildName(zip_lanes, leftovers, aligned):
     59   name = 'zip_%dx8' % zip_lanes
     60   if leftovers:
     61     name += '_%d' % leftovers
     62   if aligned:
     63     name += '_aligned'
     64   return name
     65 
     66 
     67 def GenerateClearAggregators(emitter, lanes):
     68   for lane in lanes:
     69     emitter.EmitVMov('i16', lane.aggregator, emitter.ImmediateConstant(0))
     70 
     71 
     72 def GenerateLoadAggregateStore(emitter, lanes, output_address, alignment):
     73   """Emit inner loop code for reading N lanes and interweaving them."""
     74   emitter.EmitNewline()
     75   emitter.EmitComment('Load Aggregate Store.')
     76 
     77   for lane in lanes:
     78     emitter.EmitVLoad(
     79         '1.8', lane.load,
     80         emitter.DereferenceIncrement(lane.input_address, alignment))
     81 
     82   store_registers = []
     83   for lane in lanes:
     84     emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load)
     85     store_registers.append(lane.load)
     86 
     87   emitter.EmitVStoreA('1.8', store_registers,
     88                       emitter.DereferenceIncrement(output_address, 64))
     89 
     90 
     91 def GenerateLeftoverLoadAggregateStore(
     92     emitter, leftovers, lanes, output_address):
     93   """Handle leftovers when count is not a multiply of 8."""
     94   emitter.EmitNewline()
     95   emitter.EmitComment('Leftover Load Aggregate Store.')
     96 
     97   # Clear load registers.
     98   for lane in lanes:
     99     emitter.EmitVMov('i8', lane.load, emitter.ImmediateConstant(0))
    100 
    101   if leftovers == 1:
    102     # Load 8 bits.
    103     for lane in lanes:
    104       emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 0),
    105                         emitter.Dereference(lane.input_address, None))
    106   elif leftovers == 2:
    107     # Load 16 bits.
    108     for lane in lanes:
    109       emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 0),
    110                         emitter.Dereference(lane.input_address, None))
    111   elif leftovers == 3:
    112     # Load 16 bits.
    113     for lane in lanes:
    114       emitter.EmitVLoad(
    115           '1.16', emitter.Lane(lane.load, 0),
    116           emitter.DereferenceIncrement(lane.input_address, None))
    117     # Load 8 bits.
    118     for lane in lanes:
    119       emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 2),
    120                         emitter.Dereference(lane.input_address, None))
    121   elif leftovers == 4:
    122     # Load 32 bits.
    123     for lane in lanes:
    124       emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0),
    125                         emitter.Dereference(lane.input_address, None))
    126   elif leftovers == 5:
    127     # Load 32 bits..
    128     for lane in lanes:
    129       emitter.EmitVLoad(
    130           '1.32', emitter.Lane(lane.load, 0),
    131           emitter.DereferenceIncrement(lane.input_address, None))
    132     # Load 8 bits.
    133     for lane in lanes:
    134       emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 4),
    135                         emitter.Dereference(lane.input_address, None))
    136   elif leftovers == 6:
    137     # Load 32 bits..
    138     for lane in lanes:
    139       emitter.EmitVLoad(
    140           '1.32', emitter.Lane(lane.load, 0),
    141           emitter.DereferenceIncrement(lane.input_address, None))
    142     # Load 16 bits.
    143     for lane in lanes:
    144       emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 2),
    145                         emitter.Dereference(lane.input_address, None))
    146   elif leftovers == 7:
    147     # Load 32 bits..
    148     for lane in lanes:
    149       emitter.EmitVLoad(
    150           '1.32', emitter.Lane(lane.load, 0),
    151           emitter.DereferenceIncrement(lane.input_address, None))
    152     # Load 16 bits.
    153     for lane in lanes:
    154       emitter.EmitVLoad(
    155           '1.16', emitter.Lane(lane.load, 2),
    156           emitter.DereferenceIncrement(lane.input_address, None))
    157     # Load 8 bits.
    158     for lane in lanes:
    159       emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 6),
    160                         emitter.Dereference(lane.input_address, None))
    161   else:
    162     raise ConfigurationError('Unsupported leftover num: %d' % leftovers)
    163 
    164   # Aggregate.
    165   store_registers = []
    166   for lane in lanes:
    167     emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load)
    168     store_registers.append(lane.load)
    169 
    170   # Store.
    171   emitter.EmitVStoreA('1.8', store_registers,
    172                       emitter.DereferenceIncrement(output_address, 64))
    173 
    174 
    175 def GenerateAggregatorReduction(emitter,
    176                                 registers,
    177                                 lanes,
    178                                 output_address,
    179                                 multiplicative_offset,
    180                                 additive_offset):
    181   """Reduce 4 lane sum aggregators to 1 value and store the sums."""
    182   emitter.EmitNewline()
    183   emitter.EmitComment('Aggregator Reduction.')
    184 
    185   multiplier = registers.DoubleRegister()
    186   emitter.EmitVMov('32', emitter.Lane(multiplier, 0), multiplicative_offset)
    187   offset = registers.QuadRegister()
    188   emitter.EmitVDup('32', offset, additive_offset)
    189 
    190   lane_temps = []
    191   for lane in lanes:
    192     emitter.EmitVPaddl('u16', lane.aggregator, lane.aggregator)
    193 
    194   for lane in lanes:
    195     lane_temp = registers.DoubleRegister()
    196     lane_temps.append(lane_temp)
    197     emitter.EmitVPadd('u32',
    198                       lane_temp,
    199                       registers.Low(lane.aggregator),
    200                       registers.High(lane.aggregator))
    201 
    202   temp = registers.QuadRegister()
    203   low = registers.Low(temp)
    204   high = registers.High(temp)
    205 
    206   if len(lanes) == 1:
    207     emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[0])
    208   elif len(lanes) == 2:
    209     emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1])
    210   elif len(lanes) == 3:
    211     emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1])
    212     emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[2])
    213   elif len(lanes) == 4:
    214     emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1])
    215     emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[3])
    216   else:
    217     raise ConfigurationError(
    218         'Unexpected number of aggregators to reduce: %d' % len(lanes))
    219 
    220   emitter.EmitVMul('i32', temp, temp, emitter.Lane(multiplier, 0))
    221   emitter.EmitVAdd('i32', temp, temp, offset)
    222 
    223   if len(lanes) == 1:
    224     emitter.EmitVStore(
    225         '1.32', emitter.Lane(low, 0), emitter.Dereference(output_address, None))
    226   elif len(lanes) == 2:
    227     emitter.EmitVStore('1.32', low, emitter.Dereference(output_address, 64))
    228   elif len(lanes) == 3:
    229     emitter.EmitVStore(
    230         '1.32', low, emitter.DereferenceIncrement(output_address, 64))
    231     emitter.EmitVStore(
    232         '1.32', emitter.Lane(high, 0),
    233         emitter.Dereference(output_address, None))
    234   elif len(lanes) == 4:
    235     emitter.EmitVStore(
    236         '1.32', low, emitter.DereferenceIncrement(output_address, 64))
    237     emitter.EmitVStore('1.32', high, emitter.Dereference(output_address, 64))
    238 
    239 
    240 def GenerateZipNx8(emitter, zip_lanes, leftovers, aligned):
    241   """Emit the zip function for a given number of rows and row size leftovers."""
    242   if leftovers < 0 or leftovers > 7:
    243     raise ConfigurationError('Leftovers should be between 0 and 7 inclusive.')
    244   if zip_lanes < 1 or zip_lanes > 3:
    245     raise ConfigurationError('Zip_lanes should should be 1, 2 or 3.')
    246 
    247   name = BuildName(zip_lanes, leftovers, aligned)
    248 
    249   emitter.EmitFunctionBeginA(name,
    250                              [['const std::uint8_t*', 'source'],
    251                               ['std::int32_t', 'count'],
    252                               ['std::int32_t', 'stride'],
    253                               ['std::uint8_t*', 'destination'],
    254                               ['std::int32_t', 'multiplicative_offset'],
    255                               ['std::int32_t', 'additive_offset']],
    256                              'void')
    257   emitter.EmitAssert('count %% 8 == %d' % leftovers)
    258   emitter.EmitAssert('count <= 2048')
    259   emitter.EmitAssert('count >= 8')
    260   emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(destination) % 8 == 0')
    261   if aligned:
    262     emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(source) % 8 == 0')
    263     if zip_lanes > 1:
    264       emitter.EmitAssert('stride % 8 == 0')
    265   emitter.EmitAsmBegin()
    266 
    267   registers = neon_emitter.NeonRegisters()
    268 
    269   count = registers.MapParameter('count')
    270   output_address = registers.MapParameter('destination')
    271 
    272   lanes = GenerateZipLanes(emitter,
    273                            registers,
    274                            zip_lanes,
    275                            registers.MapParameter('source'),
    276                            registers.MapParameter('stride'))
    277 
    278   if leftovers:
    279     emitter.EmitSub(count, count, emitter.ImmediateConstant(leftovers))
    280 
    281   GenerateClearAggregators(emitter, lanes)
    282 
    283   emitter.EmitNewline()
    284   emitter.EmitNumericalLabel(1)
    285   emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
    286 
    287   GenerateLoadAggregateStore(
    288       emitter, lanes, output_address, 64 if aligned else None)
    289 
    290   emitter.EmitNewline()
    291   emitter.EmitBneBack(1)
    292 
    293   if leftovers:
    294     GenerateLeftoverLoadAggregateStore(
    295         emitter, leftovers, lanes, output_address)
    296 
    297   GenerateAggregatorReduction(emitter,
    298                               registers,
    299                               lanes,
    300                               output_address,
    301                               registers.MapParameter('multiplicative_offset'),
    302                               registers.MapParameter('additive_offset'))
    303 
    304   emitter.EmitAsmEnd(registers.MappedParameters(),
    305                      [],
    306                      registers.Clobbers() + ['cc', 'memory'])
    307   emitter.EmitFunctionEnd()
    308 
    309 
    310 def GenerateFunctions(emitter):
    311   for aligned in [True, False]:
    312     for lanes in range(1, 4):
    313       for leftovers in range(0, 8):
    314         GenerateZipNx8(emitter, lanes, leftovers, aligned)
    315         emitter.EmitNewline()
    316