1 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // This file contains the AArch64 / Cortex-A57 specific register allocation 10 // constraints for use by the PBQP register allocator. 11 // 12 // It is essentially a transcription of what is contained in 13 // AArch64A57FPLoadBalancing, which tries to use a balanced 14 // mix of odd and even D-registers when performing a critical sequence of 15 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates. 16 //===----------------------------------------------------------------------===// 17 18 #define DEBUG_TYPE "aarch64-pbqp" 19 20 #include "AArch64.h" 21 #include "AArch64PBQPRegAlloc.h" 22 #include "AArch64RegisterInfo.h" 23 #include "llvm/CodeGen/LiveIntervalAnalysis.h" 24 #include "llvm/CodeGen/MachineBasicBlock.h" 25 #include "llvm/CodeGen/MachineFunction.h" 26 #include "llvm/CodeGen/MachineRegisterInfo.h" 27 #include "llvm/CodeGen/RegAllocPBQP.h" 28 #include "llvm/Support/Debug.h" 29 #include "llvm/Support/ErrorHandling.h" 30 #include "llvm/Support/raw_ostream.h" 31 32 using namespace llvm; 33 34 namespace { 35 36 #ifndef NDEBUG 37 bool isFPReg(unsigned reg) { 38 return AArch64::FPR32RegClass.contains(reg) || 39 AArch64::FPR64RegClass.contains(reg) || 40 AArch64::FPR128RegClass.contains(reg); 41 } 42 #endif 43 44 bool isOdd(unsigned reg) { 45 switch (reg) { 46 default: 47 llvm_unreachable("Register is not from the expected class !"); 48 case AArch64::S1: 49 case AArch64::S3: 50 case AArch64::S5: 51 case AArch64::S7: 52 case AArch64::S9: 53 case AArch64::S11: 54 case AArch64::S13: 55 case AArch64::S15: 56 case AArch64::S17: 57 case AArch64::S19: 58 case AArch64::S21: 59 case AArch64::S23: 60 case AArch64::S25: 61 case AArch64::S27: 62 case AArch64::S29: 63 case AArch64::S31: 64 case AArch64::D1: 65 case AArch64::D3: 66 case AArch64::D5: 67 case AArch64::D7: 68 case AArch64::D9: 69 case AArch64::D11: 70 case AArch64::D13: 71 case AArch64::D15: 72 case AArch64::D17: 73 case AArch64::D19: 74 case AArch64::D21: 75 case AArch64::D23: 76 case AArch64::D25: 77 case AArch64::D27: 78 case AArch64::D29: 79 case AArch64::D31: 80 case AArch64::Q1: 81 case AArch64::Q3: 82 case AArch64::Q5: 83 case AArch64::Q7: 84 case AArch64::Q9: 85 case AArch64::Q11: 86 case AArch64::Q13: 87 case AArch64::Q15: 88 case AArch64::Q17: 89 case AArch64::Q19: 90 case AArch64::Q21: 91 case AArch64::Q23: 92 case AArch64::Q25: 93 case AArch64::Q27: 94 case AArch64::Q29: 95 case AArch64::Q31: 96 return true; 97 case AArch64::S0: 98 case AArch64::S2: 99 case AArch64::S4: 100 case AArch64::S6: 101 case AArch64::S8: 102 case AArch64::S10: 103 case AArch64::S12: 104 case AArch64::S14: 105 case AArch64::S16: 106 case AArch64::S18: 107 case AArch64::S20: 108 case AArch64::S22: 109 case AArch64::S24: 110 case AArch64::S26: 111 case AArch64::S28: 112 case AArch64::S30: 113 case AArch64::D0: 114 case AArch64::D2: 115 case AArch64::D4: 116 case AArch64::D6: 117 case AArch64::D8: 118 case AArch64::D10: 119 case AArch64::D12: 120 case AArch64::D14: 121 case AArch64::D16: 122 case AArch64::D18: 123 case AArch64::D20: 124 case AArch64::D22: 125 case AArch64::D24: 126 case AArch64::D26: 127 case AArch64::D28: 128 case AArch64::D30: 129 case AArch64::Q0: 130 case AArch64::Q2: 131 case AArch64::Q4: 132 case AArch64::Q6: 133 case AArch64::Q8: 134 case AArch64::Q10: 135 case AArch64::Q12: 136 case AArch64::Q14: 137 case AArch64::Q16: 138 case AArch64::Q18: 139 case AArch64::Q20: 140 case AArch64::Q22: 141 case AArch64::Q24: 142 case AArch64::Q26: 143 case AArch64::Q28: 144 case AArch64::Q30: 145 return false; 146 147 } 148 } 149 150 bool haveSameParity(unsigned reg1, unsigned reg2) { 151 assert(isFPReg(reg1) && "Expecting an FP register for reg1"); 152 assert(isFPReg(reg2) && "Expecting an FP register for reg2"); 153 154 return isOdd(reg1) == isOdd(reg2); 155 } 156 157 } 158 159 bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd, 160 unsigned Ra) { 161 if (Rd == Ra) 162 return false; 163 164 LiveIntervals &LIs = G.getMetadata().LIS; 165 166 if (TRI->isPhysicalRegister(Rd) || TRI->isPhysicalRegister(Ra)) { 167 DEBUG(dbgs() << "Rd is a physical reg:" << TRI->isPhysicalRegister(Rd) 168 << '\n'); 169 DEBUG(dbgs() << "Ra is a physical reg:" << TRI->isPhysicalRegister(Ra) 170 << '\n'); 171 return false; 172 } 173 174 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd); 175 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra); 176 177 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed = 178 &G.getNodeMetadata(node1).getAllowedRegs(); 179 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed = 180 &G.getNodeMetadata(node2).getAllowedRegs(); 181 182 PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2); 183 184 // The edge does not exist. Create one with the appropriate interference 185 // costs. 186 if (edge == G.invalidEdgeId()) { 187 const LiveInterval &ld = LIs.getInterval(Rd); 188 const LiveInterval &la = LIs.getInterval(Ra); 189 bool livesOverlap = ld.overlaps(la); 190 191 PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1, 192 vRaAllowed->size() + 1, 0); 193 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { 194 unsigned pRd = (*vRdAllowed)[i]; 195 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { 196 unsigned pRa = (*vRaAllowed)[j]; 197 if (livesOverlap && TRI->regsOverlap(pRd, pRa)) 198 costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity(); 199 else 200 costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0; 201 } 202 } 203 G.addEdge(node1, node2, std::move(costs)); 204 return true; 205 } 206 207 if (G.getEdgeNode1Id(edge) == node2) { 208 std::swap(node1, node2); 209 std::swap(vRdAllowed, vRaAllowed); 210 } 211 212 // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass)) 213 PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge)); 214 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { 215 unsigned pRd = (*vRdAllowed)[i]; 216 217 // Get the maximum cost (excluding unallocatable reg) for same parity 218 // registers 219 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min(); 220 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { 221 unsigned pRa = (*vRaAllowed)[j]; 222 if (haveSameParity(pRd, pRa)) 223 if (costs[i + 1][j + 1] != 224 std::numeric_limits<PBQP::PBQPNum>::infinity() && 225 costs[i + 1][j + 1] > sameParityMax) 226 sameParityMax = costs[i + 1][j + 1]; 227 } 228 229 // Ensure all registers with a different parity have a higher cost 230 // than sameParityMax 231 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { 232 unsigned pRa = (*vRaAllowed)[j]; 233 if (!haveSameParity(pRd, pRa)) 234 if (sameParityMax > costs[i + 1][j + 1]) 235 costs[i + 1][j + 1] = sameParityMax + 1.0; 236 } 237 } 238 G.updateEdgeCosts(edge, std::move(costs)); 239 240 return true; 241 } 242 243 void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd, 244 unsigned Ra) { 245 LiveIntervals &LIs = G.getMetadata().LIS; 246 247 // Do some Chain management 248 if (Chains.count(Ra)) { 249 if (Rd != Ra) { 250 DEBUG(dbgs() << "Moving acc chain from " << PrintReg(Ra, TRI) << " to " 251 << PrintReg(Rd, TRI) << '\n';); 252 Chains.remove(Ra); 253 Chains.insert(Rd); 254 } 255 } else { 256 DEBUG(dbgs() << "Creating new acc chain for " << PrintReg(Rd, TRI) 257 << '\n';); 258 Chains.insert(Rd); 259 } 260 261 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd); 262 263 const LiveInterval &ld = LIs.getInterval(Rd); 264 for (auto r : Chains) { 265 // Skip self 266 if (r == Rd) 267 continue; 268 269 const LiveInterval &lr = LIs.getInterval(r); 270 if (ld.overlaps(lr)) { 271 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed = 272 &G.getNodeMetadata(node1).getAllowedRegs(); 273 274 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r); 275 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed = 276 &G.getNodeMetadata(node2).getAllowedRegs(); 277 278 PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2); 279 assert(edge != G.invalidEdgeId() && 280 "PBQP error ! The edge should exist !"); 281 282 DEBUG(dbgs() << "Refining constraint !\n";); 283 284 if (G.getEdgeNode1Id(edge) == node2) { 285 std::swap(node1, node2); 286 std::swap(vRdAllowed, vRrAllowed); 287 } 288 289 // Enforce that cost is higher with all other Chains of the same parity 290 PBQP::Matrix costs(G.getEdgeCosts(edge)); 291 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { 292 unsigned pRd = (*vRdAllowed)[i]; 293 294 // Get the maximum cost (excluding unallocatable reg) for all other 295 // parity registers 296 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min(); 297 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) { 298 unsigned pRa = (*vRrAllowed)[j]; 299 if (!haveSameParity(pRd, pRa)) 300 if (costs[i + 1][j + 1] != 301 std::numeric_limits<PBQP::PBQPNum>::infinity() && 302 costs[i + 1][j + 1] > sameParityMax) 303 sameParityMax = costs[i + 1][j + 1]; 304 } 305 306 // Ensure all registers with same parity have a higher cost 307 // than sameParityMax 308 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) { 309 unsigned pRa = (*vRrAllowed)[j]; 310 if (haveSameParity(pRd, pRa)) 311 if (sameParityMax > costs[i + 1][j + 1]) 312 costs[i + 1][j + 1] = sameParityMax + 1.0; 313 } 314 } 315 G.updateEdgeCosts(edge, std::move(costs)); 316 } 317 } 318 } 319 320 static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg, 321 const MachineInstr &MI) { 322 const LiveInterval &LI = LIs.getInterval(reg); 323 SlotIndex SI = LIs.getInstructionIndex(MI); 324 return LI.expiredAt(SI); 325 } 326 327 void A57ChainingConstraint::apply(PBQPRAGraph &G) { 328 const MachineFunction &MF = G.getMetadata().MF; 329 LiveIntervals &LIs = G.getMetadata().LIS; 330 331 TRI = MF.getSubtarget().getRegisterInfo(); 332 DEBUG(MF.dump()); 333 334 for (const auto &MBB: MF) { 335 Chains.clear(); // FIXME: really needed ? Could not work at MF level ? 336 337 for (const auto &MI: MBB) { 338 339 // Forget Chains which have expired 340 for (auto r : Chains) { 341 SmallVector<unsigned, 8> toDel; 342 if(regJustKilledBefore(LIs, r, MI)) { 343 DEBUG(dbgs() << "Killing chain " << PrintReg(r, TRI) << " at "; 344 MI.print(dbgs());); 345 toDel.push_back(r); 346 } 347 348 while (!toDel.empty()) { 349 Chains.remove(toDel.back()); 350 toDel.pop_back(); 351 } 352 } 353 354 switch (MI.getOpcode()) { 355 case AArch64::FMSUBSrrr: 356 case AArch64::FMADDSrrr: 357 case AArch64::FNMSUBSrrr: 358 case AArch64::FNMADDSrrr: 359 case AArch64::FMSUBDrrr: 360 case AArch64::FMADDDrrr: 361 case AArch64::FNMSUBDrrr: 362 case AArch64::FNMADDDrrr: { 363 unsigned Rd = MI.getOperand(0).getReg(); 364 unsigned Ra = MI.getOperand(3).getReg(); 365 366 if (addIntraChainConstraint(G, Rd, Ra)) 367 addInterChainConstraint(G, Rd, Ra); 368 break; 369 } 370 371 case AArch64::FMLAv2f32: 372 case AArch64::FMLSv2f32: { 373 unsigned Rd = MI.getOperand(0).getReg(); 374 addInterChainConstraint(G, Rd, Rd); 375 break; 376 } 377 378 default: 379 break; 380 } 381 } 382 } 383 } 384