1 """Zip primitive used by the GEMM function. 2 3 Takes 1 to 3 rows of data and interleaves them in 8 byte chunks. Pads to 4 multiply of 8 length with zeros. Calculates row sums and appends those at the 5 end. 6 """ 7 8 import neon_emitter 9 10 11 class Error(Exception): 12 """Module level error.""" 13 14 15 class ConfigurationError(Error): 16 """Unsupported configuration.""" 17 18 19 class ZipLane(object): 20 21 def __init__(self, input_address, load, aggregator): 22 self.input_address = input_address 23 self.load = load 24 self.aggregator = aggregator 25 26 27 def GenerateZipLanes(emitter, registers, zip_lanes, input_address, stride): 28 """Prepares read lanes for the zip operation. 29 30 Args: 31 emitter: ARM/NEON emitter. 32 registers: ARM/NEON registers state. 33 zip_lanes: number of lanes to prepare. 34 input_address: register that contains the input address for the first lane. 35 stride: memory stride for lane inputs. 36 37 Returns: 38 Array of ZipLane objects. 39 """ 40 lanes = [] 41 last_address_register = input_address 42 for i in range(0, zip_lanes): 43 if not i: 44 lanes.append(ZipLane(input_address, registers.DoubleRegister(), 45 registers.QuadRegister(2))) 46 else: 47 address_register = registers.GeneralRegister() 48 lanes.append(ZipLane(address_register, registers.DoubleRegister(), 49 registers.QuadRegister(2))) 50 emitter.EmitAdd(address_register, last_address_register, stride) 51 last_address_register = address_register 52 return lanes 53 54 55 def BuildName(zip_lanes, leftovers, aligned): 56 name = 'zip_%dx8' % zip_lanes 57 if leftovers: 58 name += '_%d' % leftovers 59 if aligned: 60 name += '_aligned' 61 return name 62 63 64 def GenerateClearAggregators(emitter, lanes): 65 for lane in lanes: 66 emitter.EmitVMov('i16', lane.aggregator, emitter.ImmediateConstant(0)) 67 68 69 def GenerateLoadAggregateStore(emitter, lanes, output_address, alignment): 70 """Emit inner loop code for reading N lanes and interweaving them.""" 71 emitter.EmitNewline() 72 emitter.EmitComment('Load Aggregate Store.') 73 74 for lane in lanes: 75 emitter.EmitVLoad( 76 '1.8', lane.load, 77 emitter.DereferenceIncrement(lane.input_address, alignment)) 78 79 store_registers = [] 80 for lane in lanes: 81 emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load) 82 store_registers.append(lane.load) 83 84 emitter.EmitVStoreA('1.8', store_registers, 85 emitter.DereferenceIncrement(output_address, 64)) 86 87 88 def GenerateLeftoverLoadAggregateStore(emitter, leftovers, lanes, 89 output_address): 90 """Handle leftovers when count is not a multiply of 8.""" 91 emitter.EmitNewline() 92 emitter.EmitComment('Leftover Load Aggregate Store.') 93 94 # Clear load registers. 95 for lane in lanes: 96 emitter.EmitVMov('i8', lane.load, emitter.ImmediateConstant(0)) 97 98 if leftovers == 1: 99 # Load 8 bits. 100 for lane in lanes: 101 emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 0), 102 emitter.Dereference(lane.input_address, None)) 103 elif leftovers == 2: 104 # Load 16 bits. 105 for lane in lanes: 106 emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 0), 107 emitter.Dereference(lane.input_address, None)) 108 elif leftovers == 3: 109 # Load 16 bits. 110 for lane in lanes: 111 emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 0), 112 emitter.DereferenceIncrement(lane.input_address, None)) 113 # Load 8 bits. 114 for lane in lanes: 115 emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 2), 116 emitter.Dereference(lane.input_address, None)) 117 elif leftovers == 4: 118 # Load 32 bits. 119 for lane in lanes: 120 emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0), 121 emitter.Dereference(lane.input_address, None)) 122 elif leftovers == 5: 123 # Load 32 bits.. 124 for lane in lanes: 125 emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0), 126 emitter.DereferenceIncrement(lane.input_address, None)) 127 # Load 8 bits. 128 for lane in lanes: 129 emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 4), 130 emitter.Dereference(lane.input_address, None)) 131 elif leftovers == 6: 132 # Load 32 bits.. 133 for lane in lanes: 134 emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0), 135 emitter.DereferenceIncrement(lane.input_address, None)) 136 # Load 16 bits. 137 for lane in lanes: 138 emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 2), 139 emitter.Dereference(lane.input_address, None)) 140 elif leftovers == 7: 141 # Load 32 bits.. 142 for lane in lanes: 143 emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0), 144 emitter.DereferenceIncrement(lane.input_address, None)) 145 # Load 16 bits. 146 for lane in lanes: 147 emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 2), 148 emitter.DereferenceIncrement(lane.input_address, None)) 149 # Load 8 bits. 150 for lane in lanes: 151 emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 6), 152 emitter.Dereference(lane.input_address, None)) 153 else: 154 raise ConfigurationError('Unsupported leftover num: %d' % leftovers) 155 156 # Aggregate. 157 store_registers = [] 158 for lane in lanes: 159 emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load) 160 store_registers.append(lane.load) 161 162 # Store. 163 emitter.EmitVStoreA('1.8', store_registers, 164 emitter.DereferenceIncrement(output_address, 64)) 165 166 167 def GenerateAggregatorReduction(emitter, registers, lanes, output_address, 168 multiplicative_offset, additive_offset): 169 """Reduce 4 lane sum aggregators to 1 value and store the sums.""" 170 emitter.EmitNewline() 171 emitter.EmitComment('Aggregator Reduction.') 172 173 multiplier = registers.DoubleRegister() 174 emitter.EmitVMov('32', emitter.Lane(multiplier, 0), multiplicative_offset) 175 offset = registers.QuadRegister() 176 emitter.EmitVDup('32', offset, additive_offset) 177 178 lane_temps = [] 179 for lane in lanes: 180 emitter.EmitVPaddl('u16', lane.aggregator, lane.aggregator) 181 182 for lane in lanes: 183 lane_temp = registers.DoubleRegister() 184 lane_temps.append(lane_temp) 185 emitter.EmitVPadd('u32', lane_temp, registers.Low(lane.aggregator), 186 registers.High(lane.aggregator)) 187 188 temp = registers.QuadRegister() 189 low = registers.Low(temp) 190 high = registers.High(temp) 191 192 if len(lanes) == 1: 193 emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[0]) 194 elif len(lanes) == 2: 195 emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1]) 196 elif len(lanes) == 3: 197 emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1]) 198 emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[2]) 199 elif len(lanes) == 4: 200 emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1]) 201 emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[3]) 202 else: 203 raise ConfigurationError('Unexpected number of aggregators to reduce: %d' % 204 len(lanes)) 205 206 emitter.EmitVMul('i32', temp, temp, emitter.Lane(multiplier, 0)) 207 emitter.EmitVAdd('i32', temp, temp, offset) 208 209 if len(lanes) == 1: 210 emitter.EmitVStore('1.32', emitter.Lane(low, 0), 211 emitter.Dereference(output_address, None)) 212 elif len(lanes) == 2: 213 emitter.EmitVStore('1.32', low, emitter.Dereference(output_address, 64)) 214 elif len(lanes) == 3: 215 emitter.EmitVStore('1.32', low, 216 emitter.DereferenceIncrement(output_address, 64)) 217 emitter.EmitVStore('1.32', emitter.Lane(high, 0), 218 emitter.Dereference(output_address, None)) 219 elif len(lanes) == 4: 220 emitter.EmitVStoreA('1.32', [low, high], 221 emitter.DereferenceIncrement(output_address, 64)) 222 223 224 def GenerateZipNx8(emitter, zip_lanes, leftovers, aligned): 225 """Emit the zip function for a given number of rows and row size leftovers.""" 226 if leftovers < 0 or leftovers > 7: 227 raise ConfigurationError('Leftovers should be between 0 and 7 inclusive.') 228 if zip_lanes < 1 or zip_lanes > 4: 229 raise ConfigurationError('Zip_lanes should should be 1, 2, 3 or 4.') 230 231 name = BuildName(zip_lanes, leftovers, aligned) 232 233 emitter.EmitFunctionBeginA( 234 name, [['const std::uint8_t*', 'source'], ['std::int32_t', 'count'], 235 ['std::int32_t', 'stride'], ['std::uint8_t*', 'destination'], 236 ['std::int32_t', 'multiplicative_offset'], 237 ['std::int32_t', 'additive_offset']], 'void') 238 emitter.EmitAssert('count %% 8 == %d' % leftovers) 239 emitter.EmitAssert('count <= 2048') 240 emitter.EmitAssert('count >= 8') 241 emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(destination) % 8 == 0') 242 if aligned: 243 emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(source) % 8 == 0') 244 if zip_lanes > 1: 245 emitter.EmitAssert('stride % 8 == 0') 246 emitter.EmitAsmBegin() 247 248 registers = neon_emitter.NeonRegisters() 249 250 count = registers.MapParameter('count') 251 output_address = registers.MapParameter('destination') 252 253 lanes = GenerateZipLanes(emitter, registers, zip_lanes, 254 registers.MapParameter('source'), 255 registers.MapParameter('stride')) 256 257 if leftovers: 258 emitter.EmitSub(count, count, emitter.ImmediateConstant(leftovers)) 259 260 GenerateClearAggregators(emitter, lanes) 261 262 emitter.EmitNewline() 263 emitter.EmitNumericalLabel(1) 264 emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) 265 266 GenerateLoadAggregateStore(emitter, lanes, output_address, 64 if aligned else 267 None) 268 269 emitter.EmitNewline() 270 emitter.EmitBneBack(1) 271 272 if leftovers: 273 GenerateLeftoverLoadAggregateStore(emitter, leftovers, lanes, 274 output_address) 275 276 GenerateAggregatorReduction(emitter, registers, lanes, output_address, 277 registers.MapParameter('multiplicative_offset'), 278 registers.MapParameter('additive_offset')) 279 280 emitter.EmitAsmEnd(registers.MappedParameters(), [], 281 registers.Clobbers() + ['cc', 'memory']) 282 emitter.EmitFunctionEnd() 283 284 285 def GenerateFunctions(emitter): 286 for aligned in [True, False]: 287 for lanes in range(1, 5): 288 for leftovers in range(0, 8): 289 GenerateZipNx8(emitter, lanes, leftovers, aligned) 290 emitter.EmitNewline() 291