1 """Zip primitive used by the GEMM function. 2 3 Takes 1 to 3 rows of data and interleaves them in 8 byte chunks. Pads to 4 multiply of 8 length with zeros. Calculates row sums and appends those at the 5 end. 6 """ 7 8 9 import neon_emitter 10 11 12 class Error(Exception): 13 """Module level error.""" 14 15 16 class ConfigurationError(Error): 17 """Unsupported configuration.""" 18 19 20 class ZipLane(object): 21 22 def __init__(self, input_address, load, aggregator): 23 self.input_address = input_address 24 self.load = load 25 self.aggregator = aggregator 26 27 28 def GenerateZipLanes(emitter, registers, zip_lanes, input_address, stride): 29 """Prepares read lanes for the zip operation. 30 31 Args: 32 emitter: ARM/NEON emitter. 33 registers: ARM/NEON registers state. 34 zip_lanes: number of lanes to prepare. 35 input_address: register that contains the input address for the first lane. 36 stride: memory stride for lane inputs. 37 38 Returns: 39 Array of ZipLane objects. 40 """ 41 lanes = [] 42 last_address_register = input_address 43 for i in range(0, zip_lanes): 44 if not i: 45 lanes.append(ZipLane(input_address, 46 registers.DoubleRegister(), 47 registers.QuadRegister(2))) 48 else: 49 address_register = registers.GeneralRegister() 50 lanes.append(ZipLane(address_register, 51 registers.DoubleRegister(), 52 registers.QuadRegister(2))) 53 emitter.EmitAdd(address_register, last_address_register, stride) 54 last_address_register = address_register 55 return lanes 56 57 58 def BuildName(zip_lanes, leftovers, aligned): 59 name = 'zip_%dx8' % zip_lanes 60 if leftovers: 61 name += '_%d' % leftovers 62 if aligned: 63 name += '_aligned' 64 return name 65 66 67 def GenerateClearAggregators(emitter, lanes): 68 for lane in lanes: 69 emitter.EmitVMov('i16', lane.aggregator, emitter.ImmediateConstant(0)) 70 71 72 def GenerateLoadAggregateStore(emitter, lanes, output_address, alignment): 73 """Emit inner loop code for reading N lanes and interweaving them.""" 74 emitter.EmitNewline() 75 emitter.EmitComment('Load Aggregate Store.') 76 77 for lane in lanes: 78 emitter.EmitVLoad( 79 '1.8', lane.load, 80 emitter.DereferenceIncrement(lane.input_address, alignment)) 81 82 store_registers = [] 83 for lane in lanes: 84 emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load) 85 store_registers.append(lane.load) 86 87 emitter.EmitVStoreA('1.8', store_registers, 88 emitter.DereferenceIncrement(output_address, 64)) 89 90 91 def GenerateLeftoverLoadAggregateStore( 92 emitter, leftovers, lanes, output_address): 93 """Handle leftovers when count is not a multiply of 8.""" 94 emitter.EmitNewline() 95 emitter.EmitComment('Leftover Load Aggregate Store.') 96 97 # Clear load registers. 98 for lane in lanes: 99 emitter.EmitVMov('i8', lane.load, emitter.ImmediateConstant(0)) 100 101 if leftovers == 1: 102 # Load 8 bits. 103 for lane in lanes: 104 emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 0), 105 emitter.Dereference(lane.input_address, None)) 106 elif leftovers == 2: 107 # Load 16 bits. 108 for lane in lanes: 109 emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 0), 110 emitter.Dereference(lane.input_address, None)) 111 elif leftovers == 3: 112 # Load 16 bits. 113 for lane in lanes: 114 emitter.EmitVLoad( 115 '1.16', emitter.Lane(lane.load, 0), 116 emitter.DereferenceIncrement(lane.input_address, None)) 117 # Load 8 bits. 118 for lane in lanes: 119 emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 2), 120 emitter.Dereference(lane.input_address, None)) 121 elif leftovers == 4: 122 # Load 32 bits. 123 for lane in lanes: 124 emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0), 125 emitter.Dereference(lane.input_address, None)) 126 elif leftovers == 5: 127 # Load 32 bits.. 128 for lane in lanes: 129 emitter.EmitVLoad( 130 '1.32', emitter.Lane(lane.load, 0), 131 emitter.DereferenceIncrement(lane.input_address, None)) 132 # Load 8 bits. 133 for lane in lanes: 134 emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 4), 135 emitter.Dereference(lane.input_address, None)) 136 elif leftovers == 6: 137 # Load 32 bits.. 138 for lane in lanes: 139 emitter.EmitVLoad( 140 '1.32', emitter.Lane(lane.load, 0), 141 emitter.DereferenceIncrement(lane.input_address, None)) 142 # Load 16 bits. 143 for lane in lanes: 144 emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 2), 145 emitter.Dereference(lane.input_address, None)) 146 elif leftovers == 7: 147 # Load 32 bits.. 148 for lane in lanes: 149 emitter.EmitVLoad( 150 '1.32', emitter.Lane(lane.load, 0), 151 emitter.DereferenceIncrement(lane.input_address, None)) 152 # Load 16 bits. 153 for lane in lanes: 154 emitter.EmitVLoad( 155 '1.16', emitter.Lane(lane.load, 2), 156 emitter.DereferenceIncrement(lane.input_address, None)) 157 # Load 8 bits. 158 for lane in lanes: 159 emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 6), 160 emitter.Dereference(lane.input_address, None)) 161 else: 162 raise ConfigurationError('Unsupported leftover num: %d' % leftovers) 163 164 # Aggregate. 165 store_registers = [] 166 for lane in lanes: 167 emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load) 168 store_registers.append(lane.load) 169 170 # Store. 171 emitter.EmitVStoreA('1.8', store_registers, 172 emitter.DereferenceIncrement(output_address, 64)) 173 174 175 def GenerateAggregatorReduction(emitter, 176 registers, 177 lanes, 178 output_address, 179 multiplicative_offset, 180 additive_offset): 181 """Reduce 4 lane sum aggregators to 1 value and store the sums.""" 182 emitter.EmitNewline() 183 emitter.EmitComment('Aggregator Reduction.') 184 185 multiplier = registers.DoubleRegister() 186 emitter.EmitVMov('32', emitter.Lane(multiplier, 0), multiplicative_offset) 187 offset = registers.QuadRegister() 188 emitter.EmitVDup('32', offset, additive_offset) 189 190 lane_temps = [] 191 for lane in lanes: 192 emitter.EmitVPaddl('u16', lane.aggregator, lane.aggregator) 193 194 for lane in lanes: 195 lane_temp = registers.DoubleRegister() 196 lane_temps.append(lane_temp) 197 emitter.EmitVPadd('u32', 198 lane_temp, 199 registers.Low(lane.aggregator), 200 registers.High(lane.aggregator)) 201 202 temp = registers.QuadRegister() 203 low = registers.Low(temp) 204 high = registers.High(temp) 205 206 if len(lanes) == 1: 207 emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[0]) 208 elif len(lanes) == 2: 209 emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1]) 210 elif len(lanes) == 3: 211 emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1]) 212 emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[2]) 213 elif len(lanes) == 4: 214 emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1]) 215 emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[3]) 216 else: 217 raise ConfigurationError( 218 'Unexpected number of aggregators to reduce: %d' % len(lanes)) 219 220 emitter.EmitVMul('i32', temp, temp, emitter.Lane(multiplier, 0)) 221 emitter.EmitVAdd('i32', temp, temp, offset) 222 223 if len(lanes) == 1: 224 emitter.EmitVStore( 225 '1.32', emitter.Lane(low, 0), emitter.Dereference(output_address, None)) 226 elif len(lanes) == 2: 227 emitter.EmitVStore('1.32', low, emitter.Dereference(output_address, 64)) 228 elif len(lanes) == 3: 229 emitter.EmitVStore( 230 '1.32', low, emitter.DereferenceIncrement(output_address, 64)) 231 emitter.EmitVStore( 232 '1.32', emitter.Lane(high, 0), 233 emitter.Dereference(output_address, None)) 234 elif len(lanes) == 4: 235 emitter.EmitVStore( 236 '1.32', low, emitter.DereferenceIncrement(output_address, 64)) 237 emitter.EmitVStore('1.32', high, emitter.Dereference(output_address, 64)) 238 239 240 def GenerateZipNx8(emitter, zip_lanes, leftovers, aligned): 241 """Emit the zip function for a given number of rows and row size leftovers.""" 242 if leftovers < 0 or leftovers > 7: 243 raise ConfigurationError('Leftovers should be between 0 and 7 inclusive.') 244 if zip_lanes < 1 or zip_lanes > 3: 245 raise ConfigurationError('Zip_lanes should should be 1, 2 or 3.') 246 247 name = BuildName(zip_lanes, leftovers, aligned) 248 249 emitter.EmitFunctionBeginA(name, 250 [['const std::uint8_t*', 'source'], 251 ['std::int32_t', 'count'], 252 ['std::int32_t', 'stride'], 253 ['std::uint8_t*', 'destination'], 254 ['std::int32_t', 'multiplicative_offset'], 255 ['std::int32_t', 'additive_offset']], 256 'void') 257 emitter.EmitAssert('count %% 8 == %d' % leftovers) 258 emitter.EmitAssert('count <= 2048') 259 emitter.EmitAssert('count >= 8') 260 emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(destination) % 8 == 0') 261 if aligned: 262 emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(source) % 8 == 0') 263 if zip_lanes > 1: 264 emitter.EmitAssert('stride % 8 == 0') 265 emitter.EmitAsmBegin() 266 267 registers = neon_emitter.NeonRegisters() 268 269 count = registers.MapParameter('count') 270 output_address = registers.MapParameter('destination') 271 272 lanes = GenerateZipLanes(emitter, 273 registers, 274 zip_lanes, 275 registers.MapParameter('source'), 276 registers.MapParameter('stride')) 277 278 if leftovers: 279 emitter.EmitSub(count, count, emitter.ImmediateConstant(leftovers)) 280 281 GenerateClearAggregators(emitter, lanes) 282 283 emitter.EmitNewline() 284 emitter.EmitNumericalLabel(1) 285 emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) 286 287 GenerateLoadAggregateStore( 288 emitter, lanes, output_address, 64 if aligned else None) 289 290 emitter.EmitNewline() 291 emitter.EmitBneBack(1) 292 293 if leftovers: 294 GenerateLeftoverLoadAggregateStore( 295 emitter, leftovers, lanes, output_address) 296 297 GenerateAggregatorReduction(emitter, 298 registers, 299 lanes, 300 output_address, 301 registers.MapParameter('multiplicative_offset'), 302 registers.MapParameter('additive_offset')) 303 304 emitter.EmitAsmEnd(registers.MappedParameters(), 305 [], 306 registers.Clobbers() + ['cc', 'memory']) 307 emitter.EmitFunctionEnd() 308 309 310 def GenerateFunctions(emitter): 311 for aligned in [True, False]: 312 for lanes in range(1, 4): 313 for leftovers in range(0, 8): 314 GenerateZipNx8(emitter, lanes, leftovers, aligned) 315 emitter.EmitNewline() 316