Home | History | Annotate | Download | only in generators
      1 """Generates the specialized gemm functions."""
      2 
      3 import mul_Nx8_Mx8_neon
      4 import qnt_Nx8_neon
      5 import zip_Nx8_neon
      6 
      7 _QUANTIZED_8BIT = 'quantized_8bit'
      8 _FULL_32BIT = 'full_32bit'
      9 _FULL_FLOAT = 'full_float'
     10 
     11 
     12 class Error(Exception):
     13   """Module level error."""
     14 
     15 
     16 class ConfigurationError(Error):
     17   """Runtime configuration error."""
     18 
     19 
     20 def GenerateCommonTempsCountersAndConsts(emitter, rows):
     21   emitter.EmitDeclare('const std::int32_t', 'row_chunks', 'm / 3')
     22   emitter.EmitDeclare('const std::int32_t', 'col_chunks', 'n / 3')
     23   emitter.EmitDeclare('const std::int32_t', 'padded_k', '((k + 7) / 8) * 8')
     24   emitter.EmitDeclare('const std::int32_t', 'chunk_size', 'k * 3')
     25   emitter.EmitDeclare('const std::int32_t', 'zipped_chunk_size',
     26                       '(padded_k + 16) * 3')
     27   emitter.EmitDeclare('const std::int32_t', 'zipped_rhs_size',
     28                       '(padded_k + 16) * n')
     29   emitter.EmitDeclare('const std::uint8_t*', 'lhs_chunk', 'lhs')
     30   emitter.EmitDeclare('const std::uint8_t*', 'rhs_chunk', 'rhs')
     31   emitter.EmitDeclare('std::uint8_t*', 'zipped_lhs', 'scratch')
     32   emitter.EmitDeclare(
     33       'std::int32_t*', 'zipped_lhs_3_offsets',
     34       'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3)')
     35   if rows is not 0:
     36     emitter.EmitDeclare(
     37         'std::int32_t*', 'zipped_lhs_%d_offsets' % rows,
     38         'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * %d)' % rows)
     39   emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs',
     40                       'scratch + zipped_chunk_size')
     41   emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_chunk', 'zipped_rhs')
     42   emitter.EmitDeclare('const std::int32_t', 'result_chunk_stride',
     43                       'result_stride * 3')
     44   emitter.EmitNewline()
     45 
     46 
     47 def GenerateQuantized8BitTempsCountersAndConsts(emitter, rows):
     48   """Generates all the boilerplate variables for the q8 gemm function."""
     49   GenerateCommonTempsCountersAndConsts(emitter, rows)
     50   emitter.EmitDeclare('const std::int32_t', 'const_offset',
     51                       'lhs_offset * rhs_offset * k + result_offset')
     52   emitter.EmitDeclare('const std::int32_t', 'rounding_offset',
     53                       '(1 << (shift - 1))')
     54   emitter.EmitDeclare('std::int32_t*', 'temp_result',
     55                       'reinterpret_cast<std::int32_t*>('
     56                       'scratch + zipped_chunk_size + zipped_rhs_size)')
     57   emitter.EmitDeclare('std::uint8_t*', 'result_chunk', 'result')
     58   emitter.EmitDeclare('std::int32_t*', 'mul_result_chunk', 'temp_result')
     59   emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes',
     60                       '((n * 4 + 7) / 8) * 8')
     61   emitter.EmitNewline()
     62 
     63 
     64 def GenerateFullTempsCountersAndConsts(emitter, result_type, rows):
     65   """Generates all the boilerplate variables for the int32 and float gemms."""
     66   GenerateCommonTempsCountersAndConsts(emitter, rows)
     67   emitter.EmitDeclare('const std::int32_t', 'const_offset',
     68                       'lhs_offset * rhs_offset * k')
     69   emitter.EmitDeclare(result_type, 'result_chunk', 'result')
     70   emitter.EmitDeclare(result_type, 'mul_result_chunk', 'result')
     71   emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes',
     72                       'result_stride * 4')
     73   emitter.EmitNewline()
     74 
     75 
     76 def ZipName(rows, leftovers, aligned):
     77   return zip_Nx8_neon.BuildName(rows, leftovers, aligned)
     78 
     79 
     80 def GenerateZipRhs(emitter, aligned, cols, leftovers):
     81   """Emits the code responsible for zipping the rhs matrix."""
     82   emitter.EmitOpenBracket('for (int i = 0; i < col_chunks; ++i)')
     83   emitter.EmitCall(
     84       ZipName(3, leftovers, aligned),
     85       ['rhs_chunk', 'k', 'k', 'zipped_rhs_chunk', 'lhs_offset', 0])
     86   emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size')
     87   emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size')
     88   emitter.EmitCloseBracket()
     89 
     90   if cols is not 0:
     91     emitter.EmitCall(
     92         ZipName(cols, leftovers, aligned),
     93         ['rhs_chunk', 'k', 'k', 'zipped_rhs_chunk', 'lhs_offset', 0])
     94   emitter.EmitNewline()
     95 
     96 
     97 def MulName(result_type, lhs_add, rhs_add, rows, cols):
     98   return mul_Nx8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, rows, cols)
     99 
    100 
    101 def GetMulParams(result_type):
    102   params = ['zipped_lhs', 'zipped_rhs_chunk', 'padded_k', 'mul_result_chunk',
    103             'mul_result_chunk_stride_bytes']
    104   if result_type is 'float':
    105     params.append('result_scale')
    106   return params
    107 
    108 
    109 def GenerateMulRows(emitter, result, result_type, lhs_add, rhs_add, aligned,
    110                     rows, cols, leftovers):
    111   """Emits code responsible for multiplication of one horizontal lhs strip."""
    112   emitter.EmitCall(
    113       ZipName(rows, leftovers, aligned),
    114       ['lhs_chunk', 'k', 'k', 'zipped_lhs', 'rhs_offset', 'const_offset'])
    115   emitter.EmitAssign('zipped_rhs_chunk', 'zipped_rhs')
    116   emitter.EmitAssign('mul_result_chunk', result)
    117 
    118   emitter.EmitOpenBracket('for (int j = 0; j < col_chunks; ++j)')
    119 
    120   emitter.EmitCall(
    121       MulName(result_type, lhs_add, rhs_add, rows, 3),
    122       GetMulParams(result_type))
    123   emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size')
    124   emitter.EmitAssignIncrement('mul_result_chunk', 3)
    125 
    126   emitter.EmitCloseBracket()
    127 
    128   if cols is not 0:
    129     emitter.EmitCall(
    130         MulName(result_type, lhs_add, rhs_add, rows, cols),
    131         GetMulParams(result_type))
    132 
    133 
    134 def GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers):
    135   """Emits code for all lhs strips & leftover rows. Quantize after mul code."""
    136   emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)')
    137   GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, 3,
    138                   cols, leftovers)
    139   emitter.EmitCall(
    140       qnt_Nx8_neon.BuildMultiQuantizeName(aligned, 3),
    141       ['temp_result', 'n', 'mul_result_chunk_stride_bytes',
    142        'zipped_lhs_3_offsets', 'result_chunk', 'result_stride',
    143        'multiplicative_offset', 'rounding_offset', '-shift'])
    144   emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size')
    145   emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride')
    146   emitter.EmitCloseBracket()
    147   emitter.EmitNewline()
    148 
    149   if rows is not 0:
    150     GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, rows,
    151                     cols, leftovers)
    152     emitter.EmitCall(
    153         qnt_Nx8_neon.BuildMultiQuantizeName(aligned, rows),
    154         ['temp_result', 'n', 'mul_result_chunk_stride_bytes',
    155          'zipped_lhs_%d_offsets' % rows, 'result_chunk', 'result_stride',
    156          'multiplicative_offset', 'rounding_offset', '-shift'])
    157 
    158 
    159 def GenerateFullMul(emitter, result_type, aligned, rows, cols, leftovers):
    160   emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)')
    161   GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned, 3,
    162                   cols, leftovers)
    163   emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size')
    164   emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride')
    165   emitter.EmitCloseBracket()
    166   emitter.EmitNewline()
    167 
    168   if rows is not 0:
    169     GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned,
    170                     rows, cols, leftovers)
    171 
    172 
    173 def BuildName(output_type, aligned, rows, cols, leftover):
    174   name = BuildMainGemmName(output_type) + '_%d_%d_%d' % (rows, cols, leftover)
    175   if aligned:
    176     name += '_aligned'
    177   return name
    178 
    179 
    180 def GetCommonGemmParameters():
    181   return [['std::uint8_t*', 'scratch'], ['const std::uint8_t*', 'lhs'],
    182           ['const std::uint8_t*', 'rhs'], ['std::int32_t', 'm'],
    183           ['std::int32_t', 'n'], ['std::int32_t', 'k'],
    184           ['std::int32_t', 'lhs_offset'], ['std::int32_t', 'rhs_offset']]
    185 
    186 
    187 def GetGemmParameters(output_type, extra_params=None):
    188   """Prepares a (type, parameter) array for the gemm functions."""
    189   if extra_params is None:
    190     extra_params = []
    191   params = GetCommonGemmParameters()
    192   if output_type is _QUANTIZED_8BIT:
    193     params += [['std::int32_t', 'result_offset'],
    194                ['std::int32_t', 'multiplicative_offset'],
    195                ['std::int32_t', 'shift'], ['std::uint8_t*', 'result']]
    196   elif output_type is _FULL_32BIT:
    197     params += [['std::int32_t*', 'result']]
    198   elif output_type is _FULL_FLOAT:
    199     params += [['float', 'result_scale'], ['float*', 'result']]
    200   else:
    201     raise ConfigurationError('Unsupported output type: %s' % output_type)
    202   return params + extra_params
    203 
    204 
    205 def GetStridedGemmParameters(output_type):
    206   return GetGemmParameters(output_type, [['std::int32_t', 'result_stride']])
    207 
    208 
    209 def GenerateGemm(emitter, output_type, aligned, rows, cols, leftovers):
    210   """Build one gemm function for given row, col, and depth leftovers."""
    211   emitter.EmitFunctionBeginA(
    212       BuildName(output_type, aligned, rows, cols, leftovers),
    213       GetStridedGemmParameters(output_type), 'void')
    214 
    215   emitter.EmitAssert('m %% 3 == %d' % rows)
    216   emitter.EmitAssert('n %% 3 == %d' % cols)
    217   emitter.EmitAssert('k %% 8 == %d' % leftovers)
    218 
    219   if output_type is _QUANTIZED_8BIT:
    220     GenerateQuantized8BitTempsCountersAndConsts(emitter, rows)
    221     GenerateZipRhs(emitter, aligned, cols, leftovers)
    222     GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers)
    223   elif output_type is _FULL_32BIT:
    224     GenerateFullTempsCountersAndConsts(emitter, 'std::int32_t*', rows)
    225     GenerateZipRhs(emitter, aligned, cols, leftovers)
    226     GenerateFullMul(emitter, 'int32', aligned, rows, cols, leftovers)
    227   elif output_type is _FULL_FLOAT:
    228     GenerateFullTempsCountersAndConsts(emitter, 'float*', rows)
    229     GenerateZipRhs(emitter, aligned, cols, leftovers)
    230     GenerateFullMul(emitter, 'float', aligned, rows, cols, leftovers)
    231   else:
    232     raise ConfigurationError('Unknown output type: %s' % output_type)
    233 
    234   emitter.EmitFunctionEnd()
    235 
    236 
    237 def GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers):
    238   emitter.EmitCall(
    239       emitter.Scope('internal',
    240                     BuildName(output_type, aligned, m_mod, n_mod, leftovers)),
    241       [p for (unused_t, p) in GetStridedGemmParameters(output_type)])
    242 
    243 
    244 def GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod):
    245   """Third level of main switch, choose optimized version on depth leftover."""
    246   emitter.EmitSwitch('k % 8')
    247 
    248   for leftovers in range(0, 8):
    249     emitter.EmitCase(leftovers)
    250     emitter.PushIndent()
    251     GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers)
    252     emitter.EmitBreak()
    253     emitter.PopIndent()
    254 
    255   emitter.EmitSwitchEnd()
    256 
    257 
    258 def GenerateGemmSwitch2(emitter, output_type, aligned, m_mod):
    259   """Second level of main switch, choose optimized version on cols leftover."""
    260   emitter.EmitSwitch('n % 3')
    261 
    262   for n_mod in range(0, 3):
    263     emitter.EmitCase(n_mod)
    264     emitter.PushIndent()
    265     GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod)
    266     emitter.EmitBreak()
    267     emitter.PopIndent()
    268 
    269   emitter.EmitSwitchEnd()
    270 
    271 
    272 def GenerateGemmSwitch1(emitter, output_type, aligned):
    273   """First level of main switch, choose optimized version on rows leftover."""
    274   emitter.EmitSwitch('m % 3')
    275 
    276   for m_mod in range(0, 3):
    277     emitter.EmitCase(m_mod)
    278     emitter.PushIndent()
    279     GenerateGemmSwitch2(emitter, output_type, aligned, m_mod)
    280     emitter.EmitBreak()
    281     emitter.PopIndent()
    282 
    283   emitter.EmitSwitchEnd()
    284 
    285 
    286 def BuildMainGemmName(output_type):
    287   if output_type is _QUANTIZED_8BIT:
    288     return 'gemm_q8'
    289   elif output_type is _FULL_32BIT:
    290     return 'gemm_i32'
    291   elif output_type is _FULL_FLOAT:
    292     return 'gemm_f'
    293   else:
    294     raise ConfigurationError('Unsupported output type: %s' % output_type)
    295 
    296 
    297 def BuildStridedMainGemmName(output_type):
    298   return BuildMainGemmName(output_type) + '_strided'
    299 
    300 
    301 def GenerateMainGemmFunction(emitter, output_type):
    302   """Emit high level gemm function that switches between optimized versions."""
    303   emitter.EmitFunctionBeginA(
    304       BuildStridedMainGemmName(output_type),
    305       GetStridedGemmParameters(output_type), 'void')
    306 
    307   emitter.EmitDeclare('const bool', 'lhs_aligned',
    308                       '((reinterpret_cast<std::uintptr_t>(lhs) % 8) == 0)')
    309   emitter.EmitDeclare('const bool', 'rhs_aligned',
    310                       '((reinterpret_cast<std::uintptr_t>(rhs) % 8) == 0)')
    311   emitter.EmitDeclare('const bool', 'k_aligned', '((k % 8) == 0)')
    312 
    313   if output_type is _QUANTIZED_8BIT:
    314     emitter.EmitDeclare('const bool', 'result_aligned',
    315                         '((reinterpret_cast<std::uintptr_t>(result) % 8) == 0)')
    316     emitter.EmitDeclare('const bool', 'result_stride_aligned',
    317                         '((result_stride % 8) == 0)')
    318     emitter.EmitDeclare('const bool', 'aligned',
    319                         'lhs_aligned && rhs_aligned && result_aligned '
    320                         '&& k_aligned && result_stride_aligned')
    321   else:
    322     emitter.EmitDeclare('const bool', 'aligned',
    323                         'lhs_aligned && rhs_aligned && k_aligned')
    324 
    325   emitter.EmitIf('aligned')
    326   GenerateGemmSwitch1(emitter, output_type, True)
    327   emitter.EmitElse()
    328   GenerateGemmSwitch1(emitter, output_type, False)
    329   emitter.EmitEndif()
    330   emitter.EmitFunctionEnd()
    331 
    332 
    333 def GenerateWrapperGemmFunction(emitter, output_type):
    334   emitter.EmitFunctionBeginA(
    335       BuildMainGemmName(output_type), GetGemmParameters(output_type), 'void')
    336   emitter.EmitCall(
    337       BuildStridedMainGemmName(output_type),
    338       [p for (unused_t, p) in GetGemmParameters(output_type)] + ['n'])
    339   emitter.EmitFunctionEnd()
    340 
    341 
    342 def GenerateInternalFunctions(emitter):
    343   """Generate all the functions hidden in the internal namespace."""
    344   for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]:
    345     for aligned in [True, False]:
    346       for rows in range(0, 3):
    347         for cols in range(0, 3):
    348           for leftover in range(0, 8):
    349             GenerateGemm(emitter, output_type, aligned, rows, cols, leftover)
    350             emitter.EmitNewline()
    351 
    352 
    353 def GeneratePublicFunctions(emitter):
    354   for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]:
    355     GenerateMainGemmFunction(emitter, output_type)
    356     emitter.EmitNewline()
    357 
    358     GenerateWrapperGemmFunction(emitter, output_type)
    359     emitter.EmitNewline()
    360