Home | History | Annotate | Download | only in AArch64
      1 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 // This file contains the AArch64 / Cortex-A57 specific register allocation
     10 // constraints for use by the PBQP register allocator.
     11 //
     12 // It is essentially a transcription of what is contained in
     13 // AArch64A57FPLoadBalancing, which tries to use a balanced
     14 // mix of odd and even D-registers when performing a critical sequence of
     15 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
     16 //===----------------------------------------------------------------------===//
     17 
     18 #define DEBUG_TYPE "aarch64-pbqp"
     19 
     20 #include "AArch64.h"
     21 #include "AArch64PBQPRegAlloc.h"
     22 #include "AArch64RegisterInfo.h"
     23 #include "llvm/CodeGen/LiveIntervalAnalysis.h"
     24 #include "llvm/CodeGen/MachineBasicBlock.h"
     25 #include "llvm/CodeGen/MachineFunction.h"
     26 #include "llvm/CodeGen/MachineRegisterInfo.h"
     27 #include "llvm/CodeGen/RegAllocPBQP.h"
     28 #include "llvm/Support/Debug.h"
     29 #include "llvm/Support/ErrorHandling.h"
     30 #include "llvm/Support/raw_ostream.h"
     31 
     32 using namespace llvm;
     33 
     34 namespace {
     35 
     36 #ifndef NDEBUG
     37 bool isFPReg(unsigned reg) {
     38   return AArch64::FPR32RegClass.contains(reg) ||
     39          AArch64::FPR64RegClass.contains(reg) ||
     40          AArch64::FPR128RegClass.contains(reg);
     41 }
     42 #endif
     43 
     44 bool isOdd(unsigned reg) {
     45   switch (reg) {
     46   default:
     47     llvm_unreachable("Register is not from the expected class !");
     48   case AArch64::S1:
     49   case AArch64::S3:
     50   case AArch64::S5:
     51   case AArch64::S7:
     52   case AArch64::S9:
     53   case AArch64::S11:
     54   case AArch64::S13:
     55   case AArch64::S15:
     56   case AArch64::S17:
     57   case AArch64::S19:
     58   case AArch64::S21:
     59   case AArch64::S23:
     60   case AArch64::S25:
     61   case AArch64::S27:
     62   case AArch64::S29:
     63   case AArch64::S31:
     64   case AArch64::D1:
     65   case AArch64::D3:
     66   case AArch64::D5:
     67   case AArch64::D7:
     68   case AArch64::D9:
     69   case AArch64::D11:
     70   case AArch64::D13:
     71   case AArch64::D15:
     72   case AArch64::D17:
     73   case AArch64::D19:
     74   case AArch64::D21:
     75   case AArch64::D23:
     76   case AArch64::D25:
     77   case AArch64::D27:
     78   case AArch64::D29:
     79   case AArch64::D31:
     80   case AArch64::Q1:
     81   case AArch64::Q3:
     82   case AArch64::Q5:
     83   case AArch64::Q7:
     84   case AArch64::Q9:
     85   case AArch64::Q11:
     86   case AArch64::Q13:
     87   case AArch64::Q15:
     88   case AArch64::Q17:
     89   case AArch64::Q19:
     90   case AArch64::Q21:
     91   case AArch64::Q23:
     92   case AArch64::Q25:
     93   case AArch64::Q27:
     94   case AArch64::Q29:
     95   case AArch64::Q31:
     96     return true;
     97   case AArch64::S0:
     98   case AArch64::S2:
     99   case AArch64::S4:
    100   case AArch64::S6:
    101   case AArch64::S8:
    102   case AArch64::S10:
    103   case AArch64::S12:
    104   case AArch64::S14:
    105   case AArch64::S16:
    106   case AArch64::S18:
    107   case AArch64::S20:
    108   case AArch64::S22:
    109   case AArch64::S24:
    110   case AArch64::S26:
    111   case AArch64::S28:
    112   case AArch64::S30:
    113   case AArch64::D0:
    114   case AArch64::D2:
    115   case AArch64::D4:
    116   case AArch64::D6:
    117   case AArch64::D8:
    118   case AArch64::D10:
    119   case AArch64::D12:
    120   case AArch64::D14:
    121   case AArch64::D16:
    122   case AArch64::D18:
    123   case AArch64::D20:
    124   case AArch64::D22:
    125   case AArch64::D24:
    126   case AArch64::D26:
    127   case AArch64::D28:
    128   case AArch64::D30:
    129   case AArch64::Q0:
    130   case AArch64::Q2:
    131   case AArch64::Q4:
    132   case AArch64::Q6:
    133   case AArch64::Q8:
    134   case AArch64::Q10:
    135   case AArch64::Q12:
    136   case AArch64::Q14:
    137   case AArch64::Q16:
    138   case AArch64::Q18:
    139   case AArch64::Q20:
    140   case AArch64::Q22:
    141   case AArch64::Q24:
    142   case AArch64::Q26:
    143   case AArch64::Q28:
    144   case AArch64::Q30:
    145     return false;
    146 
    147   }
    148 }
    149 
    150 bool haveSameParity(unsigned reg1, unsigned reg2) {
    151   assert(isFPReg(reg1) && "Expecting an FP register for reg1");
    152   assert(isFPReg(reg2) && "Expecting an FP register for reg2");
    153 
    154   return isOdd(reg1) == isOdd(reg2);
    155 }
    156 
    157 }
    158 
    159 bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
    160                                                  unsigned Ra) {
    161   if (Rd == Ra)
    162     return false;
    163 
    164   LiveIntervals &LIs = G.getMetadata().LIS;
    165 
    166   if (TRI->isPhysicalRegister(Rd) || TRI->isPhysicalRegister(Ra)) {
    167     DEBUG(dbgs() << "Rd is a physical reg:" << TRI->isPhysicalRegister(Rd)
    168           << '\n');
    169     DEBUG(dbgs() << "Ra is a physical reg:" << TRI->isPhysicalRegister(Ra)
    170           << '\n');
    171     return false;
    172   }
    173 
    174   PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
    175   PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
    176 
    177   const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
    178     &G.getNodeMetadata(node1).getAllowedRegs();
    179   const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
    180     &G.getNodeMetadata(node2).getAllowedRegs();
    181 
    182   PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
    183 
    184   // The edge does not exist. Create one with the appropriate interference
    185   // costs.
    186   if (edge == G.invalidEdgeId()) {
    187     const LiveInterval &ld = LIs.getInterval(Rd);
    188     const LiveInterval &la = LIs.getInterval(Ra);
    189     bool livesOverlap = ld.overlaps(la);
    190 
    191     PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
    192                                  vRaAllowed->size() + 1, 0);
    193     for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
    194       unsigned pRd = (*vRdAllowed)[i];
    195       for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
    196         unsigned pRa = (*vRaAllowed)[j];
    197         if (livesOverlap && TRI->regsOverlap(pRd, pRa))
    198           costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
    199         else
    200           costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
    201       }
    202     }
    203     G.addEdge(node1, node2, std::move(costs));
    204     return true;
    205   }
    206 
    207   if (G.getEdgeNode1Id(edge) == node2) {
    208     std::swap(node1, node2);
    209     std::swap(vRdAllowed, vRaAllowed);
    210   }
    211 
    212   // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
    213   PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
    214   for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
    215     unsigned pRd = (*vRdAllowed)[i];
    216 
    217     // Get the maximum cost (excluding unallocatable reg) for same parity
    218     // registers
    219     PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
    220     for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
    221       unsigned pRa = (*vRaAllowed)[j];
    222       if (haveSameParity(pRd, pRa))
    223         if (costs[i + 1][j + 1] !=
    224                 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
    225             costs[i + 1][j + 1] > sameParityMax)
    226           sameParityMax = costs[i + 1][j + 1];
    227     }
    228 
    229     // Ensure all registers with a different parity have a higher cost
    230     // than sameParityMax
    231     for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
    232       unsigned pRa = (*vRaAllowed)[j];
    233       if (!haveSameParity(pRd, pRa))
    234         if (sameParityMax > costs[i + 1][j + 1])
    235           costs[i + 1][j + 1] = sameParityMax + 1.0;
    236     }
    237   }
    238   G.updateEdgeCosts(edge, std::move(costs));
    239 
    240   return true;
    241 }
    242 
    243 void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
    244                                                  unsigned Ra) {
    245   LiveIntervals &LIs = G.getMetadata().LIS;
    246 
    247   // Do some Chain management
    248   if (Chains.count(Ra)) {
    249     if (Rd != Ra) {
    250       DEBUG(dbgs() << "Moving acc chain from " << PrintReg(Ra, TRI) << " to "
    251                    << PrintReg(Rd, TRI) << '\n';);
    252       Chains.remove(Ra);
    253       Chains.insert(Rd);
    254     }
    255   } else {
    256     DEBUG(dbgs() << "Creating new acc chain for " << PrintReg(Rd, TRI)
    257                  << '\n';);
    258     Chains.insert(Rd);
    259   }
    260 
    261   PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
    262 
    263   const LiveInterval &ld = LIs.getInterval(Rd);
    264   for (auto r : Chains) {
    265     // Skip self
    266     if (r == Rd)
    267       continue;
    268 
    269     const LiveInterval &lr = LIs.getInterval(r);
    270     if (ld.overlaps(lr)) {
    271       const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
    272         &G.getNodeMetadata(node1).getAllowedRegs();
    273 
    274       PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
    275       const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
    276         &G.getNodeMetadata(node2).getAllowedRegs();
    277 
    278       PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
    279       assert(edge != G.invalidEdgeId() &&
    280              "PBQP error ! The edge should exist !");
    281 
    282       DEBUG(dbgs() << "Refining constraint !\n";);
    283 
    284       if (G.getEdgeNode1Id(edge) == node2) {
    285         std::swap(node1, node2);
    286         std::swap(vRdAllowed, vRrAllowed);
    287       }
    288 
    289       // Enforce that cost is higher with all other Chains of the same parity
    290       PBQP::Matrix costs(G.getEdgeCosts(edge));
    291       for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
    292         unsigned pRd = (*vRdAllowed)[i];
    293 
    294         // Get the maximum cost (excluding unallocatable reg) for all other
    295         // parity registers
    296         PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
    297         for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
    298           unsigned pRa = (*vRrAllowed)[j];
    299           if (!haveSameParity(pRd, pRa))
    300             if (costs[i + 1][j + 1] !=
    301                     std::numeric_limits<PBQP::PBQPNum>::infinity() &&
    302                 costs[i + 1][j + 1] > sameParityMax)
    303               sameParityMax = costs[i + 1][j + 1];
    304         }
    305 
    306         // Ensure all registers with same parity have a higher cost
    307         // than sameParityMax
    308         for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
    309           unsigned pRa = (*vRrAllowed)[j];
    310           if (haveSameParity(pRd, pRa))
    311             if (sameParityMax > costs[i + 1][j + 1])
    312               costs[i + 1][j + 1] = sameParityMax + 1.0;
    313         }
    314       }
    315       G.updateEdgeCosts(edge, std::move(costs));
    316     }
    317   }
    318 }
    319 
    320 static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
    321                                 const MachineInstr &MI) {
    322   const LiveInterval &LI = LIs.getInterval(reg);
    323   SlotIndex SI = LIs.getInstructionIndex(MI);
    324   return LI.expiredAt(SI);
    325 }
    326 
    327 void A57ChainingConstraint::apply(PBQPRAGraph &G) {
    328   const MachineFunction &MF = G.getMetadata().MF;
    329   LiveIntervals &LIs = G.getMetadata().LIS;
    330 
    331   TRI = MF.getSubtarget().getRegisterInfo();
    332   DEBUG(MF.dump());
    333 
    334   for (const auto &MBB: MF) {
    335     Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
    336 
    337     for (const auto &MI: MBB) {
    338 
    339       // Forget Chains which have expired
    340       for (auto r : Chains) {
    341         SmallVector<unsigned, 8> toDel;
    342         if(regJustKilledBefore(LIs, r, MI)) {
    343           DEBUG(dbgs() << "Killing chain " << PrintReg(r, TRI) << " at ";
    344                 MI.print(dbgs()););
    345           toDel.push_back(r);
    346         }
    347 
    348         while (!toDel.empty()) {
    349           Chains.remove(toDel.back());
    350           toDel.pop_back();
    351         }
    352       }
    353 
    354       switch (MI.getOpcode()) {
    355       case AArch64::FMSUBSrrr:
    356       case AArch64::FMADDSrrr:
    357       case AArch64::FNMSUBSrrr:
    358       case AArch64::FNMADDSrrr:
    359       case AArch64::FMSUBDrrr:
    360       case AArch64::FMADDDrrr:
    361       case AArch64::FNMSUBDrrr:
    362       case AArch64::FNMADDDrrr: {
    363         unsigned Rd = MI.getOperand(0).getReg();
    364         unsigned Ra = MI.getOperand(3).getReg();
    365 
    366         if (addIntraChainConstraint(G, Rd, Ra))
    367           addInterChainConstraint(G, Rd, Ra);
    368         break;
    369       }
    370 
    371       case AArch64::FMLAv2f32:
    372       case AArch64::FMLSv2f32: {
    373         unsigned Rd = MI.getOperand(0).getReg();
    374         addInterChainConstraint(G, Rd, Rd);
    375         break;
    376       }
    377 
    378       default:
    379         break;
    380       }
    381     }
    382   }
    383 }
    384