Home | History | Annotate | Download | only in generators
      1 # Copyright 2016 The Gemmlowp Authors. All rights reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #    http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 """."""
     15 
     16 import common
     17 
     18 
     19 def _ReadParams(emitter, registers, input_address, elements, min_register):
     20   registers_count = (elements + 3) / 4
     21   registers = [
     22       registers.QuadRegister(min_register)
     23       for unused_i in range(registers_count)
     24   ]
     25   emitter.EmitVLoadAE(registers_count * 4, 32, registers, input_address, 64)
     26   return registers
     27 
     28 
     29 def _Duplicate(emitter, registers, rows, values):
     30   """Populate a grid of registers duplicating provided values."""
     31   duplicated = []
     32   for i in range(rows):
     33     if i is rows - 1:
     34       duplicated.append(values[0])
     35     else:
     36       duplicated.append(registers.QuadRegister())
     37 
     38     emitter.EmitVDup('32', duplicated[i],
     39                      emitter.Lane(32, values[i / 4], i % 4))
     40 
     41   return duplicated
     42 
     43 
     44 def _DuplicateGeneralRegister(emitter, registers, value, min_register):
     45   register = registers.QuadRegister(min_register)
     46   emitter.EmitVDup('32', register, value)
     47   return register
     48 
     49 
     50 class _StaticQuantizationUInt8Transformation(object):
     51   """Calculate quantized values and cast back to uint8."""
     52 
     53   def Prepare(self, emitter, registers, kernel_m, kernel_n, lhs, rhs):
     54     """Load parameters and prepare duplicated registers."""
     55     emitter.EmitNewline()
     56     emitter.EmitComment('StaticQuantization::Prepare')
     57 
     58     lhs_offset = _ReadParams(emitter, registers, lhs, kernel_m, 4)
     59     self.rhs_offsets = _ReadParams(emitter, registers, rhs, kernel_n, 4)
     60     self.multiplicative_offset = _DuplicateGeneralRegister(
     61         emitter, registers,
     62         registers.MapParameter('multiplicative_offset',
     63                                'params.kernel.multiplicative_offset'), 4)
     64     self.rounding_offset = _DuplicateGeneralRegister(
     65         emitter, registers,
     66         registers.MapParameter('rounding_offset',
     67                                'params.kernel.rounding_offset'), 4)
     68     self.shift = _DuplicateGeneralRegister(
     69         emitter, registers,
     70         registers.MapParameter('shift', 'params.kernel.shift'), 4)
     71     self.lhs_offsets = _Duplicate(emitter, registers, kernel_m, lhs_offset)
     72 
     73   def Transform(self, emitter, registers, data, unused_kernel_m,
     74                 unused_kernel_n):
     75     """Quantize the data."""
     76     emitter.EmitNewline()
     77     emitter.EmitComment('StaticQuantization::Transform')
     78 
     79     for (row, lhs_offset) in zip(data, self.lhs_offsets):
     80       for row_register in row:
     81         emitter.EmitVAdd('s32', row_register, row_register, lhs_offset)
     82 
     83     for row in data:
     84       for (row_register, rhs_offset_register) in zip(row, self.rhs_offsets):
     85         emitter.EmitVAdd('s32', row_register, row_register, rhs_offset_register)
     86 
     87     for row in data:
     88       for row_register in row:
     89         emitter.EmitVMul('i32', row_register, row_register,
     90                          self.multiplicative_offset)
     91 
     92     for row in data:
     93       for row_register in row:
     94         emitter.EmitVAdd('i32', row_register, row_register,
     95                          self.rounding_offset)
     96 
     97     for row in data:
     98       for row_register in row:
     99         emitter.EmitVShl('s32', row_register, row_register, self.shift)
    100 
    101     if len(data[0]) is 1:
    102       for row in data:
    103         emitter.EmitVQmovn('s32', row[0], row[0])
    104 
    105       for row in data:
    106         emitter.EmitVQmovun('s16', row[0], row[0])
    107 
    108       return data
    109     elif len(data[0]) is 2:
    110       results = []
    111       for row in data:
    112         emitter.EmitVQmovn2('s32', row[0], row[0], row[1])
    113         registers.FreeRegister(row[1])
    114         results.append([row[0]])
    115 
    116       for row in results:
    117         emitter.EmitVQmovun('s16', row[0], row[0])
    118 
    119       return results
    120     else:
    121       assert False
    122 
    123   def Type(self):
    124     return 8
    125 
    126 
    127 class _StaticQuantizationInt32Transformation(object):
    128   """."""
    129 
    130   def Prepare(self, emitter, registers, kernel_m, kernel_n, lhs, rhs):
    131     emitter.EmitNewline()
    132     emitter.EmitComment('StaticQuantizationInt32::Prepare')
    133 
    134     lhs_offset = _ReadParams(emitter, registers, lhs, kernel_m, 4)
    135     self.rhs_offsets = _ReadParams(emitter, registers, rhs, kernel_n, 4)
    136     self.lhs_offsets = _Duplicate(emitter, registers, kernel_m, lhs_offset)
    137 
    138   def Transform(self, emitter, unused_registers, data, unused_kernel_m,
    139                 unused_kernel_n):
    140     """Quantize data and output as int32."""
    141     emitter.EmitNewline()
    142     emitter.EmitComment('StaticQuantizationInt32::Transform')
    143 
    144     for (row, lhs_offset) in zip(data, self.lhs_offsets):
    145       for row_register in row:
    146         emitter.EmitVAdd('s32', row_register, row_register, lhs_offset)
    147 
    148     for row in data:
    149       for (row_register, rhs_offsets_register) in zip(row, self.rhs_offsets):
    150         emitter.EmitVAdd('s32', row_register, row_register,
    151                          rhs_offsets_register)
    152 
    153     return data
    154 
    155   def Type(self):
    156     return 32
    157 
    158 
    159 class _StaticQuantizationFloatTransformation(object):
    160   """."""
    161 
    162   def Prepare(self, emitter, registers, kernel_m, kernel_n, lhs, rhs):
    163     emitter.EmitNewline()
    164     emitter.EmitComment('StaticQuantizationFloat::Prepare')
    165 
    166     lhs_offset = _ReadParams(emitter, registers, lhs, kernel_m, 4)
    167     self.rhs_offsets = _ReadParams(emitter, registers, rhs, kernel_n, 4)
    168     self.scale = _DuplicateGeneralRegister(
    169         emitter, registers,
    170         registers.MapParameter('scale', 'params.kernel.scale'), 4)
    171     self.lhs_offsets = _Duplicate(emitter, registers, kernel_m, lhs_offset)
    172 
    173   def Transform(self, emitter, unused_registers, data, unused_kernel_m,
    174                 unused_kernel_n):
    175     """Quantize data and output as float."""
    176     emitter.EmitNewline()
    177     emitter.EmitComment('StaticQuantizationFloat::Transform')
    178 
    179     for (row, lhs_offset) in zip(data, self.lhs_offsets):
    180       for row_register in row:
    181         emitter.EmitVAdd('s32', row_register, row_register, lhs_offset)
    182 
    183     for row in data:
    184       for (row_register, rhs_offsets_register) in zip(row, self.rhs_offsets):
    185         emitter.EmitVAdd('s32', row_register, row_register,
    186                          rhs_offsets_register)
    187 
    188     for row in data:
    189       for row_register in row:
    190         emitter.EmitVCvt('f32', 's32', row_register, row_register)
    191 
    192     for row in data:
    193       for row_register in row:
    194         emitter.EmitVMul('f32', row_register, row_register, self.scale)
    195 
    196     return data
    197 
    198   def Type(self):
    199     return 32
    200 
    201 
    202 class _RowMajorOutput(object):
    203   """Output data in row major layout."""
    204 
    205   def Prepare(self, emitter, registers, kernel_m, unused_kernel_n,
    206               unused_data_type):
    207     """Prepare strided load addresses."""
    208     emitter.EmitNewline()
    209     emitter.EmitComment('RowMajorOutput::Prepare')
    210 
    211     stride = registers.MapParameter('stride', 'params.output_stream.stride')
    212 
    213     self.outputs = []
    214     self.outputs.append(registers.MapOutputParameter('result'))
    215 
    216     for unused_i in range(kernel_m - 1):
    217       register = registers.GeneralRegister()
    218       emitter.EmitAdd(register, self.outputs[-1], stride)
    219       self.outputs.append(register)
    220 
    221   def Output(self, emitter, unused_registers, data, data_type, unused_kernel_m,
    222              kernel_n):
    223     emitter.EmitNewline()
    224     emitter.EmitComment('RowMajorOutput::Output')
    225 
    226     for (datum, output) in zip(data, self.outputs):
    227       emitter.EmitVStoreAE(data_type, kernel_n, datum, output, None)
    228 
    229 
    230 def _GenerateAndClearAggregators(emitter, registers, count):
    231   """Prepare aggregators and emit aggregator clear code."""
    232   emitter.EmitNewline()
    233   emitter.EmitComment('Clear aggregators.')
    234   aggregators = [registers.QuadRegister() for unused_i in range(count)]
    235   for i in range(count):
    236     if i < 3:
    237       emitter.EmitVMov('i32', aggregators[i], emitter.ImmediateConstant(0))
    238     else:
    239       emitter.EmitVMov('i32', aggregators[i], aggregators[i - 3])
    240   return aggregators
    241 
    242 
    243 def _Generate3x3LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
    244                                       count):
    245   """Emit inner loop for 3 rows x 3 cols multiplication."""
    246   emitter.EmitNewline()
    247   emitter.EmitComment('3x3 lanes loop.')
    248   emitter.EmitNumericalLabel(1)
    249   emitter.EmitNewline()
    250 
    251   lhs_load = [registers.DoubleRegister() for unused_i in range(3)]
    252   rhs_load = [registers.DoubleRegister() for unused_i in range(3)]
    253   temp = [registers.QuadRegister() for unused_i in range(4)]
    254 
    255   emitter.EmitVLoadA(1, 8, rhs_load, emitter.DereferenceIncrement(rhs, 64))
    256   emitter.EmitVLoad(1, 8, lhs_load[0], emitter.DereferenceIncrement(lhs, 64))
    257 
    258   emitter.EmitVMull('u8', temp[0], lhs_load[0], rhs_load[0])
    259   emitter.EmitVLoad(1, 8, lhs_load[1], emitter.DereferenceIncrement(lhs, 64))
    260 
    261   emitter.EmitVMull('u8', temp[1], lhs_load[0], rhs_load[1])
    262   emitter.EmitVLoad(1, 8, lhs_load[2], emitter.DereferenceIncrement(lhs, 64))
    263 
    264   emitter.EmitVMull('u8', temp[2], lhs_load[0], rhs_load[2])
    265   emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64))
    266 
    267   emitter.EmitVMull('u8', temp[3], lhs_load[1], rhs_load[0])
    268   emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(64))
    269 
    270   emitter.EmitVPadal('u16', aggregators[0], temp[0])
    271   emitter.EmitVPadal('u16', aggregators[1], temp[1])
    272   emitter.EmitVPadal('u16', aggregators[2], temp[2])
    273   emitter.EmitVPadal('u16', aggregators[3], temp[3])
    274 
    275   emitter.EmitVMull('u8', temp[0], lhs_load[1], rhs_load[1])
    276   emitter.EmitVMull('u8', temp[1], lhs_load[1], rhs_load[2])
    277 
    278   registers.FreeRegisters([lhs_load[0], lhs_load[1]])
    279   temp.append(registers.QuadRegister())
    280 
    281   emitter.EmitVMull('u8', temp[2], lhs_load[2], rhs_load[0])
    282   emitter.EmitVMull('u8', temp[3], lhs_load[2], rhs_load[1])
    283 
    284   emitter.EmitNewline()
    285   emitter.EmitComment('Subtract counter.')
    286   emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
    287   emitter.EmitNewline()
    288 
    289   emitter.EmitVMull('u8', temp[4], lhs_load[2], rhs_load[2])
    290 
    291   emitter.EmitVPadal('u16', aggregators[4], temp[0])
    292   emitter.EmitVPadal('u16', aggregators[5], temp[1])
    293   emitter.EmitVPadal('u16', aggregators[6], temp[2])
    294   emitter.EmitVPadal('u16', aggregators[7], temp[3])
    295   emitter.EmitVPadal('u16', aggregators[8], temp[4])
    296 
    297   emitter.EmitNewline()
    298   emitter.EmitComment('Loop break.')
    299   emitter.EmitBgtBack(1)
    300 
    301   registers.FreeRegisters(temp + [lhs_load[2]] + rhs_load)
    302 
    303 
    304 def _Generate2x4LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
    305                                       count):
    306   """Emit inner loop for 2 rows x 4 cols multiplication."""
    307   emitter.EmitNewline()
    308   emitter.EmitComment('2x4 lanes loop.')
    309   emitter.EmitNumericalLabel(1)
    310   emitter.EmitNewline()
    311 
    312   lhs_load = [registers.DoubleRegister() for unused_i in range(2)]
    313   rhs_load = [registers.DoubleRegister() for unused_i in range(4)]
    314   temp = [registers.QuadRegister() for unused_i in range(5)]
    315 
    316   emitter.EmitVLoadA(1, 8, rhs_load, emitter.DereferenceIncrement(rhs, 256))
    317   emitter.EmitVLoad(1, 8, lhs_load[0], emitter.DereferenceIncrement(lhs, 64))
    318 
    319   emitter.EmitVMull('u8', temp[0], lhs_load[0], rhs_load[0])
    320   emitter.EmitVLoad(1, 8, lhs_load[1], emitter.DereferenceIncrement(lhs, 64))
    321 
    322   emitter.EmitVMull('u8', temp[1], lhs_load[0], rhs_load[1])
    323   emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(64))
    324 
    325   emitter.EmitVMull('u8', temp[2], lhs_load[0], rhs_load[2])
    326   emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64))
    327 
    328   emitter.EmitVMull('u8', temp[3], lhs_load[0], rhs_load[3])
    329   emitter.EmitVMull('u8', temp[4], lhs_load[1], rhs_load[0])
    330 
    331   emitter.EmitVPadal('u16', aggregators[0], temp[0])
    332   emitter.EmitVPadal('u16', aggregators[1], temp[1])
    333   emitter.EmitVPadal('u16', aggregators[2], temp[2])
    334 
    335   emitter.EmitVMull('u8', temp[0], lhs_load[1], rhs_load[1])
    336   emitter.EmitVMull('u8', temp[1], lhs_load[1], rhs_load[2])
    337   emitter.EmitVMull('u8', temp[2], lhs_load[1], rhs_load[3])
    338 
    339   emitter.EmitNewline()
    340   emitter.EmitComment('Subtract counter.')
    341   emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
    342 
    343   emitter.EmitNewline()
    344   emitter.EmitVPadal('u16', aggregators[3], temp[3])
    345   emitter.EmitVPadal('u16', aggregators[4], temp[4])
    346   emitter.EmitVPadal('u16', aggregators[5], temp[0])
    347   emitter.EmitVPadal('u16', aggregators[6], temp[1])
    348   emitter.EmitVPadal('u16', aggregators[7], temp[2])
    349 
    350   emitter.EmitNewline()
    351   emitter.EmitComment('Loop break.')
    352   emitter.EmitBgtBack(1)
    353 
    354   registers.FreeRegisters(temp + lhs_load + rhs_load)
    355 
    356 
    357 def _Generate1x8LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
    358                                       count):
    359   """Emit inner loop for 1 rows x 8 cols multiplication."""
    360   emitter.EmitNewline()
    361   emitter.EmitComment('1x8 lanes loop.')
    362   emitter.EmitNumericalLabel(1)
    363   emitter.EmitNewline()
    364 
    365   lhs_load = registers.DoubleRegister()
    366   rhs_load = [registers.DoubleRegister() for unused_i in range(4)]
    367   temp = [registers.QuadRegister() for unused_i in range(5)]
    368 
    369   emitter.EmitVLoadAE(4 * 8, 8, rhs_load, rhs, 256)
    370   emitter.EmitVLoadE(8, 8, lhs_load, lhs, 64)
    371 
    372   emitter.EmitVMull('u8', temp[0], lhs_load, rhs_load[0])
    373   emitter.EmitVMull('u8', temp[1], lhs_load, rhs_load[1])
    374   emitter.EmitVMull('u8', temp[2], lhs_load, rhs_load[2])
    375   emitter.EmitVMull('u8', temp[3], lhs_load, rhs_load[3])
    376 
    377   emitter.EmitVLoadAE(4 * 8, 8, rhs_load, rhs, 256)
    378 
    379   emitter.EmitVPadal('u16', aggregators[0], temp[0])
    380   emitter.EmitVPadal('u16', aggregators[1], temp[1])
    381   emitter.EmitVPadal('u16', aggregators[2], temp[2])
    382   emitter.EmitVPadal('u16', aggregators[3], temp[3])
    383 
    384   emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(256))
    385 
    386   emitter.EmitVMull('u8', temp[4], lhs_load, rhs_load[0])
    387   emitter.EmitVMull('u8', temp[0], lhs_load, rhs_load[1])
    388   emitter.EmitVMull('u8', temp[1], lhs_load, rhs_load[2])
    389   emitter.EmitVMull('u8', temp[2], lhs_load, rhs_load[3])
    390 
    391   emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(32))
    392 
    393   emitter.EmitNewline()
    394   emitter.EmitComment('Subtract counter.')
    395   emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
    396 
    397   emitter.EmitNewline()
    398   emitter.EmitVPadal('u16', aggregators[4], temp[4])
    399   emitter.EmitVPadal('u16', aggregators[5], temp[0])
    400   emitter.EmitVPadal('u16', aggregators[6], temp[1])
    401   emitter.EmitVPadal('u16', aggregators[7], temp[2])
    402 
    403   emitter.EmitNewline()
    404   emitter.EmitComment('Loop break.')
    405   emitter.EmitBgtBack(1)
    406 
    407   registers.FreeRegisters(temp + [lhs_load] + rhs_load)
    408 
    409 
    410 def _GenerateNxMLoadMultiplyAggregate(emitter, registers, kernel_m, kernel_n,
    411                                       aggregators, lhs, rhs, count):
    412   """Emit inner loop for N rows x M cols multiplication."""
    413   emitter.EmitNewline()
    414   emitter.EmitComment('General NxM lanes loop.')
    415   emitter.EmitNumericalLabel(1)
    416   emitter.EmitNewline()
    417   emitter.EmitComment('Subtract counter.')
    418   emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
    419   emitter.EmitNewline()
    420 
    421   lhs_load = [registers.DoubleRegister() for unused_i in range(kernel_m)]
    422   rhs_load = [registers.DoubleRegister() for unused_i in range(kernel_n)]
    423 
    424   emitter.EmitVLoadAE(8 * kernel_m, 8, lhs_load, lhs, 64)
    425   emitter.EmitVLoadAE(8 * kernel_n, 8, rhs_load, rhs, 64)
    426 
    427   emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64))
    428   emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(64))
    429 
    430   results = [
    431       registers.QuadRegister() for unused_i in range(kernel_m * kernel_n)
    432   ]
    433 
    434   for row in range(kernel_m):
    435     for col in range(kernel_n):
    436       index = row * kernel_n + col
    437       emitter.EmitVMull('u8', results[index], rhs_load[col], lhs_load[row])
    438 
    439   for i in range(kernel_m * kernel_n):
    440     emitter.EmitVPadal('u16', aggregators[i], results[i])
    441 
    442   emitter.EmitNewline()
    443   emitter.EmitComment('Loop break.')
    444   emitter.EmitBgtBack(1)
    445 
    446   registers.FreeRegisters(lhs_load + rhs_load + results)
    447 
    448 
    449 def _Generate1xNLoadMultiplyAggregate(emitter, registers, kernel_n, aggregators,
    450                                       lhs, rhs, count):
    451   """Emit inner loop for 1 row x M cols multiplication."""
    452   assert kernel_n in [5, 6, 7, 8]
    453   emitter.EmitNewline()
    454   emitter.EmitComment('General 1xM lanes loop.')
    455   emitter.EmitNumericalLabel(1)
    456   emitter.EmitNewline()
    457   emitter.EmitComment('Subtract counter.')
    458   emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
    459   emitter.EmitNewline()
    460 
    461   leftover = kernel_n - 4
    462 
    463   rhs_load = [registers.DoubleRegister() for unused_i in range(4)]
    464   lhs_load = registers.DoubleRegister()
    465 
    466   emitter.EmitVLoadAE(8 * 4, 8, rhs_load, rhs, 64)
    467   emitter.EmitVLoadE(8, 8, lhs_load, lhs, 64)
    468 
    469   emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64))
    470 
    471   results = [registers.QuadRegister() for unused_i in range(4)]
    472 
    473   for i in range(4):
    474     emitter.EmitVMull('u8', results[i], rhs_load[i], lhs_load)
    475 
    476   emitter.EmitVLoadAE(8 * leftover, 8, rhs_load, rhs, 64)
    477   emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(128))
    478 
    479   for i in range(4):
    480     emitter.EmitVPadal('u16', aggregators[i], results[i])
    481 
    482   for i in range(leftover):
    483     emitter.EmitVMull('u8', results[i], rhs_load[i], lhs_load)
    484 
    485   for i in range(leftover):
    486     emitter.EmitVPadal('u16', aggregators[i + 4], results[i])
    487 
    488   emitter.EmitNewline()
    489   emitter.EmitComment('Loop break.')
    490   emitter.EmitBgtBack(1)
    491 
    492   registers.FreeRegisters([lhs_load] + rhs_load + results)
    493 
    494 
    495 def _GenerateMultiplyKernel(emitter, registers, kernel_m, kernel_n, lhs, rhs):
    496   """Main muliply loop. Pick best implementation for given kernel shape."""
    497   count = registers.MapParameter('count', 'params.kernel.count')
    498 
    499   aggregators = _GenerateAndClearAggregators(emitter, registers,
    500                                              kernel_m * kernel_n)
    501   if kernel_m is 3 and kernel_n is 3:
    502     _Generate3x3LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
    503                                       count)
    504   elif kernel_m is 2 and kernel_n is 4:
    505     _Generate2x4LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
    506                                       count)
    507   elif kernel_m is 1 and kernel_n is 8:
    508     _Generate1x8LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
    509                                       count)
    510   elif kernel_m is 1 and kernel_n > 4:
    511     _Generate1xNLoadMultiplyAggregate(emitter, registers, kernel_n, aggregators,
    512                                       lhs, rhs, count)
    513   else:
    514     _GenerateNxMLoadMultiplyAggregate(emitter, registers, kernel_m, kernel_n,
    515                                       aggregators, lhs, rhs, count)
    516   return aggregators
    517 
    518 
    519 def _ReduceAggregators(emitter, aggregators):
    520   reduced_count = (len(aggregators) + 3) / 4
    521   reduced = aggregators[:reduced_count]
    522   emitter.EmitVSumReduce('u32', len(aggregators), 4, reduced, aggregators)
    523   return reduced
    524 
    525 
    526 def _GenerateAggregatorReduce(emitter, aggregators, kernel_m, kernel_n):
    527   emitter.EmitNewline()
    528   emitter.EmitComment('Reduce aggregators.')
    529   row_temps = []
    530   for i in range(kernel_m):
    531     row_temps.append(
    532         _ReduceAggregators(emitter, aggregators[i * kernel_n:(i + 1) *
    533                                                 kernel_n]))
    534   return row_temps
    535 
    536 
    537 class QuantizedMulKernel(common.MulKernelGenerator):
    538   """."""
    539 
    540   def __init__(self, cc_emitter, kernel_name, output_stream_name, asm_emitter,
    541                fused_transformation, output_strategy):
    542     common.MulKernelGenerator.__init__(self, cc_emitter, kernel_name,
    543                                        output_stream_name)
    544     self.asm_emitter = asm_emitter
    545     self.fused_transformation = fused_transformation
    546     self.output_strategy = output_strategy
    547 
    548   def EmitMultiply(self, in_type, out_type, kernel_m, kernel_n, pack_size):
    549     assert in_type is 'uint8_t'
    550     assert pack_size is 8
    551     assert kernel_m * kernel_n <= 9
    552 
    553     registers = self.asm_emitter.CreateRegisters()
    554 
    555     self.asm_emitter.PushIndent(self.emitter.indent)
    556     self.asm_emitter.EmitAsmBegin()
    557 
    558     lhs = registers.MapOutputParameter('lhs')
    559     rhs = registers.MapOutputParameter('rhs')
    560     self.asm_emitter.EmitPld(lhs)
    561     self.asm_emitter.EmitPld(rhs)
    562 
    563     aggregators = _GenerateMultiplyKernel(self.asm_emitter, registers, kernel_m,
    564                                           kernel_n, lhs, rhs)
    565 
    566     self.fused_transformation.Prepare(self.asm_emitter, registers, kernel_m,
    567                                       kernel_n, lhs, rhs)
    568 
    569     self.output_strategy.Prepare(self.asm_emitter, registers, kernel_m,
    570                                  kernel_n, self.fused_transformation.Type())
    571 
    572     reduced = _GenerateAggregatorReduce(self.asm_emitter, aggregators, kernel_m,
    573                                         kernel_n)
    574 
    575     transformed = self.fused_transformation.Transform(self.asm_emitter,
    576                                                       registers, reduced,
    577                                                       kernel_m, kernel_n)
    578 
    579     self.output_strategy.Output(self.asm_emitter, registers, transformed,
    580                                 self.fused_transformation.Type(), kernel_m,
    581                                 kernel_n)
    582 
    583     self.asm_emitter.EmitAsmEnd(registers)
    584     self.asm_emitter.PopIndent(len(self.emitter.indent))
    585 
    586 
    587 class QuantizedMulStaticRowMajor(QuantizedMulKernel):
    588   """."""
    589 
    590   def __init__(self, cc_emitter, asm_emitter):
    591     QuantizedMulKernel.__init__(self, cc_emitter, 'QuantizedStaticPreprocessed',
    592                                 'RowMajor', asm_emitter,
    593                                 _StaticQuantizationUInt8Transformation(),
    594                                 _RowMajorOutput())
    595 
    596 
    597 class QuantizedMulStaticAsInt32RowMajor(QuantizedMulKernel):
    598   """."""
    599 
    600   def __init__(self, cc_emitter, asm_emitter):
    601     QuantizedMulKernel.__init__(self, cc_emitter,
    602                                 'QuantizedStaticPreprocessedAsInt32',
    603                                 'RowMajor', asm_emitter,
    604                                 _StaticQuantizationInt32Transformation(),
    605                                 _RowMajorOutput())
    606 
    607 
    608 class QuantizedMulStaticAsFloatRowMajor(QuantizedMulKernel):
    609   """."""
    610 
    611   def __init__(self, cc_emitter, asm_emitter):
    612     QuantizedMulKernel.__init__(self, cc_emitter,
    613                                 'QuantizedStaticPreprocessedAsFloat',
    614                                 'RowMajor', asm_emitter,
    615                                 _StaticQuantizationFloatTransformation(),
    616                                 _RowMajorOutput())
    617 
    618 
    619 def GenerateKernels(cc_emitter, asm_emitter, shapes):
    620   """Generate the quantized multiplication kernels for uint8 operands."""
    621   quantized_mul_static_row_major = QuantizedMulStaticRowMajor(cc_emitter,
    622                                                               asm_emitter)
    623   quantized_mul_static_int32_row_major = QuantizedMulStaticAsInt32RowMajor(
    624       cc_emitter, asm_emitter)
    625 
    626   quantized_mul_static_float_row_major = QuantizedMulStaticAsFloatRowMajor(
    627       cc_emitter, asm_emitter)
    628 
    629   for shape in shapes:
    630     quantized_mul_static_row_major.SpecializeMulKernel('uint8_t', 'uint8_t',
    631                                                        shape[0], shape[1], 8)
    632   for shape in shapes:
    633     quantized_mul_static_int32_row_major.SpecializeMulKernel('uint8_t',
    634                                                              'int32_t',
    635                                                              shape[0], shape[1],
    636                                                              8)
    637 
    638   for shape in shapes:
    639     quantized_mul_static_float_row_major.SpecializeMulKernel('uint8_t', 'float',
    640                                                              shape[0], shape[1],
    641                                                              8)
    642