Home | History | Annotate | Download | only in generators
      1 """Generates the whole gemm header.
      2 
      3 """
      4 
      5 import cc_emitter
      6 import mul_Nx8_Mx8_neon
      7 import neon_emitter
      8 import qnt_Nx8_neon
      9 import zip_Nx8_neon
     10 
     11 _HEADER_COPYRIGHT = """// Copyright 2015 Google Inc. All Rights Reserved.
     12 //
     13 // Licensed under the Apache License, Version 2.0 (the "License");
     14 // you may not use this file except in compliance with the License.
     15 // You may obtain a copy of the License at
     16 //
     17 //     http://www.apache.org/licenses/LICENSE-2.0
     18 //
     19 // Unless required by applicable law or agreed to in writing, software
     20 // distributed under the License is distributed on an "AS IS" BASIS,
     21 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     22 // See the License for the specific language governing permissions and
     23 // limitations under the License.
     24 //
     25 // single_thread_gemm.h: programatically generated GEMM library header.
     26 """
     27 
     28 _QUANTIZED_8BIT = 'quantized_8bit'
     29 _FULL_32BIT = 'full_32bit'
     30 _FULL_FLOAT = 'full_float'
     31 
     32 
     33 class Error(Exception):
     34   """Module level error."""
     35 
     36 
     37 class ConfigurationError(Error):
     38   """Runtime configuration error."""
     39 
     40 
     41 def GenerateCommonTempsCountersAndConsts(emitter, rows):
     42   emitter.EmitDeclare('const std::int32_t', 'row_chunks', 'm / 3')
     43   emitter.EmitDeclare('const std::int32_t', 'col_chunks', 'n / 3')
     44   emitter.EmitDeclare('const std::int32_t', 'padded_k', '((k + 7) / 8) * 8')
     45   emitter.EmitDeclare('const std::int32_t', 'chunk_size', 'k * 3')
     46   emitter.EmitDeclare('const std::int32_t', 'zipped_chunk_size',
     47                       '(padded_k + 16) * 3')
     48   emitter.EmitDeclare('const std::int32_t', 'zipped_rhs_size',
     49                       '(padded_k + 16) * n')
     50   emitter.EmitDeclare('const std::uint8_t*', 'lhs_chunk', 'lhs')
     51   emitter.EmitDeclare('const std::uint8_t*', 'rhs_chunk', 'rhs')
     52   emitter.EmitDeclare('std::uint8_t*', 'zipped_lhs', 'scratch')
     53   emitter.EmitDeclare(
     54       'std::int32_t*', 'zipped_lhs_3_offsets',
     55       'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3)')
     56   if rows is not 0:
     57     emitter.EmitDeclare(
     58         'std::int32_t*', 'zipped_lhs_%d_offsets' % rows,
     59         'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * %d)' % rows)
     60   emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs',
     61                       'scratch + zipped_chunk_size')
     62   emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_chunk', 'zipped_rhs')
     63   emitter.EmitDeclare('const std::int32_t', 'result_chunk_stride',
     64                       'result_stride * 3')
     65   emitter.EmitNewline()
     66 
     67 
     68 def GenerateQuantized8BitTempsCountersAndConsts(emitter, rows):
     69   """Generates all the boilerplate variables for the q8 gemm function."""
     70   GenerateCommonTempsCountersAndConsts(emitter, rows)
     71   emitter.EmitDeclare('const std::int32_t', 'const_offset',
     72                       'lhs_offset * rhs_offset * k + result_offset')
     73   emitter.EmitDeclare('const std::int32_t', 'rounding_offset',
     74                       '(1 << (shift - 1))')
     75   emitter.EmitDeclare('std::int32_t*', 'temp_result',
     76                       'reinterpret_cast<std::int32_t*>('
     77                       'scratch + zipped_chunk_size + zipped_rhs_size)')
     78   emitter.EmitDeclare('std::uint8_t*', 'result_chunk', 'result')
     79   emitter.EmitDeclare('std::int32_t*', 'mul_result_chunk', 'temp_result')
     80   emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes',
     81                       '((n * 4 + 7) / 8) * 8')
     82   emitter.EmitNewline()
     83 
     84 
     85 def GenerateFullTempsCountersAndConsts(emitter, result_type, rows):
     86   """Generates all the boilerplate variables for the int32 and float gemms."""
     87   GenerateCommonTempsCountersAndConsts(emitter, rows)
     88   emitter.EmitDeclare('const std::int32_t', 'const_offset',
     89                       'lhs_offset * rhs_offset * k')
     90   emitter.EmitDeclare(result_type, 'result_chunk', 'result')
     91   emitter.EmitDeclare(result_type, 'mul_result_chunk', 'result')
     92   emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes',
     93                       'result_stride * 4')
     94   emitter.EmitNewline()
     95 
     96 
     97 def ZipName(rows, leftovers, aligned):
     98   return zip_Nx8_neon.BuildName(rows, leftovers, aligned)
     99 
    100 
    101 def GenerateZipRhs(emitter, aligned, cols, leftovers):
    102   """Emits the code responsible for zipping the rhs matrix."""
    103   emitter.EmitOpenBracket('for (int i = 0; i < col_chunks; ++i)')
    104   emitter.EmitCall(
    105       ZipName(3, leftovers, aligned),
    106       ['rhs_chunk', 'k', 'k', 'zipped_rhs_chunk', 'lhs_offset', 0])
    107   emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size')
    108   emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size')
    109   emitter.EmitCloseBracket()
    110 
    111   if cols is not 0:
    112     emitter.EmitCall(
    113         ZipName(cols, leftovers, aligned),
    114         ['rhs_chunk', 'k', 'k', 'zipped_rhs_chunk', 'lhs_offset', 0])
    115   emitter.EmitNewline()
    116 
    117 
    118 def MulName(result_type, lhs_add, rhs_add, rows, cols):
    119   return mul_Nx8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, rows, cols)
    120 
    121 
    122 def GetMulParams(result_type):
    123   params = ['zipped_lhs', 'zipped_rhs_chunk', 'padded_k', 'mul_result_chunk',
    124             'mul_result_chunk_stride_bytes']
    125   if result_type is 'float':
    126     params.append('result_scale')
    127   return params
    128 
    129 
    130 def GenerateMulRows(emitter, result, result_type, lhs_add, rhs_add, aligned,
    131                     rows, cols, leftovers):
    132   """Emits code responsible for multiplication of one horizontal lhs strip."""
    133   emitter.EmitCall(
    134       ZipName(rows, leftovers, aligned),
    135       ['lhs_chunk', 'k', 'k', 'zipped_lhs', 'rhs_offset', 'const_offset'])
    136   emitter.EmitAssign('zipped_rhs_chunk', 'zipped_rhs')
    137   emitter.EmitAssign('mul_result_chunk', result)
    138 
    139   emitter.EmitOpenBracket('for (int j = 0; j < col_chunks; ++j)')
    140 
    141   emitter.EmitCall(
    142       MulName(result_type, lhs_add, rhs_add, rows, 3),
    143       GetMulParams(result_type))
    144   emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size')
    145   emitter.EmitAssignIncrement('mul_result_chunk', 3)
    146 
    147   emitter.EmitCloseBracket()
    148 
    149   if cols is not 0:
    150     emitter.EmitCall(
    151         MulName(result_type, lhs_add, rhs_add, rows, cols),
    152         GetMulParams(result_type))
    153 
    154 
    155 def GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers):
    156   """Emits code for all lhs strips & leftover rows. Quantize after mul code."""
    157   emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)')
    158   GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, 3,
    159                   cols, leftovers)
    160   emitter.EmitCall(
    161       BuildMultiQuantizeName(aligned, 3),
    162       ['temp_result', 'n', 'mul_result_chunk_stride_bytes',
    163        'zipped_lhs_3_offsets', 'result_chunk', 'result_stride',
    164        'multiplicative_offset', 'rounding_offset', '-shift'])
    165   emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size')
    166   emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride')
    167   emitter.EmitCloseBracket()
    168   emitter.EmitNewline()
    169 
    170   if rows is not 0:
    171     GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, rows,
    172                     cols, leftovers)
    173     emitter.EmitCall(
    174         BuildMultiQuantizeName(aligned, rows),
    175         ['temp_result', 'n', 'mul_result_chunk_stride_bytes',
    176          'zipped_lhs_%d_offsets' % rows, 'result_chunk', 'result_stride',
    177          'multiplicative_offset', 'rounding_offset', '-shift'])
    178 
    179 
    180 def GenerateFullMul(emitter, result_type, aligned, rows, cols, leftovers):
    181   emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)')
    182   GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned, 3,
    183                   cols, leftovers)
    184   emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size')
    185   emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride')
    186   emitter.EmitCloseBracket()
    187   emitter.EmitNewline()
    188 
    189   if rows is not 0:
    190     GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned,
    191                     rows, cols, leftovers)
    192 
    193 
    194 def BuildName(output_type, aligned, rows, cols, leftover):
    195   name = BuildMainGemmName(output_type) + '_%d_%d_%d' % (rows, cols, leftover)
    196   if aligned:
    197     name += '_aligned'
    198   return name
    199 
    200 
    201 def GetCommonGemmParameters():
    202   return [['std::uint8_t*', 'scratch'], ['const std::uint8_t*', 'lhs'],
    203           ['const std::uint8_t*', 'rhs'], ['std::int32_t', 'm'],
    204           ['std::int32_t', 'n'], ['std::int32_t', 'k'],
    205           ['std::int32_t', 'lhs_offset'], ['std::int32_t', 'rhs_offset']]
    206 
    207 
    208 def GetGemmParameters(output_type, extra_params=None):
    209   """Prepares a (type, parameter) array for the gemm functions."""
    210   if extra_params is None:
    211     extra_params = []
    212   params = GetCommonGemmParameters()
    213   if output_type is _QUANTIZED_8BIT:
    214     params += [['std::int32_t', 'result_offset'],
    215                ['std::int32_t', 'multiplicative_offset'],
    216                ['std::int32_t', 'shift'], ['std::uint8_t*', 'result']]
    217   elif output_type is _FULL_32BIT:
    218     params += [['std::int32_t*', 'result']]
    219   elif output_type is _FULL_FLOAT:
    220     params += [['float', 'result_scale'], ['float*', 'result']]
    221   else:
    222     raise ConfigurationError('Unsupported output type: %s' % output_type)
    223   return params + extra_params
    224 
    225 
    226 def GetStridedGemmParameters(output_type):
    227   return GetGemmParameters(output_type, [['std::int32_t', 'result_stride']])
    228 
    229 
    230 def GenerateGemm(emitter, output_type, aligned, rows, cols, leftovers):
    231   """Build one gemm function for given row, col, and depth leftovers."""
    232   emitter.EmitFunctionBeginA(
    233       BuildName(output_type, aligned, rows, cols, leftovers),
    234       GetStridedGemmParameters(output_type), 'void')
    235 
    236   emitter.EmitAssert('m %% 3 == %d' % rows)
    237   emitter.EmitAssert('n %% 3 == %d' % cols)
    238   emitter.EmitAssert('k %% 8 == %d' % leftovers)
    239 
    240   if output_type is _QUANTIZED_8BIT:
    241     GenerateQuantized8BitTempsCountersAndConsts(emitter, rows)
    242     GenerateZipRhs(emitter, aligned, cols, leftovers)
    243     GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers)
    244   elif output_type is _FULL_32BIT:
    245     GenerateFullTempsCountersAndConsts(emitter, 'std::int32_t*', rows)
    246     GenerateZipRhs(emitter, aligned, cols, leftovers)
    247     GenerateFullMul(emitter, 'int32', aligned, rows, cols, leftovers)
    248   elif output_type is _FULL_FLOAT:
    249     GenerateFullTempsCountersAndConsts(emitter, 'float*', rows)
    250     GenerateZipRhs(emitter, aligned, cols, leftovers)
    251     GenerateFullMul(emitter, 'float', aligned, rows, cols, leftovers)
    252   else:
    253     raise ConfigurationError('Unknown output type: %s' % output_type)
    254 
    255   emitter.EmitFunctionEnd()
    256 
    257 
    258 def BuildMultiQuantizeName(aligned, rows):
    259   name = 'multi_qnt_%dx8' % rows
    260   if aligned:
    261     name = '%s_aligned' % name
    262   return name
    263 
    264 
    265 def GenerateMultiQuantize(emitter, aligned, rows):
    266   """Emit main quantization code that switches between optimized versions."""
    267   name = BuildMultiQuantizeName(aligned, rows)
    268   emitter.EmitFunctionBeginA(
    269       name,
    270       [['const std::int32_t*', 'source'], ['std::int32_t', 'count'],
    271        ['std::int32_t', 'stride'], ['const std::int32_t*', 'offsets'],
    272        ['std::uint8_t*', 'destination'], ['std::int32_t', 'destination_stride'],
    273        ['std::int32_t', 'multiplicative_offset'],
    274        ['std::int32_t', 'rounding_offset'], ['std::int32_t', 'shift']], 'void')
    275   emitter.EmitSwitch('count % 8')
    276 
    277   for leftovers in range(0, 8):
    278     emitter.EmitCase(leftovers)
    279     emitter.PushIndent()
    280     emitter.EmitCall(
    281         qnt_Nx8_neon.BuildName(rows, leftovers, aligned),
    282         ['source', 'count', 'stride', 'offsets', 'destination',
    283          'destination_stride', 'multiplicative_offset', 'rounding_offset',
    284          'shift'])
    285     emitter.EmitBreak()
    286     emitter.PopIndent()
    287 
    288   emitter.EmitSwitchEnd()
    289   emitter.EmitFunctionEnd()
    290 
    291 
    292 def GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers):
    293   emitter.EmitCall(
    294       emitter.Scope('internal',
    295                     BuildName(output_type, aligned, m_mod, n_mod, leftovers)),
    296       [p for (unused_t, p) in GetStridedGemmParameters(output_type)])
    297 
    298 
    299 def GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod):
    300   """Third level of main switch, choose optimized version on depth leftover."""
    301   emitter.EmitSwitch('k % 8')
    302 
    303   for leftovers in range(0, 8):
    304     emitter.EmitCase(leftovers)
    305     emitter.PushIndent()
    306     GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers)
    307     emitter.EmitBreak()
    308     emitter.PopIndent()
    309 
    310   emitter.EmitSwitchEnd()
    311 
    312 
    313 def GenerateGemmSwitch2(emitter, output_type, aligned, m_mod):
    314   """Second level of main switch, choose optimized version on cols leftover."""
    315   emitter.EmitSwitch('n % 3')
    316 
    317   for n_mod in range(0, 3):
    318     emitter.EmitCase(n_mod)
    319     emitter.PushIndent()
    320     GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod)
    321     emitter.EmitBreak()
    322     emitter.PopIndent()
    323 
    324   emitter.EmitSwitchEnd()
    325 
    326 
    327 def GenerateGemmSwitch1(emitter, output_type, aligned):
    328   """First level of main switch, choose optimized version on rows leftover."""
    329   emitter.EmitSwitch('m % 3')
    330 
    331   for m_mod in range(0, 3):
    332     emitter.EmitCase(m_mod)
    333     emitter.PushIndent()
    334     GenerateGemmSwitch2(emitter, output_type, aligned, m_mod)
    335     emitter.EmitBreak()
    336     emitter.PopIndent()
    337 
    338   emitter.EmitSwitchEnd()
    339 
    340 
    341 def BuildMainGemmName(output_type):
    342   if output_type is _QUANTIZED_8BIT:
    343     return 'gemm_q8'
    344   elif output_type is _FULL_32BIT:
    345     return 'gemm_i32'
    346   elif output_type is _FULL_FLOAT:
    347     return 'gemm_f'
    348   else:
    349     raise ConfigurationError('Unsupported output type: %s' % output_type)
    350 
    351 
    352 def BuildStridedMainGemmName(output_type):
    353   return BuildMainGemmName(output_type) + '_strided'
    354 
    355 
    356 def GenerateMainGemmFunction(emitter, output_type):
    357   """Emit high level gemm function that switches between optimized versions."""
    358   emitter.EmitFunctionBeginA(
    359       BuildStridedMainGemmName(output_type),
    360       GetStridedGemmParameters(output_type), 'void')
    361 
    362   emitter.EmitDeclare('const bool', 'lhs_aligned',
    363                       '((reinterpret_cast<std::uintptr_t>(lhs) % 8) == 0)')
    364   emitter.EmitDeclare('const bool', 'rhs_aligned',
    365                       '((reinterpret_cast<std::uintptr_t>(rhs) % 8) == 0)')
    366   emitter.EmitDeclare('const bool', 'k_aligned', '((k % 8) == 0)')
    367 
    368   if output_type is _QUANTIZED_8BIT:
    369     emitter.EmitDeclare('const bool', 'result_aligned',
    370                         '((reinterpret_cast<std::uintptr_t>(result) % 8) == 0)')
    371     emitter.EmitDeclare('const bool', 'result_stride_aligned',
    372                         '((result_stride % 8) == 0)')
    373     emitter.EmitDeclare('const bool', 'aligned',
    374                         'lhs_aligned && rhs_aligned && result_aligned '
    375                         '&& k_aligned && result_stride_aligned')
    376   else:
    377     emitter.EmitDeclare('const bool', 'aligned',
    378                         'lhs_aligned && rhs_aligned && k_aligned')
    379 
    380   emitter.EmitIf('aligned')
    381   GenerateGemmSwitch1(emitter, output_type, True)
    382   emitter.EmitElse()
    383   GenerateGemmSwitch1(emitter, output_type, False)
    384   emitter.EmitEndif()
    385   emitter.EmitFunctionEnd()
    386 
    387 
    388 def GenerateWrapperGemmFunction(emitter, output_type):
    389   emitter.EmitFunctionBeginA(
    390       BuildMainGemmName(output_type), GetGemmParameters(output_type), 'void')
    391   emitter.EmitCall(
    392       BuildStridedMainGemmName(output_type),
    393       [p for (unused_t, p) in GetGemmParameters(output_type)] + ['n'])
    394   emitter.EmitFunctionEnd()
    395 
    396 
    397 def GenerateInternalFunctions(emitter):
    398   """Generate all the functions hidden in the internal namespace."""
    399   zip_Nx8_neon.GenerateFunctions(neon_emitter.NeonEmitter())
    400   emitter.EmitNewline()
    401 
    402   mul_Nx8_Mx8_neon.GenerateFunctions(neon_emitter.NeonEmitter(), 'int32', False,
    403                                      True)
    404   emitter.EmitNewline()
    405 
    406   mul_Nx8_Mx8_neon.GenerateFunctions(neon_emitter.NeonEmitter(), 'int32', True,
    407                                      True)
    408   emitter.EmitNewline()
    409 
    410   mul_Nx8_Mx8_neon.GenerateFunctions(neon_emitter.NeonEmitter(), 'float', True,
    411                                      True)
    412   emitter.EmitNewline()
    413 
    414   qnt_Nx8_neon.GenerateFunctions(neon_emitter.NeonEmitter())
    415   emitter.EmitNewline()
    416 
    417   for aligned in [True, False]:
    418     for rows in range(1, 4):
    419       GenerateMultiQuantize(emitter, aligned, rows)
    420       emitter.EmitNewline()
    421 
    422   for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]:
    423     for aligned in [True, False]:
    424       for rows in range(0, 3):
    425         for cols in range(0, 3):
    426           for leftover in range(0, 8):
    427             GenerateGemm(emitter, output_type, aligned, rows, cols, leftover)
    428             emitter.EmitNewline()
    429 
    430 
    431 def Main():
    432   """Generate the single threaded meta gemm library."""
    433   emitter = cc_emitter.CCEmitter()
    434 
    435   emitter.EmitCodeNoSemicolon(_HEADER_COPYRIGHT)
    436   emitter.EmitHeaderBegin('gemmlowp_meta_single_thread_gemm')
    437 
    438   emitter.EmitPreprocessor1('ifdef', 'GEMMLOWP_NEON_32')
    439   emitter.EmitNewline()
    440 
    441   emitter.EmitInclude('<cassert>')
    442   emitter.EmitNewline()
    443 
    444   emitter.EmitNamespaceBegin('gemmlowp')
    445   emitter.EmitNamespaceBegin('meta')
    446   emitter.EmitNamespaceBegin('internal')
    447   emitter.EmitNewline()
    448 
    449   GenerateInternalFunctions(emitter)
    450 
    451   emitter.EmitNamespaceEnd()
    452   emitter.EmitNewline()
    453 
    454   GenerateMainGemmFunction(emitter, _QUANTIZED_8BIT)
    455   emitter.EmitNewline()
    456   GenerateMainGemmFunction(emitter, _FULL_32BIT)
    457   emitter.EmitNewline()
    458   GenerateMainGemmFunction(emitter, _FULL_FLOAT)
    459   emitter.EmitNewline()
    460   GenerateWrapperGemmFunction(emitter, _QUANTIZED_8BIT)
    461   emitter.EmitNewline()
    462   GenerateWrapperGemmFunction(emitter, _FULL_32BIT)
    463   emitter.EmitNewline()
    464   GenerateWrapperGemmFunction(emitter, _FULL_FLOAT)
    465   emitter.EmitNewline()
    466 
    467   emitter.EmitNamespaceEnd()
    468   emitter.EmitNamespaceEnd()
    469   emitter.EmitNewline()
    470 
    471   emitter.EmitPreprocessor('else')
    472   emitter.EmitPreprocessor1('warning',
    473                             '"Meta gemm fast-path requires GEMMLOWP_NEON_32!"')
    474   emitter.EmitPreprocessor('endif')
    475   emitter.EmitNewline()
    476 
    477   emitter.EmitHeaderEnd()
    478 
    479 
    480 if __name__ == '__main__':
    481   Main()
    482