Home | History | Annotate | Download | only in generators
      1 """Zip primitive used by the GEMM function.
      2 
      3 Takes 1 to 3 rows of data and interleaves them in 8 byte chunks. Pads to
      4 multiply of 8 length with zeros. Calculates row sums and appends those at the
      5 end.
      6 """
      7 
      8 import neon_emitter
      9 
     10 
     11 class Error(Exception):
     12   """Module level error."""
     13 
     14 
     15 class ConfigurationError(Error):
     16   """Unsupported configuration."""
     17 
     18 
     19 class ZipLane(object):
     20 
     21   def __init__(self, input_address, load, aggregator):
     22     self.input_address = input_address
     23     self.load = load
     24     self.aggregator = aggregator
     25 
     26 
     27 def GenerateZipLanes(emitter, registers, zip_lanes, input_address, stride):
     28   """Prepares read lanes for the zip operation.
     29 
     30   Args:
     31     emitter: ARM/NEON emitter.
     32     registers: ARM/NEON registers state.
     33     zip_lanes: number of lanes to prepare.
     34     input_address: register that contains the input address for the first lane.
     35     stride: memory stride for lane inputs.
     36 
     37   Returns:
     38     Array of ZipLane objects.
     39   """
     40   lanes = []
     41   last_address_register = input_address
     42   for i in range(0, zip_lanes):
     43     if not i:
     44       lanes.append(ZipLane(input_address, registers.DoubleRegister(),
     45                            registers.QuadRegister(2)))
     46     else:
     47       address_register = registers.GeneralRegister()
     48       lanes.append(ZipLane(address_register, registers.DoubleRegister(),
     49                            registers.QuadRegister(2)))
     50       emitter.EmitAdd(address_register, last_address_register, stride)
     51       last_address_register = address_register
     52   return lanes
     53 
     54 
     55 def BuildName(zip_lanes, leftovers, aligned):
     56   name = 'zip_%dx8' % zip_lanes
     57   if leftovers:
     58     name += '_%d' % leftovers
     59   if aligned:
     60     name += '_aligned'
     61   return name
     62 
     63 
     64 def GenerateClearAggregators(emitter, lanes):
     65   for lane in lanes:
     66     emitter.EmitVMov('i16', lane.aggregator, emitter.ImmediateConstant(0))
     67 
     68 
     69 def GenerateLoadAggregateStore(emitter, lanes, output_address, alignment):
     70   """Emit inner loop code for reading N lanes and interweaving them."""
     71   emitter.EmitNewline()
     72   emitter.EmitComment('Load Aggregate Store.')
     73 
     74   for lane in lanes:
     75     emitter.EmitVLoad(
     76         '1.8', lane.load,
     77         emitter.DereferenceIncrement(lane.input_address, alignment))
     78 
     79   store_registers = []
     80   for lane in lanes:
     81     emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load)
     82     store_registers.append(lane.load)
     83 
     84   emitter.EmitVStoreA('1.8', store_registers,
     85                       emitter.DereferenceIncrement(output_address, 64))
     86 
     87 
     88 def GenerateLeftoverLoadAggregateStore(emitter, leftovers, lanes,
     89                                        output_address):
     90   """Handle leftovers when count is not a multiply of 8."""
     91   emitter.EmitNewline()
     92   emitter.EmitComment('Leftover Load Aggregate Store.')
     93 
     94   # Clear load registers.
     95   for lane in lanes:
     96     emitter.EmitVMov('i8', lane.load, emitter.ImmediateConstant(0))
     97 
     98   if leftovers == 1:
     99     # Load 8 bits.
    100     for lane in lanes:
    101       emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 0),
    102                         emitter.Dereference(lane.input_address, None))
    103   elif leftovers == 2:
    104     # Load 16 bits.
    105     for lane in lanes:
    106       emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 0),
    107                         emitter.Dereference(lane.input_address, None))
    108   elif leftovers == 3:
    109     # Load 16 bits.
    110     for lane in lanes:
    111       emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 0),
    112                         emitter.DereferenceIncrement(lane.input_address, None))
    113     # Load 8 bits.
    114     for lane in lanes:
    115       emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 2),
    116                         emitter.Dereference(lane.input_address, None))
    117   elif leftovers == 4:
    118     # Load 32 bits.
    119     for lane in lanes:
    120       emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0),
    121                         emitter.Dereference(lane.input_address, None))
    122   elif leftovers == 5:
    123     # Load 32 bits..
    124     for lane in lanes:
    125       emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0),
    126                         emitter.DereferenceIncrement(lane.input_address, None))
    127     # Load 8 bits.
    128     for lane in lanes:
    129       emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 4),
    130                         emitter.Dereference(lane.input_address, None))
    131   elif leftovers == 6:
    132     # Load 32 bits..
    133     for lane in lanes:
    134       emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0),
    135                         emitter.DereferenceIncrement(lane.input_address, None))
    136     # Load 16 bits.
    137     for lane in lanes:
    138       emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 2),
    139                         emitter.Dereference(lane.input_address, None))
    140   elif leftovers == 7:
    141     # Load 32 bits..
    142     for lane in lanes:
    143       emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0),
    144                         emitter.DereferenceIncrement(lane.input_address, None))
    145     # Load 16 bits.
    146     for lane in lanes:
    147       emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 2),
    148                         emitter.DereferenceIncrement(lane.input_address, None))
    149     # Load 8 bits.
    150     for lane in lanes:
    151       emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 6),
    152                         emitter.Dereference(lane.input_address, None))
    153   else:
    154     raise ConfigurationError('Unsupported leftover num: %d' % leftovers)
    155 
    156   # Aggregate.
    157   store_registers = []
    158   for lane in lanes:
    159     emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load)
    160     store_registers.append(lane.load)
    161 
    162   # Store.
    163   emitter.EmitVStoreA('1.8', store_registers,
    164                       emitter.DereferenceIncrement(output_address, 64))
    165 
    166 
    167 def GenerateAggregatorReduction(emitter, registers, lanes, output_address,
    168                                 multiplicative_offset, additive_offset):
    169   """Reduce 4 lane sum aggregators to 1 value and store the sums."""
    170   emitter.EmitNewline()
    171   emitter.EmitComment('Aggregator Reduction.')
    172 
    173   multiplier = registers.DoubleRegister()
    174   emitter.EmitVMov('32', emitter.Lane(multiplier, 0), multiplicative_offset)
    175   offset = registers.QuadRegister()
    176   emitter.EmitVDup('32', offset, additive_offset)
    177 
    178   lane_temps = []
    179   for lane in lanes:
    180     emitter.EmitVPaddl('u16', lane.aggregator, lane.aggregator)
    181 
    182   for lane in lanes:
    183     lane_temp = registers.DoubleRegister()
    184     lane_temps.append(lane_temp)
    185     emitter.EmitVPadd('u32', lane_temp, registers.Low(lane.aggregator),
    186                       registers.High(lane.aggregator))
    187 
    188   temp = registers.QuadRegister()
    189   low = registers.Low(temp)
    190   high = registers.High(temp)
    191 
    192   if len(lanes) == 1:
    193     emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[0])
    194   elif len(lanes) == 2:
    195     emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1])
    196   elif len(lanes) == 3:
    197     emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1])
    198     emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[2])
    199   elif len(lanes) == 4:
    200     emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1])
    201     emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[3])
    202   else:
    203     raise ConfigurationError('Unexpected number of aggregators to reduce: %d' %
    204                              len(lanes))
    205 
    206   emitter.EmitVMul('i32', temp, temp, emitter.Lane(multiplier, 0))
    207   emitter.EmitVAdd('i32', temp, temp, offset)
    208 
    209   if len(lanes) == 1:
    210     emitter.EmitVStore('1.32', emitter.Lane(low, 0),
    211                        emitter.Dereference(output_address, None))
    212   elif len(lanes) == 2:
    213     emitter.EmitVStore('1.32', low, emitter.Dereference(output_address, 64))
    214   elif len(lanes) == 3:
    215     emitter.EmitVStore('1.32', low,
    216                        emitter.DereferenceIncrement(output_address, 64))
    217     emitter.EmitVStore('1.32', emitter.Lane(high, 0),
    218                        emitter.Dereference(output_address, None))
    219   elif len(lanes) == 4:
    220     emitter.EmitVStoreA('1.32', [low, high],
    221                         emitter.DereferenceIncrement(output_address, 64))
    222 
    223 
    224 def GenerateZipNx8(emitter, zip_lanes, leftovers, aligned):
    225   """Emit the zip function for a given number of rows and row size leftovers."""
    226   if leftovers < 0 or leftovers > 7:
    227     raise ConfigurationError('Leftovers should be between 0 and 7 inclusive.')
    228   if zip_lanes < 1 or zip_lanes > 4:
    229     raise ConfigurationError('Zip_lanes should should be 1, 2, 3 or 4.')
    230 
    231   name = BuildName(zip_lanes, leftovers, aligned)
    232 
    233   emitter.EmitFunctionBeginA(
    234       name, [['const std::uint8_t*', 'source'], ['std::int32_t', 'count'],
    235              ['std::int32_t', 'stride'], ['std::uint8_t*', 'destination'],
    236              ['std::int32_t', 'multiplicative_offset'],
    237              ['std::int32_t', 'additive_offset']], 'void')
    238   emitter.EmitAssert('count %% 8 == %d' % leftovers)
    239   emitter.EmitAssert('count <= 2048')
    240   emitter.EmitAssert('count >= 8')
    241   emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(destination) % 8 == 0')
    242   if aligned:
    243     emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(source) % 8 == 0')
    244     if zip_lanes > 1:
    245       emitter.EmitAssert('stride % 8 == 0')
    246   emitter.EmitAsmBegin()
    247 
    248   registers = neon_emitter.NeonRegisters()
    249 
    250   count = registers.MapParameter('count')
    251   output_address = registers.MapParameter('destination')
    252 
    253   lanes = GenerateZipLanes(emitter, registers, zip_lanes,
    254                            registers.MapParameter('source'),
    255                            registers.MapParameter('stride'))
    256 
    257   if leftovers:
    258     emitter.EmitSub(count, count, emitter.ImmediateConstant(leftovers))
    259 
    260   GenerateClearAggregators(emitter, lanes)
    261 
    262   emitter.EmitNewline()
    263   emitter.EmitNumericalLabel(1)
    264   emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
    265 
    266   GenerateLoadAggregateStore(emitter, lanes, output_address, 64 if aligned else
    267                              None)
    268 
    269   emitter.EmitNewline()
    270   emitter.EmitBneBack(1)
    271 
    272   if leftovers:
    273     GenerateLeftoverLoadAggregateStore(emitter, leftovers, lanes,
    274                                        output_address)
    275 
    276   GenerateAggregatorReduction(emitter, registers, lanes, output_address,
    277                               registers.MapParameter('multiplicative_offset'),
    278                               registers.MapParameter('additive_offset'))
    279 
    280   emitter.EmitAsmEnd(registers.MappedParameters(), [],
    281                      registers.Clobbers() + ['cc', 'memory'])
    282   emitter.EmitFunctionEnd()
    283 
    284 
    285 def GenerateFunctions(emitter):
    286   for aligned in [True, False]:
    287     for lanes in range(1, 5):
    288       for leftovers in range(0, 8):
    289         GenerateZipNx8(emitter, lanes, leftovers, aligned)
    290         emitter.EmitNewline()
    291