Home | History | Annotate | Download | only in generators
      1 """Generates the specialized gemv functions."""
      2 
      3 import mul_1x8_Mx8_neon
      4 import mul_Nx8_Mx8_neon
      5 import qnt_Nx8_neon
      6 import zip_Nx8_neon
      7 
      8 _QUANTIZED_8BIT = 'quantized_8bit'
      9 _FULL_32BIT = 'full_32bit'
     10 _FULL_FLOAT = 'full_float'
     11 
     12 
     13 class Error(Exception):
     14   """Module level error."""
     15 
     16 
     17 class ConfigurationError(Error):
     18   """Runtime configuration error."""
     19 
     20 
     21 def GenerateCommonTempsCountersAndConsts(emitter):
     22   """Generates common gemv boilerplate variables."""
     23   emitter.EmitDeclare('const std::int32_t', 'col_chunks', 'n / 8')
     24   emitter.EmitDeclare('const std::int32_t', 'padded_k', '((k + 7) / 8) * 8')
     25   emitter.EmitDeclare('const std::int32_t', 'chunk_size', 'k * 4')
     26   emitter.EmitDeclare('const std::int32_t', 'zipped_chunk_size',
     27                       '(padded_k + 16) * 4')
     28   emitter.EmitDeclare('const std::uint8_t*', 'rhs_chunk', 'rhs')
     29   emitter.EmitDeclare('std::uint8_t*', 'zipped_lhs', 'scratch')
     30   emitter.EmitDeclare('std::int32_t*', 'zipped_lhs_offsets',
     31                       'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k)')
     32   emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_1',
     33                       'scratch + padded_k + 16')
     34   emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_2',
     35                       'zipped_rhs_1 + zipped_chunk_size')
     36   emitter.EmitNewline()
     37 
     38 
     39 def GenerateQuantized8BitTempsCountersAndConsts(emitter):
     40   """Generates all the boilerplate variables for the q8 gemm function."""
     41   GenerateCommonTempsCountersAndConsts(emitter)
     42   emitter.EmitDeclare('const std::int32_t', 'const_offset',
     43                       'lhs_offset * rhs_offset * k + result_offset')
     44   emitter.EmitDeclare('const std::int32_t', 'rounding_offset',
     45                       '(1 << (shift - 1))')
     46   emitter.EmitDeclare('std::int32_t*', 'temp_result',
     47                       'reinterpret_cast<std::int32_t*>('
     48                       'zipped_rhs_2 + zipped_chunk_size)')
     49   emitter.EmitDeclare('std::int32_t*', 'mul_result_chunk', 'temp_result')
     50   emitter.EmitNewline()
     51 
     52 
     53 def GenerateFullTempsCountersAndConsts(emitter, result_type):
     54   """Generates all the boilerplate variables for the int32 and float gemms."""
     55   GenerateCommonTempsCountersAndConsts(emitter)
     56   emitter.EmitDeclare('const std::int32_t', 'const_offset',
     57                       'lhs_offset * rhs_offset * k')
     58   emitter.EmitDeclare(result_type, 'mul_result_chunk', 'result')
     59   emitter.EmitNewline()
     60 
     61 
     62 def GenerateZipVector(emitter, aligned, leftovers):
     63   emitter.EmitCall(
     64       zip_Nx8_neon.BuildName(1, leftovers, aligned),
     65       ['lhs', 'k', 'k', 'zipped_lhs', 'rhs_offset', 0])
     66 
     67 
     68 def GetMul2Params(result_type):
     69   params = ['zipped_lhs', 'zipped_rhs_1', 'zipped_rhs_2', 'padded_k',
     70             'mul_result_chunk']
     71   if result_type is 'float':
     72     params.append('result_scale')
     73   return params
     74 
     75 
     76 def GetMulParams(result_type):
     77   params = ['zipped_lhs', 'zipped_rhs_1', 'padded_k', 'mul_result_chunk', 0]
     78   if result_type is 'float':
     79     params.append('result_scale')
     80   return params
     81 
     82 
     83 def GenerateMulCols(emitter, result_type, lhs_add, rhs_add, aligned, cols,
     84                     leftovers):
     85   """Emits code responsible for multiplication of one horizontal lhs strip."""
     86   emitter.EmitOpenBracket('for (int i = 0; i < col_chunks; ++i)')
     87   emitter.EmitCall(
     88       zip_Nx8_neon.BuildName(4, leftovers, aligned),
     89       ['rhs_chunk', 'k', 'k', 'zipped_rhs_1', 'lhs_offset', 'const_offset'])
     90   emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size')
     91 
     92   emitter.EmitCall(
     93       zip_Nx8_neon.BuildName(4, leftovers, aligned),
     94       ['rhs_chunk', 'k', 'k', 'zipped_rhs_2', 'lhs_offset', 'const_offset'])
     95   emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size')
     96 
     97   emitter.EmitCall(
     98       mul_1x8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, 8),
     99       GetMul2Params(result_type))
    100 
    101   emitter.EmitAssignIncrement('mul_result_chunk', 8)
    102   emitter.EmitCloseBracket()
    103 
    104   if cols > 4:
    105     emitter.EmitCall(
    106         zip_Nx8_neon.BuildName(4, leftovers, aligned),
    107         ['rhs_chunk', 'k', 'k', 'zipped_rhs_1', 'lhs_offset', 'const_offset'])
    108     emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size')
    109 
    110     emitter.EmitCall(
    111         zip_Nx8_neon.BuildName(cols - 4, leftovers, aligned),
    112         ['rhs_chunk', 'k', 'k', 'zipped_rhs_2', 'lhs_offset', 'const_offset'])
    113 
    114     emitter.EmitCall(
    115         mul_1x8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, cols),
    116         GetMul2Params(result_type))
    117   elif cols > 0:
    118     emitter.EmitCall(
    119         zip_Nx8_neon.BuildName(cols, leftovers, aligned),
    120         ['rhs_chunk', 'k', 'k', 'zipped_rhs_1', 'lhs_offset', 'const_offset'])
    121 
    122     emitter.EmitCall(
    123         mul_Nx8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, 1, cols),
    124         GetMulParams(result_type))
    125 
    126 
    127 def GenerateQuantized8BitMul(emitter, aligned, cols, leftovers):
    128   """Emits code for all lhs strips & leftover rows. Quantize after mul code."""
    129   GenerateMulCols(emitter, 'int32', False, True, aligned, cols, leftovers)
    130   emitter.EmitCall(
    131       qnt_Nx8_neon.BuildName(1, cols, aligned),
    132       ['temp_result', 'n', 0, 'zipped_lhs_offsets', 'result', 0,
    133        'multiplicative_offset', 'rounding_offset', '-shift'])
    134 
    135 
    136 def GenerateFullMul(emitter, result_type, aligned, cols, leftovers):
    137   GenerateMulCols(emitter, result_type, True, True, aligned, cols, leftovers)
    138 
    139 
    140 def BuildName(output_type, aligned, cols, leftover):
    141   name = BuildMainGemvName(output_type) + '_%d_%d' % (cols, leftover)
    142   if aligned:
    143     name += '_aligned'
    144   return name
    145 
    146 
    147 def GetCommonGemvParameters():
    148   return [['std::uint8_t*', 'scratch'], ['const std::uint8_t*', 'lhs'],
    149           ['const std::uint8_t*', 'rhs'], ['std::int32_t', 'n'],
    150           ['std::int32_t', 'k'], ['std::int32_t', 'lhs_offset'],
    151           ['std::int32_t', 'rhs_offset']]
    152 
    153 
    154 def GetGemvParameters(output_type):
    155   """Prepares a (type, parameter) array for the gemm functions."""
    156   params = GetCommonGemvParameters()
    157   if output_type is _QUANTIZED_8BIT:
    158     params += [['std::int32_t', 'result_offset'],
    159                ['std::int32_t', 'multiplicative_offset'],
    160                ['std::int32_t', 'shift'], ['std::uint8_t*', 'result']]
    161   elif output_type is _FULL_32BIT:
    162     params += [['std::int32_t*', 'result']]
    163   elif output_type is _FULL_FLOAT:
    164     params += [['float', 'result_scale'], ['float*', 'result']]
    165   else:
    166     raise ConfigurationError('Unsupported output type: %s' % output_type)
    167   return params
    168 
    169 
    170 def GenerateGemv(emitter, output_type, aligned, cols, leftovers):
    171   """Build one gemm function for given col, and depth leftovers."""
    172   emitter.EmitFunctionBeginA(
    173       BuildName(output_type, aligned, cols, leftovers),
    174       GetGemvParameters(output_type), 'void')
    175 
    176   emitter.EmitAssert('n %% 8 == %d' % cols)
    177   emitter.EmitAssert('k %% 8 == %d' % leftovers)
    178 
    179   if output_type is _QUANTIZED_8BIT:
    180     GenerateQuantized8BitTempsCountersAndConsts(emitter)
    181     GenerateZipVector(emitter, aligned, leftovers)
    182     GenerateQuantized8BitMul(emitter, aligned, cols, leftovers)
    183   elif output_type is _FULL_32BIT:
    184     GenerateFullTempsCountersAndConsts(emitter, 'std::int32_t*')
    185     GenerateZipVector(emitter, aligned, leftovers)
    186     GenerateFullMul(emitter, 'int32', aligned, cols, leftovers)
    187   elif output_type is _FULL_FLOAT:
    188     GenerateFullTempsCountersAndConsts(emitter, 'float*')
    189     GenerateZipVector(emitter, aligned, leftovers)
    190     GenerateFullMul(emitter, 'float', aligned, cols, leftovers)
    191   else:
    192     raise ConfigurationError('Unknown output type: %s' % output_type)
    193 
    194   emitter.EmitFunctionEnd()
    195 
    196 
    197 def GenerateGemvCall(emitter, output_type, aligned, m_mod, leftovers):
    198   emitter.EmitCall(
    199       emitter.Scope('internal',
    200                     BuildName(output_type, aligned, m_mod, leftovers)),
    201       [p for (unused_t, p) in GetGemvParameters(output_type)])
    202 
    203 
    204 def GenerateGemvSwitch2(emitter, output_type, aligned, n_mod):
    205   """Second level of main switch, choose optimized version on depth leftover."""
    206   emitter.EmitSwitch('k % 8')
    207 
    208   for leftovers in range(0, 8):
    209     emitter.EmitCase(leftovers)
    210     emitter.PushIndent()
    211     GenerateGemvCall(emitter, output_type, aligned, n_mod, leftovers)
    212     emitter.EmitBreak()
    213     emitter.PopIndent()
    214 
    215   emitter.EmitSwitchEnd()
    216 
    217 
    218 def GenerateGemvSwitch1(emitter, output_type, aligned):
    219   """First level of main switch, choose optimized version on cols leftover."""
    220   emitter.EmitSwitch('n % 8')
    221 
    222   for n_mod in range(0, 8):
    223     emitter.EmitCase(n_mod)
    224     emitter.PushIndent()
    225     GenerateGemvSwitch2(emitter, output_type, aligned, n_mod)
    226     emitter.EmitBreak()
    227     emitter.PopIndent()
    228 
    229   emitter.EmitSwitchEnd()
    230 
    231 
    232 def BuildMainGemvName(output_type):
    233   if output_type is _QUANTIZED_8BIT:
    234     return 'gemv_q8'
    235   elif output_type is _FULL_32BIT:
    236     return 'gemv_i32'
    237   elif output_type is _FULL_FLOAT:
    238     return 'gemv_f'
    239   else:
    240     raise ConfigurationError('Unsupported output type: %s' % output_type)
    241 
    242 
    243 def GenerateMainGemvFunction(emitter, output_type):
    244   """Emit high level gemv function that switches between optimized versions."""
    245   emitter.EmitFunctionBeginA(
    246       BuildMainGemvName(output_type), GetGemvParameters(output_type), 'void')
    247 
    248   emitter.EmitDeclare('const bool', 'lhs_aligned',
    249                       '((reinterpret_cast<std::uintptr_t>(lhs) % 8) == 0)')
    250   emitter.EmitDeclare('const bool', 'rhs_aligned',
    251                       '((reinterpret_cast<std::uintptr_t>(rhs) % 8) == 0)')
    252   emitter.EmitDeclare('const bool', 'k_aligned', '((k % 8) == 0)')
    253 
    254   if output_type is _QUANTIZED_8BIT:
    255     emitter.EmitDeclare('const bool', 'result_aligned',
    256                         '((reinterpret_cast<std::uintptr_t>(result) % 8) == 0)')
    257     emitter.EmitDeclare('const bool', 'aligned',
    258                         'lhs_aligned && rhs_aligned && result_aligned '
    259                         '&& k_aligned')
    260   else:
    261     emitter.EmitDeclare('const bool', 'aligned',
    262                         'lhs_aligned && rhs_aligned && k_aligned')
    263 
    264   emitter.EmitIf('aligned')
    265   GenerateGemvSwitch1(emitter, output_type, True)
    266   emitter.EmitElse()
    267   GenerateGemvSwitch1(emitter, output_type, False)
    268   emitter.EmitEndif()
    269   emitter.EmitFunctionEnd()
    270 
    271 
    272 def GenerateInternalFunctions(emitter):
    273   """Generate all the functions hidden in the internal namespace."""
    274   for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]:
    275     for aligned in [True, False]:
    276       for cols in range(0, 8):
    277         for leftover in range(0, 8):
    278           GenerateGemv(emitter, output_type, aligned, cols, leftover)
    279           emitter.EmitNewline()
    280 
    281 
    282 def GeneratePublicFunctions(emitter):
    283   for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]:
    284     GenerateMainGemvFunction(emitter, output_type)
    285     emitter.EmitNewline()
    286