Home | History | Annotate | Download | only in Scalar
      1 //======- GVNExpression.h - GVN Expression classes --------------*- C++ -*-===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 /// \file
     10 ///
     11 /// The header file for the GVN pass that contains expression handling
     12 /// classes
     13 ///
     14 //===----------------------------------------------------------------------===//
     15 
     16 #ifndef LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
     17 #define LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
     18 
     19 #include "llvm/ADT/Hashing.h"
     20 #include "llvm/ADT/iterator_range.h"
     21 #include "llvm/Analysis/MemorySSA.h"
     22 #include "llvm/IR/Constant.h"
     23 #include "llvm/IR/Instructions.h"
     24 #include "llvm/IR/Value.h"
     25 #include "llvm/Support/Allocator.h"
     26 #include "llvm/Support/ArrayRecycler.h"
     27 #include "llvm/Support/Casting.h"
     28 #include "llvm/Support/Debug.h"
     29 #include "llvm/Support/raw_ostream.h"
     30 #include <algorithm>
     31 #include <cassert>
     32 #include <iterator>
     33 #include <utility>
     34 
     35 namespace llvm {
     36 
     37 namespace GVNExpression {
     38 
     39 enum ExpressionType {
     40   ET_Base,
     41   ET_Constant,
     42   ET_Variable,
     43   ET_Unknown,
     44   ET_BasicStart,
     45   ET_Basic,
     46   ET_AggregateValue,
     47   ET_Phi,
     48   ET_MemoryStart,
     49   ET_Call,
     50   ET_Load,
     51   ET_Store,
     52   ET_MemoryEnd,
     53   ET_BasicEnd
     54 };
     55 
     56 class Expression {
     57 private:
     58   ExpressionType EType;
     59   unsigned Opcode;
     60 
     61 public:
     62   Expression(ExpressionType ET = ET_Base, unsigned O = ~2U)
     63       : EType(ET), Opcode(O) {}
     64   Expression(const Expression &) = delete;
     65   Expression &operator=(const Expression &) = delete;
     66   virtual ~Expression();
     67 
     68   static unsigned getEmptyKey() { return ~0U; }
     69   static unsigned getTombstoneKey() { return ~1U; }
     70   bool operator!=(const Expression &Other) const { return !(*this == Other); }
     71   bool operator==(const Expression &Other) const {
     72     if (getOpcode() != Other.getOpcode())
     73       return false;
     74     if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey())
     75       return true;
     76     // Compare the expression type for anything but load and store.
     77     // For load and store we set the opcode to zero to make them equal.
     78     if (getExpressionType() != ET_Load && getExpressionType() != ET_Store &&
     79         getExpressionType() != Other.getExpressionType())
     80       return false;
     81 
     82     return equals(Other);
     83   }
     84 
     85   virtual bool equals(const Expression &Other) const { return true; }
     86 
     87   unsigned getOpcode() const { return Opcode; }
     88   void setOpcode(unsigned opcode) { Opcode = opcode; }
     89   ExpressionType getExpressionType() const { return EType; }
     90 
     91   // We deliberately leave the expression type out of the hash value.
     92   virtual hash_code getHashValue() const { return getOpcode(); }
     93 
     94   //
     95   // Debugging support
     96   //
     97   virtual void printInternal(raw_ostream &OS, bool PrintEType) const {
     98     if (PrintEType)
     99       OS << "etype = " << getExpressionType() << ",";
    100     OS << "opcode = " << getOpcode() << ", ";
    101   }
    102 
    103   void print(raw_ostream &OS) const {
    104     OS << "{ ";
    105     printInternal(OS, true);
    106     OS << "}";
    107   }
    108 
    109   LLVM_DUMP_METHOD void dump() const {
    110     print(dbgs());
    111     dbgs() << "\n";
    112   }
    113 };
    114 
    115 inline raw_ostream &operator<<(raw_ostream &OS, const Expression &E) {
    116   E.print(OS);
    117   return OS;
    118 }
    119 
    120 class BasicExpression : public Expression {
    121 private:
    122   typedef ArrayRecycler<Value *> RecyclerType;
    123   typedef RecyclerType::Capacity RecyclerCapacity;
    124   Value **Operands;
    125   unsigned MaxOperands;
    126   unsigned NumOperands;
    127   Type *ValueType;
    128 
    129 public:
    130   BasicExpression(unsigned NumOperands)
    131       : BasicExpression(NumOperands, ET_Basic) {}
    132   BasicExpression(unsigned NumOperands, ExpressionType ET)
    133       : Expression(ET), Operands(nullptr), MaxOperands(NumOperands),
    134         NumOperands(0), ValueType(nullptr) {}
    135   BasicExpression() = delete;
    136   BasicExpression(const BasicExpression &) = delete;
    137   BasicExpression &operator=(const BasicExpression &) = delete;
    138   ~BasicExpression() override;
    139 
    140   static bool classof(const Expression *EB) {
    141     ExpressionType ET = EB->getExpressionType();
    142     return ET > ET_BasicStart && ET < ET_BasicEnd;
    143   }
    144 
    145   /// \brief Swap two operands. Used during GVN to put commutative operands in
    146   /// order.
    147   void swapOperands(unsigned First, unsigned Second) {
    148     std::swap(Operands[First], Operands[Second]);
    149   }
    150 
    151   Value *getOperand(unsigned N) const {
    152     assert(Operands && "Operands not allocated");
    153     assert(N < NumOperands && "Operand out of range");
    154     return Operands[N];
    155   }
    156 
    157   void setOperand(unsigned N, Value *V) {
    158     assert(Operands && "Operands not allocated before setting");
    159     assert(N < NumOperands && "Operand out of range");
    160     Operands[N] = V;
    161   }
    162 
    163   unsigned getNumOperands() const { return NumOperands; }
    164 
    165   typedef Value **op_iterator;
    166   typedef Value *const *const_op_iterator;
    167   op_iterator op_begin() { return Operands; }
    168   op_iterator op_end() { return Operands + NumOperands; }
    169   const_op_iterator op_begin() const { return Operands; }
    170   const_op_iterator op_end() const { return Operands + NumOperands; }
    171   iterator_range<op_iterator> operands() {
    172     return iterator_range<op_iterator>(op_begin(), op_end());
    173   }
    174   iterator_range<const_op_iterator> operands() const {
    175     return iterator_range<const_op_iterator>(op_begin(), op_end());
    176   }
    177 
    178   void op_push_back(Value *Arg) {
    179     assert(NumOperands < MaxOperands && "Tried to add too many operands");
    180     assert(Operands && "Operandss not allocated before pushing");
    181     Operands[NumOperands++] = Arg;
    182   }
    183   bool op_empty() const { return getNumOperands() == 0; }
    184 
    185   void allocateOperands(RecyclerType &Recycler, BumpPtrAllocator &Allocator) {
    186     assert(!Operands && "Operands already allocated");
    187     Operands = Recycler.allocate(RecyclerCapacity::get(MaxOperands), Allocator);
    188   }
    189   void deallocateOperands(RecyclerType &Recycler) {
    190     Recycler.deallocate(RecyclerCapacity::get(MaxOperands), Operands);
    191   }
    192 
    193   void setType(Type *T) { ValueType = T; }
    194   Type *getType() const { return ValueType; }
    195 
    196   bool equals(const Expression &Other) const override {
    197     if (getOpcode() != Other.getOpcode())
    198       return false;
    199 
    200     const auto &OE = cast<BasicExpression>(Other);
    201     return getType() == OE.getType() && NumOperands == OE.NumOperands &&
    202            std::equal(op_begin(), op_end(), OE.op_begin());
    203   }
    204 
    205   hash_code getHashValue() const override {
    206     return hash_combine(this->Expression::getHashValue(), ValueType,
    207                         hash_combine_range(op_begin(), op_end()));
    208   }
    209 
    210   //
    211   // Debugging support
    212   //
    213   void printInternal(raw_ostream &OS, bool PrintEType) const override {
    214     if (PrintEType)
    215       OS << "ExpressionTypeBasic, ";
    216 
    217     this->Expression::printInternal(OS, false);
    218     OS << "operands = {";
    219     for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
    220       OS << "[" << i << "] = ";
    221       Operands[i]->printAsOperand(OS);
    222       OS << "  ";
    223     }
    224     OS << "} ";
    225   }
    226 };
    227 
    228 class op_inserter
    229     : public std::iterator<std::output_iterator_tag, void, void, void, void> {
    230 private:
    231   typedef BasicExpression Container;
    232   Container *BE;
    233 
    234 public:
    235   explicit op_inserter(BasicExpression &E) : BE(&E) {}
    236   explicit op_inserter(BasicExpression *E) : BE(E) {}
    237 
    238   op_inserter &operator=(Value *val) {
    239     BE->op_push_back(val);
    240     return *this;
    241   }
    242   op_inserter &operator*() { return *this; }
    243   op_inserter &operator++() { return *this; }
    244   op_inserter &operator++(int) { return *this; }
    245 };
    246 
    247 class MemoryExpression : public BasicExpression {
    248 private:
    249   const MemoryAccess *MemoryLeader;
    250 
    251 public:
    252   MemoryExpression(unsigned NumOperands, enum ExpressionType EType,
    253                    const MemoryAccess *MemoryLeader)
    254       : BasicExpression(NumOperands, EType), MemoryLeader(MemoryLeader){};
    255 
    256   MemoryExpression() = delete;
    257   MemoryExpression(const MemoryExpression &) = delete;
    258   MemoryExpression &operator=(const MemoryExpression &) = delete;
    259   static bool classof(const Expression *EB) {
    260     return EB->getExpressionType() > ET_MemoryStart &&
    261            EB->getExpressionType() < ET_MemoryEnd;
    262   }
    263   hash_code getHashValue() const override {
    264     return hash_combine(this->BasicExpression::getHashValue(), MemoryLeader);
    265   }
    266 
    267   bool equals(const Expression &Other) const override {
    268     if (!this->BasicExpression::equals(Other))
    269       return false;
    270     const MemoryExpression &OtherMCE = cast<MemoryExpression>(Other);
    271 
    272     return MemoryLeader == OtherMCE.MemoryLeader;
    273   }
    274 
    275   const MemoryAccess *getMemoryLeader() const { return MemoryLeader; }
    276   void setMemoryLeader(const MemoryAccess *ML) { MemoryLeader = ML; }
    277 };
    278 
    279 class CallExpression final : public MemoryExpression {
    280 private:
    281   CallInst *Call;
    282 
    283 public:
    284   CallExpression(unsigned NumOperands, CallInst *C,
    285                  const MemoryAccess *MemoryLeader)
    286       : MemoryExpression(NumOperands, ET_Call, MemoryLeader), Call(C) {}
    287   CallExpression() = delete;
    288   CallExpression(const CallExpression &) = delete;
    289   CallExpression &operator=(const CallExpression &) = delete;
    290   ~CallExpression() override;
    291 
    292   static bool classof(const Expression *EB) {
    293     return EB->getExpressionType() == ET_Call;
    294   }
    295 
    296   //
    297   // Debugging support
    298   //
    299   void printInternal(raw_ostream &OS, bool PrintEType) const override {
    300     if (PrintEType)
    301       OS << "ExpressionTypeCall, ";
    302     this->BasicExpression::printInternal(OS, false);
    303     OS << " represents call at ";
    304     Call->printAsOperand(OS);
    305   }
    306 };
    307 
    308 class LoadExpression final : public MemoryExpression {
    309 private:
    310   LoadInst *Load;
    311   unsigned Alignment;
    312 
    313 public:
    314   LoadExpression(unsigned NumOperands, LoadInst *L,
    315                  const MemoryAccess *MemoryLeader)
    316       : LoadExpression(ET_Load, NumOperands, L, MemoryLeader) {}
    317   LoadExpression(enum ExpressionType EType, unsigned NumOperands, LoadInst *L,
    318                  const MemoryAccess *MemoryLeader)
    319       : MemoryExpression(NumOperands, EType, MemoryLeader), Load(L) {
    320     Alignment = L ? L->getAlignment() : 0;
    321   }
    322   LoadExpression() = delete;
    323   LoadExpression(const LoadExpression &) = delete;
    324   LoadExpression &operator=(const LoadExpression &) = delete;
    325   ~LoadExpression() override;
    326 
    327   static bool classof(const Expression *EB) {
    328     return EB->getExpressionType() == ET_Load;
    329   }
    330 
    331   LoadInst *getLoadInst() const { return Load; }
    332   void setLoadInst(LoadInst *L) { Load = L; }
    333 
    334   unsigned getAlignment() const { return Alignment; }
    335   void setAlignment(unsigned Align) { Alignment = Align; }
    336 
    337   bool equals(const Expression &Other) const override;
    338 
    339   //
    340   // Debugging support
    341   //
    342   void printInternal(raw_ostream &OS, bool PrintEType) const override {
    343     if (PrintEType)
    344       OS << "ExpressionTypeLoad, ";
    345     this->BasicExpression::printInternal(OS, false);
    346     OS << " represents Load at ";
    347     Load->printAsOperand(OS);
    348     OS << " with MemoryLeader " << *getMemoryLeader();
    349   }
    350 };
    351 
    352 class StoreExpression final : public MemoryExpression {
    353 private:
    354   StoreInst *Store;
    355   Value *StoredValue;
    356 
    357 public:
    358   StoreExpression(unsigned NumOperands, StoreInst *S, Value *StoredValue,
    359                   const MemoryAccess *MemoryLeader)
    360       : MemoryExpression(NumOperands, ET_Store, MemoryLeader), Store(S),
    361         StoredValue(StoredValue) {}
    362   StoreExpression() = delete;
    363   StoreExpression(const StoreExpression &) = delete;
    364   StoreExpression &operator=(const StoreExpression &) = delete;
    365   ~StoreExpression() override;
    366 
    367   static bool classof(const Expression *EB) {
    368     return EB->getExpressionType() == ET_Store;
    369   }
    370 
    371   StoreInst *getStoreInst() const { return Store; }
    372   Value *getStoredValue() const { return StoredValue; }
    373 
    374   bool equals(const Expression &Other) const override;
    375 
    376   // Debugging support
    377   //
    378   void printInternal(raw_ostream &OS, bool PrintEType) const override {
    379     if (PrintEType)
    380       OS << "ExpressionTypeStore, ";
    381     this->BasicExpression::printInternal(OS, false);
    382     OS << " represents Store  " << *Store;
    383     OS << " with MemoryLeader " << *getMemoryLeader();
    384   }
    385 };
    386 
    387 class AggregateValueExpression final : public BasicExpression {
    388 private:
    389   unsigned MaxIntOperands;
    390   unsigned NumIntOperands;
    391   unsigned *IntOperands;
    392 
    393 public:
    394   AggregateValueExpression(unsigned NumOperands, unsigned NumIntOperands)
    395       : BasicExpression(NumOperands, ET_AggregateValue),
    396         MaxIntOperands(NumIntOperands), NumIntOperands(0),
    397         IntOperands(nullptr) {}
    398   AggregateValueExpression() = delete;
    399   AggregateValueExpression(const AggregateValueExpression &) = delete;
    400   AggregateValueExpression &
    401   operator=(const AggregateValueExpression &) = delete;
    402   ~AggregateValueExpression() override;
    403 
    404   static bool classof(const Expression *EB) {
    405     return EB->getExpressionType() == ET_AggregateValue;
    406   }
    407 
    408   typedef unsigned *int_arg_iterator;
    409   typedef const unsigned *const_int_arg_iterator;
    410 
    411   int_arg_iterator int_op_begin() { return IntOperands; }
    412   int_arg_iterator int_op_end() { return IntOperands + NumIntOperands; }
    413   const_int_arg_iterator int_op_begin() const { return IntOperands; }
    414   const_int_arg_iterator int_op_end() const {
    415     return IntOperands + NumIntOperands;
    416   }
    417   unsigned int_op_size() const { return NumIntOperands; }
    418   bool int_op_empty() const { return NumIntOperands == 0; }
    419   void int_op_push_back(unsigned IntOperand) {
    420     assert(NumIntOperands < MaxIntOperands &&
    421            "Tried to add too many int operands");
    422     assert(IntOperands && "Operands not allocated before pushing");
    423     IntOperands[NumIntOperands++] = IntOperand;
    424   }
    425 
    426   virtual void allocateIntOperands(BumpPtrAllocator &Allocator) {
    427     assert(!IntOperands && "Operands already allocated");
    428     IntOperands = Allocator.Allocate<unsigned>(MaxIntOperands);
    429   }
    430 
    431   bool equals(const Expression &Other) const override {
    432     if (!this->BasicExpression::equals(Other))
    433       return false;
    434     const AggregateValueExpression &OE = cast<AggregateValueExpression>(Other);
    435     return NumIntOperands == OE.NumIntOperands &&
    436            std::equal(int_op_begin(), int_op_end(), OE.int_op_begin());
    437   }
    438 
    439   hash_code getHashValue() const override {
    440     return hash_combine(this->BasicExpression::getHashValue(),
    441                         hash_combine_range(int_op_begin(), int_op_end()));
    442   }
    443 
    444   //
    445   // Debugging support
    446   //
    447   void printInternal(raw_ostream &OS, bool PrintEType) const override {
    448     if (PrintEType)
    449       OS << "ExpressionTypeAggregateValue, ";
    450     this->BasicExpression::printInternal(OS, false);
    451     OS << ", intoperands = {";
    452     for (unsigned i = 0, e = int_op_size(); i != e; ++i) {
    453       OS << "[" << i << "] = " << IntOperands[i] << "  ";
    454     }
    455     OS << "}";
    456   }
    457 };
    458 
    459 class int_op_inserter
    460     : public std::iterator<std::output_iterator_tag, void, void, void, void> {
    461 private:
    462   typedef AggregateValueExpression Container;
    463   Container *AVE;
    464 
    465 public:
    466   explicit int_op_inserter(AggregateValueExpression &E) : AVE(&E) {}
    467   explicit int_op_inserter(AggregateValueExpression *E) : AVE(E) {}
    468 
    469   int_op_inserter &operator=(unsigned int val) {
    470     AVE->int_op_push_back(val);
    471     return *this;
    472   }
    473   int_op_inserter &operator*() { return *this; }
    474   int_op_inserter &operator++() { return *this; }
    475   int_op_inserter &operator++(int) { return *this; }
    476 };
    477 
    478 class PHIExpression final : public BasicExpression {
    479 private:
    480   BasicBlock *BB;
    481 
    482 public:
    483   PHIExpression(unsigned NumOperands, BasicBlock *B)
    484       : BasicExpression(NumOperands, ET_Phi), BB(B) {}
    485   PHIExpression() = delete;
    486   PHIExpression(const PHIExpression &) = delete;
    487   PHIExpression &operator=(const PHIExpression &) = delete;
    488   ~PHIExpression() override;
    489 
    490   static bool classof(const Expression *EB) {
    491     return EB->getExpressionType() == ET_Phi;
    492   }
    493 
    494   bool equals(const Expression &Other) const override {
    495     if (!this->BasicExpression::equals(Other))
    496       return false;
    497     const PHIExpression &OE = cast<PHIExpression>(Other);
    498     return BB == OE.BB;
    499   }
    500 
    501   hash_code getHashValue() const override {
    502     return hash_combine(this->BasicExpression::getHashValue(), BB);
    503   }
    504 
    505   //
    506   // Debugging support
    507   //
    508   void printInternal(raw_ostream &OS, bool PrintEType) const override {
    509     if (PrintEType)
    510       OS << "ExpressionTypePhi, ";
    511     this->BasicExpression::printInternal(OS, false);
    512     OS << "bb = " << BB;
    513   }
    514 };
    515 
    516 class VariableExpression final : public Expression {
    517 private:
    518   Value *VariableValue;
    519 
    520 public:
    521   VariableExpression(Value *V) : Expression(ET_Variable), VariableValue(V) {}
    522   VariableExpression() = delete;
    523   VariableExpression(const VariableExpression &) = delete;
    524   VariableExpression &operator=(const VariableExpression &) = delete;
    525 
    526   static bool classof(const Expression *EB) {
    527     return EB->getExpressionType() == ET_Variable;
    528   }
    529 
    530   Value *getVariableValue() const { return VariableValue; }
    531   void setVariableValue(Value *V) { VariableValue = V; }
    532 
    533   bool equals(const Expression &Other) const override {
    534     const VariableExpression &OC = cast<VariableExpression>(Other);
    535     return VariableValue == OC.VariableValue;
    536   }
    537 
    538   hash_code getHashValue() const override {
    539     return hash_combine(this->Expression::getHashValue(),
    540                         VariableValue->getType(), VariableValue);
    541   }
    542 
    543   //
    544   // Debugging support
    545   //
    546   void printInternal(raw_ostream &OS, bool PrintEType) const override {
    547     if (PrintEType)
    548       OS << "ExpressionTypeVariable, ";
    549     this->Expression::printInternal(OS, false);
    550     OS << " variable = " << *VariableValue;
    551   }
    552 };
    553 
    554 class ConstantExpression final : public Expression {
    555 private:
    556   Constant *ConstantValue = nullptr;
    557 
    558 public:
    559   ConstantExpression() : Expression(ET_Constant) {}
    560   ConstantExpression(Constant *constantValue)
    561       : Expression(ET_Constant), ConstantValue(constantValue) {}
    562   ConstantExpression(const ConstantExpression &) = delete;
    563   ConstantExpression &operator=(const ConstantExpression &) = delete;
    564 
    565   static bool classof(const Expression *EB) {
    566     return EB->getExpressionType() == ET_Constant;
    567   }
    568 
    569   Constant *getConstantValue() const { return ConstantValue; }
    570   void setConstantValue(Constant *V) { ConstantValue = V; }
    571 
    572   bool equals(const Expression &Other) const override {
    573     const ConstantExpression &OC = cast<ConstantExpression>(Other);
    574     return ConstantValue == OC.ConstantValue;
    575   }
    576 
    577   hash_code getHashValue() const override {
    578     return hash_combine(this->Expression::getHashValue(),
    579                         ConstantValue->getType(), ConstantValue);
    580   }
    581 
    582   //
    583   // Debugging support
    584   //
    585   void printInternal(raw_ostream &OS, bool PrintEType) const override {
    586     if (PrintEType)
    587       OS << "ExpressionTypeConstant, ";
    588     this->Expression::printInternal(OS, false);
    589     OS << " constant = " << *ConstantValue;
    590   }
    591 };
    592 
    593 class UnknownExpression final : public Expression {
    594 private:
    595   Instruction *Inst;
    596 
    597 public:
    598   UnknownExpression(Instruction *I) : Expression(ET_Unknown), Inst(I) {}
    599   UnknownExpression() = delete;
    600   UnknownExpression(const UnknownExpression &) = delete;
    601   UnknownExpression &operator=(const UnknownExpression &) = delete;
    602 
    603   static bool classof(const Expression *EB) {
    604     return EB->getExpressionType() == ET_Unknown;
    605   }
    606 
    607   Instruction *getInstruction() const { return Inst; }
    608   void setInstruction(Instruction *I) { Inst = I; }
    609 
    610   bool equals(const Expression &Other) const override {
    611     const auto &OU = cast<UnknownExpression>(Other);
    612     return Inst == OU.Inst;
    613   }
    614 
    615   hash_code getHashValue() const override {
    616     return hash_combine(this->Expression::getHashValue(), Inst);
    617   }
    618 
    619   //
    620   // Debugging support
    621   //
    622   void printInternal(raw_ostream &OS, bool PrintEType) const override {
    623     if (PrintEType)
    624       OS << "ExpressionTypeUnknown, ";
    625     this->Expression::printInternal(OS, false);
    626     OS << " inst = " << *Inst;
    627   }
    628 };
    629 
    630 } // end namespace GVNExpression
    631 
    632 } // end namespace llvm
    633 
    634 #endif // LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
    635