1 # Copyright 2016 The Gemmlowp Authors. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 """.""" 15 16 import common 17 18 19 def _DuplicateGeneralRegister(size, emitter, registers, value, min_register): 20 register = registers.QuadRegister(min_register) 21 emitter.EmitVDup(size, register, value) 22 return register 23 24 25 def _DuplicateGeneralMemoryRegister(size, emitter, registers, value, 26 min_register): 27 register = registers.QuadRegister(min_register) 28 general = registers.GeneralRegister() 29 emitter.EmitLdr(general, value) 30 emitter.EmitVDup(size, register, general) 31 registers.FreeRegister(general) 32 return register 33 34 35 class MinMaxTransformation(object): 36 """.""" 37 38 def Check(self, in_type, out_type, kernel_size, leftovers): 39 assert in_type is 'uint8_t' 40 assert out_type is 'uint8_t' 41 assert kernel_size is 16 42 assert leftovers < 16 43 44 def Prepare(self, emitter, registers, unused_kernel_size): 45 emitter.EmitNewline() 46 emitter.EmitComment('MinMax::Prepare') 47 48 self.min = _DuplicateGeneralRegister(8, emitter, registers, 49 registers.MapParameter('min', 50 'params.min'), 51 4) 52 self.max = _DuplicateGeneralRegister(8, emitter, registers, 53 registers.MapParameter('max', 54 'params.max'), 55 4) 56 57 def Transform(self, emitter, registers, input_address, elements, 58 output_address): 59 """Generate the MinMax transform inner loop code.""" 60 emitter.EmitNewline() 61 emitter.EmitComment('MinMax::Transform') 62 register_count = (elements + 15) / 16 63 load = [registers.QuadRegister() for unused_i in range(register_count)] 64 emitter.EmitVLoadAE(8, elements, load, input_address, None) 65 emitter.EmitPldOffset(input_address, emitter.ImmediateConstant(16)) 66 67 for register in load: 68 emitter.EmitVMax('u8', register, register, self.min) 69 70 for register in load: 71 emitter.EmitVMin('u8', register, register, self.max) 72 73 emitter.EmitNewline() 74 emitter.EmitVStoreAE(8, elements, load, output_address, None) 75 emitter.EmitPld(output_address) 76 registers.FreeRegisters(load) 77 78 79 class DequantizeTransformation(object): 80 """.""" 81 82 def Check(self, in_type, out_type, kernel_size, leftovers): 83 assert in_type is 'uint8_t' 84 assert out_type is 'float' 85 assert kernel_size is 16 86 assert leftovers < 16 87 88 def Prepare(self, emitter, registers, unused_kernel_size): 89 """Duplicate quantization offsets to vector registers.""" 90 emitter.EmitNewline() 91 emitter.EmitComment('Dequantize::Prepare') 92 93 self.range_min = _DuplicateGeneralRegister( 94 32, emitter, registers, 95 registers.MapParameter('range_min', 'params.range_min'), 4) 96 self.range_offset = _DuplicateGeneralRegister( 97 32, emitter, registers, 98 registers.MapParameter('range_offset', 'params.range_offset'), 4) 99 self.range_scale = _DuplicateGeneralRegister( 100 32, emitter, registers, 101 registers.MapParameter('range_scale', 'params.range_scale'), 4) 102 103 def Transform(self, emitter, registers, input_address, elements, 104 output_address): 105 """Emit the dequantization inner loop.""" 106 emitter.EmitNewline() 107 emitter.EmitComment('Dequantize::Transform') 108 register_count = (elements + 3) / 4 109 load = [registers.QuadRegister() for unused_i in range(register_count)] 110 emitter.EmitVLoadAE(8, elements, load, input_address, None) 111 emitter.EmitPldOffset(input_address, emitter.ImmediateConstant(32)) 112 113 if len(load) is 1: 114 emitter.EmitVMovl('u8', load[0], load[0]) 115 emitter.EmitVMovl('s16', load[0], load[0]) 116 elif len(load) is 2: 117 emitter.EmitVMovl('u8', load[0], load[0]) 118 emitter.EmitVMovl2('s16', load[0], load[1], load[0]) 119 elif len(load) is 3: 120 emitter.EmitVMovl2('u8', load[0], load[1], load[0]) 121 emitter.EmitVMovl('s16', load[2], load[1]) 122 emitter.EmitVMovl2('s16', load[0], load[1], load[0]) 123 elif len(load) is 4: 124 emitter.EmitVMovl2('u8', load[0], load[1], load[0]) 125 emitter.EmitVMovl2('s16', load[2], load[3], load[1]) 126 emitter.EmitVMovl2('s16', load[0], load[1], load[0]) 127 else: 128 assert False 129 130 for register in load: 131 emitter.EmitVCvt('f32', 's32', register, register) 132 133 for register in load: 134 emitter.EmitVSub('f32', register, register, self.range_offset) 135 136 for register in load: 137 emitter.EmitVMul('f32', register, register, self.range_scale) 138 139 for register in load: 140 emitter.EmitVAdd('f32', register, register, self.range_min) 141 142 emitter.EmitNewline() 143 emitter.EmitVStoreAE(32, elements, load, output_address, None) 144 emitter.EmitPld(output_address) 145 registers.FreeRegisters(load) 146 147 148 class QuantizeTransformation(object): 149 """.""" 150 151 def Check(self, in_type, out_type, kernel_size, leftovers): 152 assert in_type is 'float' 153 assert out_type is 'uint8_t' 154 assert kernel_size is 16 155 assert leftovers < 16 156 157 def Prepare(self, emitter, registers, unused_kernel_size): 158 """Duplicate quantization offsets to vector registers.""" 159 emitter.EmitNewline() 160 emitter.EmitComment('Quantize::Prepare') 161 162 self.range_min = _DuplicateGeneralRegister( 163 32, emitter, registers, 164 registers.MapParameter('range_min', 'params.range_min'), 4) 165 self.range_offset = _DuplicateGeneralRegister( 166 32, emitter, registers, 167 registers.MapParameter('range_offset', 'params.range_offset'), 4) 168 self.range_scale = _DuplicateGeneralRegister( 169 32, emitter, registers, 170 registers.MapParameter('range_scale', 'params.range_scale'), 4) 171 172 def Transform(self, emitter, registers, input_address, elements, 173 output_address): 174 """Emit quantization inner loop code.""" 175 emitter.EmitNewline() 176 emitter.EmitComment('Quantize::Transform') 177 register_count = (elements + 3) / 4 178 load = [registers.QuadRegister() for unused_i in range(register_count)] 179 emitter.EmitVLoadAE(32, elements, load, input_address, None) 180 emitter.EmitPldOffset(input_address, emitter.ImmediateConstant(64)) 181 182 for register in load: 183 emitter.EmitVSub('f32', register, register, self.range_min) 184 185 for register in load: 186 emitter.EmitVMul('f32', register, register, self.range_scale) 187 188 for register in load: 189 emitter.EmitVAdd('f32', register, register, self.range_offset) 190 191 for register in load: 192 emitter.EmitVCvt('s32', 'f32', register, register) 193 194 if len(load) is 1: 195 emitter.EmitVQmovn('s32', load[0], load[0]) 196 emitter.EmitVQmovun('s16', load[0], load[0]) 197 elif len(load) is 2: 198 emitter.EmitVQmovn2('s32', load[0], load[0], load[1]) 199 emitter.EmitVQmovun('s16', load[0], load[0]) 200 elif len(load) is 3: 201 emitter.EmitVQmovn2('s32', load[0], load[0], load[1]) 202 emitter.EmitVQmovn('s32', load[2], load[2]) 203 emitter.EmitVQmovun2('s16', load[0], load[0], load[2]) 204 elif len(load) is 4: 205 emitter.EmitVQmovn2('s32', load[0], load[0], load[1]) 206 emitter.EmitVQmovn2('s32', load[2], load[2], load[3]) 207 emitter.EmitVQmovun2('s16', load[0], load[0], load[2]) 208 else: 209 assert False 210 211 emitter.EmitNewline() 212 emitter.EmitVStoreAE(8, elements, load, output_address, None) 213 emitter.EmitPld(output_address) 214 registers.FreeRegisters(load) 215 216 217 class RequantizeTransformation(object): 218 """.""" 219 220 def Check(self, in_type, out_type, kernel_size, leftovers): 221 assert in_type is 'int32_t' 222 assert out_type is 'uint8_t' 223 assert kernel_size is 16 224 assert leftovers < 16 225 226 def Prepare(self, emitter, registers, unused_kernel_size): 227 """Duplicate quantization parameters to vector registers.""" 228 emitter.EmitNewline() 229 emitter.EmitComment('Requantize::Prepare') 230 231 self.range_min_delta = _DuplicateGeneralRegister( 232 32, emitter, registers, 233 registers.MapParameter('input_range_min', 'params.input_range_min'), 4) 234 self.output_range_min = _DuplicateGeneralRegister( 235 32, emitter, registers, 236 registers.MapParameter('output_range_min', 'params.output_range_min'), 237 4) 238 self.input_range_offset = _DuplicateGeneralRegister( 239 32, emitter, registers, 240 registers.MapParameter('input_range_offset', 241 'params.input_range_offset'), 4) 242 self.input_range_scale = _DuplicateGeneralRegister( 243 32, emitter, registers, 244 registers.MapParameter('input_range_scale', 'params.input_range_scale'), 245 4) 246 self.one_over_output_range_scale = _DuplicateGeneralRegister( 247 32, emitter, registers, 248 registers.MapParameter('one_over_output_range_scale', 249 'params.one_over_output_range_scale'), 4) 250 emitter.EmitVSub('f32', self.range_min_delta, self.range_min_delta, 251 self.output_range_min) 252 253 def Transform(self, emitter, registers, input_address, elements, 254 output_address): 255 """Emit requantization inner loop code.""" 256 emitter.EmitNewline() 257 emitter.EmitComment('Requantize::Transform') 258 register_count = (elements + 3) / 4 259 load = [registers.QuadRegister() for unused_i in range(register_count)] 260 emitter.EmitVLoadAE(32, elements, load, input_address, None) 261 emitter.EmitPldOffset(input_address, emitter.ImmediateConstant(64)) 262 263 for register in load: 264 emitter.EmitVCvt('f32', 's32', register, register) 265 266 for register in load: 267 emitter.EmitVSub('f32', register, register, self.input_range_offset) 268 269 for register in load: 270 emitter.EmitVMul('f32', register, register, self.input_range_scale) 271 272 for register in load: 273 emitter.EmitVAdd('f32', register, register, self.range_min_delta) 274 275 for register in load: 276 emitter.EmitVMul('f32', register, register, 277 self.one_over_output_range_scale) 278 279 for register in load: 280 emitter.EmitVCvt('s32', 'f32', register, register) 281 282 if len(load) is 1: 283 emitter.EmitVQmovn('s32', load[0], load[0]) 284 emitter.EmitVQmovun('s16', load[0], load[0]) 285 elif len(load) is 2: 286 emitter.EmitVQmovn2('s32', load[0], load[0], load[1]) 287 emitter.EmitVQmovun('s16', load[0], load[0]) 288 elif len(load) is 3: 289 emitter.EmitVQmovn2('s32', load[0], load[0], load[1]) 290 emitter.EmitVQmovn('s32', load[2], load[2]) 291 emitter.EmitVQmovun2('s16', load[0], load[0], load[2]) 292 elif len(load) is 4: 293 emitter.EmitVQmovn2('s32', load[0], load[0], load[1]) 294 emitter.EmitVQmovn2('s32', load[2], load[2], load[3]) 295 emitter.EmitVQmovun2('s16', load[0], load[0], load[2]) 296 else: 297 assert False 298 299 emitter.EmitNewline() 300 emitter.EmitVStoreAE(8, elements, load, output_address, None) 301 emitter.EmitPld(output_address) 302 registers.FreeRegisters(load) 303 304 305 class BaseTransform(common.Transform1DKernelGenerator): 306 """.""" 307 308 def __init__(self, cc_emitter, kernel_name, asm_emitter, transformation): 309 common.Transform1DKernelGenerator.__init__(self, cc_emitter, kernel_name) 310 self.asm_emitter = asm_emitter 311 self.transformation = transformation 312 313 def EmitTransform(self, in_type, out_type, kernel_size, leftovers): 314 """.""" 315 self.transformation.Check(in_type, out_type, kernel_size, leftovers) 316 317 registers = self.asm_emitter.CreateRegisters() 318 319 self.emitter.EmitDeclare('int', 'params_count_copy', 'params.count') 320 321 self.asm_emitter.PushIndent(self.emitter.indent) 322 self.asm_emitter.EmitAsmBegin() 323 324 count = registers.MapOutputParameter('count', 'params_count_copy') 325 input_address = registers.MapOutputParameter('input') 326 output_address = registers.MapOutputParameter('output') 327 328 self.transformation.Prepare(self.asm_emitter, registers, kernel_size) 329 330 if leftovers: 331 self.asm_emitter.EmitNewline() 332 self.asm_emitter.EmitComment('Reduce count by leftovers.') 333 self.asm_emitter.EmitSubs(count, count, 334 self.asm_emitter.ImmediateConstant(leftovers)) 335 self.asm_emitter.EmitBeqFront(2) 336 337 self.asm_emitter.EmitNewline() 338 self.asm_emitter.EmitNumericalLabel(1) 339 self.asm_emitter.EmitSubs(count, count, 340 self.asm_emitter.ImmediateConstant(kernel_size)) 341 342 self.transformation.Transform(self.asm_emitter, registers, input_address, 343 kernel_size, output_address) 344 345 self.asm_emitter.EmitNewline() 346 self.asm_emitter.EmitBneBack(1) 347 348 if leftovers: 349 self.asm_emitter.EmitNumericalLabel(2) 350 self.asm_emitter.EmitNewline() 351 self.asm_emitter.EmitComment('Handle leftovers.') 352 self.transformation.Transform(self.asm_emitter, registers, input_address, 353 leftovers, output_address) 354 355 self.asm_emitter.EmitAsmEnd(registers) 356 self.asm_emitter.PopIndent(len(self.emitter.indent)) 357 358 359 class Requantize(BaseTransform): 360 """.""" 361 362 def __init__(self, cc_emitter, asm_emitter): 363 BaseTransform.__init__(self, cc_emitter, 'Requantize', asm_emitter, 364 RequantizeTransformation()) 365 366 367 class Quantize(BaseTransform): 368 """.""" 369 370 def __init__(self, cc_emitter, asm_emitter): 371 BaseTransform.__init__(self, cc_emitter, 'Quantize', asm_emitter, 372 QuantizeTransformation()) 373 374 375 class Dequantize(BaseTransform): 376 """.""" 377 378 def __init__(self, cc_emitter, asm_emitter): 379 BaseTransform.__init__(self, cc_emitter, 'Dequantize', asm_emitter, 380 DequantizeTransformation()) 381 382 383 class MinMax(BaseTransform): 384 """.""" 385 386 def __init__(self, numerical_type, cc_emitter, asm_emitter): 387 BaseTransform.__init__(self, cc_emitter, 'MinMax<%s>' % numerical_type, 388 asm_emitter, MinMaxTransformation()) 389 390 391 class BiasAdd(common.Transform1DKernelGenerator): 392 """.""" 393 394 def __init__(self, bias_type, cc_emitter, asm_emitter): 395 common.Transform1DKernelGenerator.__init__(self, cc_emitter, 396 'BiasAdd<%s>' % bias_type) 397 self.asm_emitter = asm_emitter 398 399 def EmitTransform(self, in_type, out_type, kernel_size, leftovers): 400 """.""" 401 assert in_type is 'uint8_t' 402 assert out_type is 'int32_t' 403 assert kernel_size is 16 404 assert leftovers < 16 405 406 registers = self.asm_emitter.CreateRegisters() 407 408 self.emitter.EmitDeclare('int', 'params_rows_copy', 'params.rows') 409 410 self.asm_emitter.PushIndent(self.emitter.indent) 411 self.asm_emitter.EmitAsmBegin() 412 413 self._Prepare(self.asm_emitter, registers) 414 415 rows = registers.MapParameter('rows', 'params_rows_copy') 416 417 self.asm_emitter.EmitNumericalLabel(1) 418 419 self._ProcessRow(self.asm_emitter, registers, kernel_size, leftovers) 420 421 self.asm_emitter.EmitSubs(rows, rows, self.asm_emitter.ImmediateConstant(1)) 422 self.asm_emitter.EmitBneBack(1) 423 424 self.asm_emitter.EmitAsmEnd(registers) 425 self.asm_emitter.PopIndent(len(self.emitter.indent)) 426 427 def _Prepare(self, emitter, registers): 428 self.input_range_min = _DuplicateGeneralMemoryRegister( 429 32, emitter, registers, 430 registers.MapMemoryParameter('input_range_min', 431 'params.input_range_min'), 8) 432 self.input_range_scale = _DuplicateGeneralMemoryRegister( 433 32, emitter, registers, 434 registers.MapMemoryParameter('input_range_scale', 435 'params.input_range_scale'), 8) 436 self.bias_range_min = _DuplicateGeneralMemoryRegister( 437 32, emitter, registers, 438 registers.MapMemoryParameter('bias_range_min', 'params.bias_range_min'), 439 8) 440 self.bias_range_scale = _DuplicateGeneralMemoryRegister( 441 32, emitter, registers, 442 registers.MapMemoryParameter('bias_range_scale', 443 'params.bias_range_scale'), 8) 444 self.output_range_min = _DuplicateGeneralMemoryRegister( 445 32, emitter, registers, 446 registers.MapMemoryParameter('output_range_min', 447 'params.output_range_min'), 8) 448 self.one_over_output_range_scale = _DuplicateGeneralMemoryRegister( 449 32, emitter, registers, 450 registers.MapMemoryParameter('one_over_output_range_scale', 451 'params.one_over_output_range_scale'), 8) 452 self.output_range_offset = _DuplicateGeneralMemoryRegister( 453 32, emitter, registers, 454 registers.MapMemoryParameter('output_range_offset', 455 'params.output_range_offset'), 8) 456 457 def _ProcessRow(self, emitter, registers, kernel_size, leftovers): 458 const_count = registers.MapParameter('count', 'params.count') 459 const_bias = registers.MapParameter('bias', 'params.bias') 460 461 count = registers.GeneralRegister() 462 bias = registers.GeneralRegister() 463 464 input_address = registers.MapOutputParameter('input') 465 output_address = registers.MapOutputParameter('output') 466 467 emitter.EmitMov(count, const_count) 468 emitter.EmitMov(bias, const_bias) 469 470 if leftovers: 471 emitter.EmitSubs(count, count, emitter.ImmediateConstant(leftovers)) 472 emitter.EmitBeqFront(3) 473 474 emitter.EmitNumericalLabel(2) 475 emitter.EmitSubs(count, count, emitter.ImmediateConstant(kernel_size)) 476 477 self._BiasAdd(emitter, registers, kernel_size, input_address, bias, 478 output_address) 479 480 emitter.EmitBneBack(2) 481 482 if leftovers: 483 emitter.EmitNumericalLabel(3) 484 self._BiasAdd(emitter, registers, leftovers, input_address, bias, 485 output_address) 486 487 def _BiasAdd(self, emitter, registers, elements, input_address, bias, 488 output_address): 489 emitter.EmitNewline() 490 emitter.EmitComment('BiasAdd::Transform') 491 register_count = (elements + 3) / 4 492 493 load_input = [ 494 registers.QuadRegister() for unused_i in range(register_count) 495 ] 496 load_bias = [registers.QuadRegister() for unused_i in range(register_count)] 497 498 emitter.EmitVLoadAE(8, elements, load_input, input_address, None) 499 emitter.EmitVLoadAE(8, elements, load_bias, bias, None) 500 emitter.EmitPldOffset(input_address, emitter.ImmediateConstant(32)) 501 502 if len(load_input) is 1: 503 emitter.EmitVMovl('u8', load_input[0], load_input[0]) 504 emitter.EmitVMovl('u8', load_bias[0], load_bias[0]) 505 emitter.EmitVMovl('s16', load_input[0], load_input[0]) 506 emitter.EmitVMovl('s16', load_bias[0], load_bias[0]) 507 elif len(load_input) is 2: 508 emitter.EmitVMovl('u8', load_input[0], load_input[0]) 509 emitter.EmitVMovl('u8', load_bias[0], load_bias[0]) 510 emitter.EmitVMovl2('s16', load_input[0], load_input[1], load_input[0]) 511 emitter.EmitVMovl2('s16', load_bias[0], load_bias[1], load_bias[0]) 512 elif len(load_input) is 3: 513 emitter.EmitVMovl2('u8', load_input[0], load_input[1], load_input[0]) 514 emitter.EmitVMovl2('u8', load_bias[0], load_bias[1], load_bias[0]) 515 emitter.EmitVMovl('s16', load_input[2], load_input[1]) 516 emitter.EmitVMovl('s16', load_bias[2], load_bias[1]) 517 emitter.EmitVMovl2('s16', load_input[0], load_input[1], load_input[0]) 518 emitter.EmitVMovl2('s16', load_bias[0], load_bias[1], load_bias[0]) 519 elif len(load_input) is 4: 520 emitter.EmitVMovl2('u8', load_input[0], load_input[1], load_input[0]) 521 emitter.EmitVMovl2('u8', load_bias[0], load_bias[1], load_bias[0]) 522 emitter.EmitVMovl2('s16', load_input[2], load_input[3], load_input[1]) 523 emitter.EmitVMovl2('s16', load_bias[2], load_bias[3], load_bias[1]) 524 emitter.EmitVMovl2('s16', load_input[0], load_input[1], load_input[0]) 525 emitter.EmitVMovl2('s16', load_bias[0], load_bias[1], load_bias[0]) 526 else: 527 assert False 528 529 for register in load_input + load_bias: 530 emitter.EmitVCvt('f32', 's32', register, register) 531 532 for register in load_input: 533 emitter.EmitVMul('f32', register, register, self.input_range_scale) 534 535 for register in load_bias: 536 emitter.EmitVMul('f32', register, register, self.bias_range_scale) 537 538 for register in load_input: 539 emitter.EmitVAdd('f32', register, register, self.input_range_min) 540 541 for register in load_bias: 542 emitter.EmitVAdd('f32', register, register, self.bias_range_min) 543 544 for (register_1, register_2) in zip(load_input, load_bias): 545 emitter.EmitVAdd('f32', register_1, register_1, register_2) 546 547 for register in load_input: 548 emitter.EmitVSub('f32', register, register, self.output_range_min) 549 550 for register in load_input: 551 emitter.EmitVMul('f32', register, register, 552 self.one_over_output_range_scale) 553 554 for register in load_input: 555 emitter.EmitVAdd('f32', register, register, self.output_range_offset) 556 557 for register in load_input: 558 emitter.EmitVCvt('s32', 'f32', register, register) 559 560 emitter.EmitNewline() 561 emitter.EmitVStoreAE(32, elements, load_input, output_address, None) 562 emitter.EmitPld(output_address) 563 registers.FreeRegisters(load_input + load_bias) 564 565 566 def GenerateKernels(cc_emitter, asm_emitter, shapes): 567 """Generate the quantization/dequantization/requantization kernels.""" 568 requantize = Requantize(cc_emitter, asm_emitter) 569 quantize = Quantize(cc_emitter, asm_emitter) 570 dequantize = Dequantize(cc_emitter, asm_emitter) 571 minmax = MinMax('uint8_t', cc_emitter, asm_emitter) 572 biasadd = BiasAdd('uint8_t', cc_emitter, asm_emitter) 573 574 for shape in shapes: 575 requantize.SpecializeTransform1DKernel('int32_t', 'uint8_t', shape[0], 576 shape[1]) 577 578 for shape in shapes: 579 quantize.SpecializeTransform1DKernel('float', 'uint8_t', shape[0], shape[1]) 580 581 for shape in shapes: 582 dequantize.SpecializeTransform1DKernel('uint8_t', 'float', shape[0], 583 shape[1]) 584 585 for shape in shapes: 586 minmax.SpecializeTransform1DKernel('uint8_t', 'uint8_t', shape[0], shape[1]) 587 588 for shape in shapes: 589 biasadd.SpecializeTransform1DKernel('uint8_t', 'int32_t', shape[0], 590 shape[1]) 591