Home | History | Annotate | Download | only in generators

Lines Matching refs:output_type

173 def BuildName(output_type, aligned, rows, cols, leftover):
174 name = BuildMainGemmName(output_type) + '_%d_%d_%d' % (rows, cols, leftover)
187 def GetGemmParameters(output_type, extra_params=None):
192 if output_type is _QUANTIZED_8BIT:
196 elif output_type is _FULL_32BIT:
198 elif output_type is _FULL_FLOAT:
201 raise ConfigurationError('Unsupported output type: %s' % output_type)
205 def GetStridedGemmParameters(output_type):
206 return GetGemmParameters(output_type, [['std::int32_t', 'result_stride']])
209 def GenerateGemm(emitter, output_type, aligned, rows, cols, leftovers):
212 BuildName(output_type, aligned, rows, cols, leftovers),
213 GetStridedGemmParameters(output_type), 'void')
219 if output_type is _QUANTIZED_8BIT:
223 elif output_type is _FULL_32BIT:
227 elif output_type is _FULL_FLOAT:
232 raise ConfigurationError('Unknown output type: %s' % output_type)
237 def GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers):
240 BuildName(output_type, aligned, m_mod, n_mod, leftovers)),
241 [p for (unused_t, p) in GetStridedGemmParameters(output_type)])
244 def GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod):
251 GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers)
258 def GenerateGemmSwitch2(emitter, output_type, aligned, m_mod):
265 GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod)
272 def GenerateGemmSwitch1(emitter, output_type, aligned):
279 GenerateGemmSwitch2(emitter, output_type, aligned, m_mod)
286 def BuildMainGemmName(output_type):
287 if output_type is _QUANTIZED_8BIT:
289 elif output_type is _FULL_32BIT:
291 elif output_type is _FULL_FLOAT:
294 raise ConfigurationError('Unsupported output type: %s' % output_type)
297 def BuildStridedMainGemmName(output_type):
298 return BuildMainGemmName(output_type) + '_strided'
301 def GenerateMainGemmFunction(emitter, output_type):
304 BuildStridedMainGemmName(output_type),
305 GetStridedGemmParameters(output_type), 'void')
313 if output_type is _QUANTIZED_8BIT:
326 GenerateGemmSwitch1(emitter, output_type, True)
328 GenerateGemmSwitch1(emitter, output_type, False)
333 def GenerateWrapperGemmFunction(emitter, output_type):
335 BuildMainGemmName(output_type), GetGemmParameters(output_type), 'void')
337 BuildStridedMainGemmName(output_type),
338 [p for (unused_t, p) in GetGemmParameters(output_type)] + ['n'])
344 for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]:
349 GenerateGemm(emitter, output_type, aligned, rows, cols, leftover)
354 for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]:
355 GenerateMainGemmFunction(emitter, output_type)
358 GenerateWrapperGemmFunction(emitter, output_type)