Home | History | Annotate | Download | only in generators
      1 # Copyright 2016 The Gemmlowp Authors. All rights reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #    http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 """."""
     15 
     16 import common
     17 
     18 
     19 def _AlignForLanes(lanes_count):
     20   if lanes_count is 8 or lanes_count is 4:
     21     return 256
     22   elif lanes_count is 6 or lanes_count is 2:
     23     return 128
     24   else:
     25     return 64
     26 
     27 
     28 def _AlignForSums(lanes_count):
     29   if lanes_count is 8:
     30     return 256
     31   elif lanes_count in [2, 4, 6]:
     32     return 128
     33   else:
     34     return 64
     35 
     36 
     37 def _GenerateInputs(emitter, registers, lanes_count, input_address, stride):
     38   """."""
     39   inputs = []
     40   last_address_register = input_address
     41   for i in range(lanes_count):
     42     if not i:
     43       inputs.append(input_address)
     44     else:
     45       address_register = registers.GeneralRegister()
     46       inputs.append(address_register)
     47       emitter.EmitAdd(address_register, last_address_register, stride)
     48       last_address_register = address_register
     49   return inputs
     50 
     51 
     52 def _GenerateClear(emitter, clear_type, block):
     53   for row in block:
     54     emitter.EmitVMov(clear_type, row, emitter.ImmediateConstant(0))
     55 
     56 
     57 def _GenerateLoadAggregateStore(emitter, registers, lanes_count, elements_count,
     58                                 aggregators, inputs, output):
     59   """Emit inner loop code for reading N lanes and interweaving them."""
     60   emitter.EmitNewline()
     61   emitter.EmitComment('Load Aggregate Store: %dx%d.' % (lanes_count,
     62                                                         elements_count))
     63 
     64   block = [registers.DoubleRegister() for unused_i in range(lanes_count)]
     65 
     66   if elements_count is not 8:
     67     _GenerateClear(emitter, 'i8', block)
     68 
     69   for (row, input_address) in zip(block, inputs):
     70     emitter.EmitVLoadE(8, elements_count, row, input_address, None)
     71 
     72   for (aggregator, row) in zip(aggregators, block):
     73     emitter.EmitVAddw('u8', aggregator, aggregator, row)
     74 
     75   emitter.EmitVStoreAE(8, 8 * lanes_count, block, output,
     76                        _AlignForLanes(lanes_count))
     77 
     78   registers.FreeRegisters(block)
     79 
     80 
     81 def _LoadMemoryParameter(emitter, registers, name, source):
     82   register = registers.GeneralRegister()
     83   emitter.EmitLdr(register, registers.MapMemoryParameter(name, source))
     84   return register
     85 
     86 
     87 def _GenerateAggregatorReductionLowRegisters(emitter, registers,
     88                                              aggregators, output_address):
     89   emitter.EmitNewline()
     90   emitter.EmitComment('Aggregator Reduction.')
     91   _GenerateAggregatorReduction(
     92       emitter, registers, aggregators, output_address,
     93       _LoadMemoryParameter(emitter, registers, 'multiplicative_sum_offset',
     94                            'params.multiplicative_sum_offset'),
     95       _LoadMemoryParameter(emitter, registers, 'additive_sum_offset',
     96                            'params.additive_sum_offset'))
     97 
     98 
     99 def _GenerateAggregatorReductionHighRegisters(emitter, registers,
    100                                               aggregators, output_address):
    101   emitter.EmitNewline()
    102   emitter.EmitComment('Aggregator Reduction.')
    103   _GenerateAggregatorReduction(
    104       emitter, registers, aggregators, output_address,
    105       registers.MapParameter('multiplicative_sum_offset',
    106                              'params.multiplicative_sum_offset'),
    107       registers.MapParameter('additive_sum_offset',
    108                              'params.additive_sum_offset'))
    109 
    110 
    111 def _GenerateAggregatorReduction(emitter, registers, aggregators,
    112                                  output_address, multiplicative_sum_offset,
    113                                  additive_sum_offset):
    114   """Reduce 4 lane sum aggregators to 1 value and store the sums."""
    115   multiplier = registers.DoubleRegister()
    116   emitter.EmitVMov('32',
    117                    emitter.Lane(32, multiplier, 0), multiplicative_sum_offset)
    118 
    119   offset = registers.QuadRegister()
    120   emitter.EmitVDup('32', offset, additive_sum_offset)
    121 
    122   for aggregator in aggregators:
    123     emitter.EmitVPaddl('u16', aggregator, aggregator)
    124 
    125   reduced_count = (len(aggregators) + 3) / 4
    126   reduced = aggregators[:reduced_count]
    127 
    128   emitter.EmitVSumReduce('u32', len(aggregators), 4, reduced, aggregators)
    129 
    130   for temp in reduced:
    131     emitter.EmitVMulScalar('i32', temp, temp, emitter.Lane(32, multiplier, 0))
    132 
    133   for temp in reduced:
    134     emitter.EmitVAdd('i32', temp, temp, offset)
    135 
    136   emitter.EmitVStoreA(1, 32, reduced,
    137                       emitter.Dereference(output_address,
    138                                           _AlignForSums(len(aggregators))))
    139 
    140 
    141 class RowMajorWithSumUInt8x8(common.StreamGenerator):
    142   """."""
    143 
    144   def __init__(self, emitter, asm_emitter):
    145     common.StreamGenerator.__init__(self, emitter, 'RowMajorWithSum')
    146     self.asm_emitter = asm_emitter
    147 
    148   def EmitPack(self, in_type, lanes_count, pack_size, leftovers):
    149     assert pack_size is 8
    150     assert in_type is 'uint8_t'
    151 
    152     registers = self.asm_emitter.CreateRegisters()
    153 
    154     self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count')
    155 
    156     self.asm_emitter.PushIndent(self.emitter.indent)
    157     self.asm_emitter.EmitAsmBegin()
    158 
    159     count = registers.MapOutputParameter('count', 'params_count_copy')
    160     output = registers.MapOutputParameter('out')
    161     inputs = _GenerateInputs(self.asm_emitter, registers, lanes_count,
    162                              registers.MapOutputParameter('in'),
    163                              registers.MapParameter('stride', 'params.stride'))
    164     aggregators = [registers.QuadRegister(8) for unused_i in range(lanes_count)]
    165 
    166     _GenerateClear(self.asm_emitter, 'i16', aggregators)
    167 
    168     if leftovers:
    169       self.asm_emitter.EmitNewline()
    170       self.asm_emitter.EmitComment('Reduce count by leftovers.')
    171       self.asm_emitter.EmitSubs(count, count,
    172                                 self.asm_emitter.ImmediateConstant(leftovers))
    173       self.asm_emitter.EmitBeqFront(2)
    174 
    175     self.asm_emitter.EmitNewline()
    176     self.asm_emitter.EmitNumericalLabel(1)
    177     self.asm_emitter.EmitSubs(count, count,
    178                               self.asm_emitter.ImmediateConstant(8))
    179 
    180     _GenerateLoadAggregateStore(self.asm_emitter, registers, lanes_count, 8,
    181                                 aggregators, inputs, output)
    182 
    183     self.asm_emitter.EmitNewline()
    184     self.asm_emitter.EmitBneBack(1)
    185 
    186     if leftovers:
    187       self.asm_emitter.EmitNewline()
    188       self.asm_emitter.EmitNumericalLabel(2)
    189       _GenerateLoadAggregateStore(self.asm_emitter, registers, lanes_count,
    190                                   leftovers, aggregators, inputs, output)
    191 
    192     registers.FreeRegisters(inputs)
    193 
    194     if len(inputs) <= 6:
    195       _GenerateAggregatorReductionHighRegisters(
    196           self.asm_emitter, registers, aggregators, output)
    197     else:
    198       _GenerateAggregatorReductionLowRegisters(
    199           self.asm_emitter, registers, aggregators, output)
    200 
    201     self.asm_emitter.EmitAsmEnd(registers)
    202     self.asm_emitter.PopIndent(len(self.emitter.indent))
    203 
    204 
    205 def _GenerateColLoadAggregateStore(emitter, registers, lanes_count,
    206                                    elements_count, aggregators, input_address,
    207                                    stride, output):
    208   """Emit inner loop code for reading N col lanes and interweaving them."""
    209   emitter.EmitNewline()
    210   emitter.EmitComment('Load Aggregate Store - column major %dx%d' %
    211                       (lanes_count, elements_count))
    212 
    213   block = [registers.DoubleRegister() for unused_i in range(lanes_count)]
    214 
    215   if elements_count is not 8:
    216     _GenerateClear(emitter, 'i8', block)
    217 
    218   block = emitter.EmitLoadColBlock(registers, 8, lanes_count, elements_count,
    219                                    block, input_address, stride)
    220 
    221   for (aggregator, row) in zip(aggregators, block):
    222     emitter.EmitVAddw('u8', aggregator, aggregator, row)
    223 
    224   emitter.EmitVStoreAE(8, 8 * lanes_count, block, output,
    225                        _AlignForLanes(lanes_count))
    226 
    227   registers.FreeRegisters(block)
    228 
    229 
    230 class ColumnMajorWithSumUInt8x8(common.StreamGenerator):
    231   """."""
    232 
    233   def __init__(self, emitter, asm_emitter):
    234     common.StreamGenerator.__init__(self, emitter, 'ColumnMajorWithSum')
    235     self.asm_emitter = asm_emitter
    236 
    237   def EmitPack(self, in_type, lanes_count, pack_size, leftovers):
    238     assert pack_size is 8
    239     assert in_type is 'uint8_t'
    240 
    241     registers = self.asm_emitter.CreateRegisters()
    242 
    243     self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count')
    244     self.emitter.EmitDeclare('int', 'params_stride_copy', 'params.stride')
    245 
    246     self.asm_emitter.PushIndent(self.emitter.indent)
    247     self.asm_emitter.EmitAsmBegin()
    248 
    249     count = registers.MapOutputParameter('count', 'params_count_copy')
    250     input_address = registers.MapOutputParameter('in')
    251     output_address = registers.MapOutputParameter('out')
    252     aggregators = [registers.QuadRegister(8) for unused_i in range(lanes_count)]
    253     stride = registers.MapOutputParameter('stride', 'params_stride_copy')
    254 
    255     self.asm_emitter.EmitColBlockStride(lanes_count, stride, stride)
    256 
    257     _GenerateClear(self.asm_emitter, 'i16', aggregators)
    258 
    259     if leftovers:
    260       self.asm_emitter.EmitNewline()
    261       self.asm_emitter.EmitComment('Reduce count by leftovers.')
    262       self.asm_emitter.EmitSubs(count, count,
    263                                 self.asm_emitter.ImmediateConstant(leftovers))
    264       self.asm_emitter.EmitBeqFront(2)
    265 
    266     self.asm_emitter.EmitNewline()
    267     self.asm_emitter.EmitNumericalLabel(1)
    268     self.asm_emitter.EmitSubs(count, count,
    269                               self.asm_emitter.ImmediateConstant(8))
    270 
    271     _GenerateColLoadAggregateStore(self.asm_emitter, registers, lanes_count, 8,
    272                                    aggregators, input_address, stride,
    273                                    output_address)
    274 
    275     self.asm_emitter.EmitNewline()
    276     self.asm_emitter.EmitBneBack(1)
    277 
    278     if leftovers:
    279       self.asm_emitter.EmitNewline()
    280       self.asm_emitter.EmitNumericalLabel(2)
    281       _GenerateColLoadAggregateStore(self.asm_emitter, registers, lanes_count,
    282                                      leftovers, aggregators, input_address,
    283                                      stride, output_address)
    284 
    285 
    286     _GenerateAggregatorReductionHighRegisters(
    287         self.asm_emitter, registers, aggregators, output_address)
    288 
    289     self.asm_emitter.EmitAsmEnd(registers)
    290     self.asm_emitter.PopIndent(len(self.emitter.indent))
    291 
    292 
    293 def GenerateUInt8x8Streams(cc_emitter, asm_emitter, lanes_count):
    294   row_major_with_sum = RowMajorWithSumUInt8x8(cc_emitter, asm_emitter)
    295   column_major_with_sum = ColumnMajorWithSumUInt8x8(cc_emitter, asm_emitter)
    296 
    297   for lanes_count in range(1, 1 + lanes_count):
    298     for leftovers in range(8):
    299       row_major_with_sum.SpecializeStream('uint8_t', lanes_count, 8, leftovers)
    300 
    301   for lanes_count in range(1, 1 + lanes_count):
    302     for leftovers in range(8):
    303       column_major_with_sum.SpecializeStream('uint8_t', lanes_count, 8,
    304                                              leftovers)
    305