Home | History | Annotate | Download | only in generators

Lines Matching full:emitter

41 def GenerateCommonTempsCountersAndConsts(emitter, rows):
42 emitter.EmitDeclare('const std::int32_t', 'row_chunks', 'm / 3')
43 emitter.EmitDeclare('const std::int32_t', 'col_chunks', 'n / 3')
44 emitter.EmitDeclare('const std::int32_t', 'padded_k', '((k + 7) / 8) * 8')
45 emitter.EmitDeclare('const std::int32_t', 'chunk_size', 'k * 3')
46 emitter.EmitDeclare('const std::int32_t', 'zipped_chunk_size',
48 emitter.EmitDeclare('const std::int32_t', 'zipped_rhs_size',
50 emitter.EmitDeclare('const std::uint8_t*', 'lhs_chunk', 'lhs')
51 emitter.EmitDeclare('const std::uint8_t*', 'rhs_chunk', 'rhs')
52 emitter.EmitDeclare('std::uint8_t*', 'zipped_lhs', 'scratch')
53 emitter.EmitDeclare(
57 emitter.EmitDeclare(
60 emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs',
62 emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_chunk', 'zipped_rhs')
63 emitter.EmitDeclare('const std::int32_t', 'result_chunk_stride',
65 emitter.EmitNewline()
68 def GenerateQuantized8BitTempsCountersAndConsts(emitter, rows):
70 GenerateCommonTempsCountersAndConsts(emitter, rows)
71 emitter.EmitDeclare('const std::int32_t', 'const_offset',
73 emitter.EmitDeclare('const std::int32_t', 'rounding_offset',
75 emitter.EmitDeclare('std::int32_t*', 'temp_result',
78 emitter.EmitDeclare('std::uint8_t*', 'result_chunk', 'result')
79 emitter.EmitDeclare('std::int32_t*', 'mul_result_chunk', 'temp_result')
80 emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes',
82 emitter.EmitNewline()
85 def GenerateFullTempsCountersAndConsts(emitter, result_type, rows):
87 GenerateCommonTempsCountersAndConsts(emitter, rows)
88 emitter.EmitDeclare('const std::int32_t', 'const_offset',
90 emitter.EmitDeclare(result_type, 'result_chunk', 'result')
91 emitter.EmitDeclare(result_type, 'mul_result_chunk', 'result')
92 emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes',
94 emitter.EmitNewline()
101 def GenerateZipRhs(emitter, aligned, cols, leftovers):
103 emitter.EmitOpenBracket('for (int i = 0; i < col_chunks; ++i)')
104 emitter.EmitCall(
107 emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size')
108 emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size')
109 emitter.EmitCloseBracket()
112 emitter.EmitCall(
115 emitter.EmitNewline()
130 def GenerateMulRows(emitter, result, result_type, lhs_add, rhs_add, aligned,
133 emitter.EmitCall(
136 emitter.EmitAssign('zipped_rhs_chunk', 'zipped_rhs')
137 emitter.EmitAssign('mul_result_chunk', result)
139 emitter.EmitOpenBracket('for (int j = 0; j < col_chunks; ++j)')
141 emitter.EmitCall(
144 emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size')
145 emitter.EmitAssignIncrement('mul_result_chunk', 3)
147 emitter.EmitCloseBracket()
150 emitter.EmitCall(
155 def GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers):
157 emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)')
158 GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, 3,
160 emitter.EmitCall(
165 emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size')
166 emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride')
167 emitter.EmitCloseBracket()
168 emitter.EmitNewline()
171 GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, rows,
173 emitter.EmitCall(
180 def GenerateFullMul(emitter, result_type, aligned, rows, cols, leftovers):
181 emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)')
182 GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned, 3,
184 emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size')
185 emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride')
186 emitter.EmitCloseBracket()
187 emitter.EmitNewline()
190 GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned,
230 def GenerateGemm(emitter, output_type, aligned, rows, cols, leftovers):
232 emitter.EmitFunctionBeginA(
236 emitter.EmitAssert('m %% 3 == %d' % rows)
237 emitter.EmitAssert('n %% 3 == %d' % cols)
238 emitter.EmitAssert('k %% 8 == %d' % leftovers)
241 GenerateQuantized8BitTempsCountersAndConsts(emitter, rows)
242 GenerateZipRhs(emitter, aligned, cols, leftovers)
243 GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers)
245 GenerateFullTempsCountersAndConsts(emitter, 'std::int32_t*', rows)
246 GenerateZipRhs(emitter, aligned, cols, leftovers)
247 GenerateFullMul(emitter, 'int32', aligned, rows, cols, leftovers)
249 GenerateFullTempsCountersAndConsts(emitter, 'float*', rows)
250 GenerateZipRhs(emitter, aligned, cols, leftovers)
251 GenerateFullMul(emitter, 'float', aligned, rows, cols, leftovers)
255 emitter.EmitFunctionEnd()
265 def GenerateMultiQuantize(emitter, aligned, rows):
268 emitter.EmitFunctionBeginA(
275 emitter.EmitSwitch('count % 8')
278 emitter.EmitCase(leftovers)
279 emitter.PushIndent()
280 emitter.EmitCall(
285 emitter.EmitBreak()
286 emitter.PopIndent()
288 emitter.EmitSwitchEnd()
289 emitter.EmitFunctionEnd()
292 def GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers):
293 emitter.EmitCall(
294 emitter.Scope('internal',
299 def GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod):
301 emitter.EmitSwitch('k % 8')
304 emitter.EmitCase(leftovers)
305 emitter.PushIndent()
306 GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers)
307 emitter.EmitBreak()
308 emitter.PopIndent()
310 emitter.EmitSwitchEnd()
313 def GenerateGemmSwitch2(emitter, output_type, aligned, m_mod):
315 emitter.EmitSwitch('n % 3')
318 emitter.EmitCase(n_mod)
319 emitter.PushIndent()
320 GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod)
321 emitter.EmitBreak()
322 emitter.PopIndent()
324 emitter.EmitSwitchEnd()
327 def GenerateGemmSwitch1(emitter, output_type, aligned):
329 emitter.EmitSwitch('m % 3')
332 emitter.EmitCase(m_mod)
333 emitter.PushIndent()
334 GenerateGemmSwitch2(emitter, output_type, aligned, m_mod)
335 emitter.EmitBreak()
336 emitter.PopIndent()
338 emitter.EmitSwitchEnd()
356 def GenerateMainGemmFunction(emitter, output_type):
358 emitter.EmitFunctionBeginA(
362 emitter.EmitDeclare('const bool', 'lhs_aligned',
364 emitter.EmitDeclare('const bool', 'rhs_aligned',
366 emitter.EmitDeclare('const bool', 'k_aligned', '((k % 8) == 0)')
369 emitter.EmitDeclare('const bool', 'result_aligned',
371 emitter.EmitDeclare('const bool', 'result_stride_aligned',
373 emitter.EmitDeclare('const bool', 'aligned',
377 emitter.EmitDeclare('const bool', 'aligned',
380 emitter.EmitIf('aligned')
381 GenerateGemmSwitch1(emitter, output_type, True)
382 emitter.EmitElse()
383 GenerateGemmSwitch1(emitter, output_type, False)
384 emitter.EmitEndif()
385 emitter.EmitFunctionEnd()
388 def GenerateWrapperGemmFunction(emitter, output_type):
389 emitter.EmitFunctionBeginA(
391 emitter.EmitCall(
394 emitter.EmitFunctionEnd()
397 def GenerateInternalFunctions(emitter):
400 emitter.EmitNewline()
404 emitter.EmitNewline()
408 emitter.EmitNewline()
412 emitter.EmitNewline()
415 emitter.EmitNewline()
419 GenerateMultiQuantize(emitter, aligned, rows)
420 emitter.EmitNewline()
427 GenerateGemm(emitter, output_type, aligned, rows, cols, leftover)
428 emitter.EmitNewline()
433 emitter = cc_emitter.CCEmitter()
435 emitter.EmitCodeNoSemicolon(_HEADER_COPYRIGHT)
436 emitter
438 emitter.EmitPreprocessor1('ifdef', 'GEMMLOWP_NEON_32')
439 emitter.EmitNewline()
441 emitter.EmitInclude('<cassert>')
442 emitter.EmitNewline()
444 emitter.EmitNamespaceBegin('gemmlowp')
445 emitter.EmitNamespaceBegin('meta')
446 emitter.EmitNamespaceBegin('internal')
447 emitter.EmitNewline()
449 GenerateInternalFunctions(emitter)
451 emitter.EmitNamespaceEnd()
452 emitter.EmitNewline()
454 GenerateMainGemmFunction(emitter, _QUANTIZED_8BIT)
455 emitter.EmitNewline()
456 GenerateMainGemmFunction(emitter, _FULL_32BIT)
457 emitter.EmitNewline()
458 GenerateMainGemmFunction(emitter, _FULL_FLOAT)
459 emitter.EmitNewline()
460 GenerateWrapperGemmFunction(emitter, _QUANTIZED_8BIT)
461 emitter.EmitNewline()
462 GenerateWrapperGemmFunction(emitter, _FULL_32BIT)
463 emitter.EmitNewline()
464 GenerateWrapperGemmFunction(emitter, _FULL_FLOAT)
465 emitter.EmitNewline()
467 emitter.EmitNamespaceEnd()
468 emitter.EmitNamespaceEnd()
469 emitter.EmitNewline()
471 emitter.EmitPreprocessor('else')
472 emitter.EmitPreprocessor1('warning',
474 emitter.EmitPreprocessor('endif')
475 emitter.EmitNewline()
477 emitter.EmitHeaderEnd()