Home | History | Annotate | Download | only in Analysis
      1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // This file defines the classes used to represent and build scalar expressions.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
     15 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
     16 
     17 #include "llvm/ADT/SmallPtrSet.h"
     18 #include "llvm/ADT/iterator_range.h"
     19 #include "llvm/Analysis/ScalarEvolution.h"
     20 #include "llvm/Support/ErrorHandling.h"
     21 
     22 namespace llvm {
     23   class ConstantInt;
     24   class ConstantRange;
     25   class DominatorTree;
     26 
     27   enum SCEVTypes {
     28     // These should be ordered in terms of increasing complexity to make the
     29     // folders simpler.
     30     scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr,
     31     scUDivExpr, scAddRecExpr, scUMaxExpr, scSMaxExpr,
     32     scUnknown, scCouldNotCompute
     33   };
     34 
     35   /// This class represents a constant integer value.
     36   class SCEVConstant : public SCEV {
     37     friend class ScalarEvolution;
     38 
     39     ConstantInt *V;
     40     SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) :
     41       SCEV(ID, scConstant), V(v) {}
     42   public:
     43     ConstantInt *getValue() const { return V; }
     44     const APInt &getAPInt() const { return getValue()->getValue(); }
     45 
     46     Type *getType() const { return V->getType(); }
     47 
     48     /// Methods for support type inquiry through isa, cast, and dyn_cast:
     49     static inline bool classof(const SCEV *S) {
     50       return S->getSCEVType() == scConstant;
     51     }
     52   };
     53 
     54   /// This is the base class for unary cast operator classes.
     55   class SCEVCastExpr : public SCEV {
     56   protected:
     57     const SCEV *Op;
     58     Type *Ty;
     59 
     60     SCEVCastExpr(const FoldingSetNodeIDRef ID,
     61                  unsigned SCEVTy, const SCEV *op, Type *ty);
     62 
     63   public:
     64     const SCEV *getOperand() const { return Op; }
     65     Type *getType() const { return Ty; }
     66 
     67     /// Methods for support type inquiry through isa, cast, and dyn_cast:
     68     static inline bool classof(const SCEV *S) {
     69       return S->getSCEVType() == scTruncate ||
     70              S->getSCEVType() == scZeroExtend ||
     71              S->getSCEVType() == scSignExtend;
     72     }
     73   };
     74 
     75   /// This class represents a truncation of an integer value to a
     76   /// smaller integer value.
     77   class SCEVTruncateExpr : public SCEVCastExpr {
     78     friend class ScalarEvolution;
     79 
     80     SCEVTruncateExpr(const FoldingSetNodeIDRef ID,
     81                      const SCEV *op, Type *ty);
     82 
     83   public:
     84     /// Methods for support type inquiry through isa, cast, and dyn_cast:
     85     static inline bool classof(const SCEV *S) {
     86       return S->getSCEVType() == scTruncate;
     87     }
     88   };
     89 
     90   /// This class represents a zero extension of a small integer value
     91   /// to a larger integer value.
     92   class SCEVZeroExtendExpr : public SCEVCastExpr {
     93     friend class ScalarEvolution;
     94 
     95     SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
     96                        const SCEV *op, Type *ty);
     97 
     98   public:
     99     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    100     static inline bool classof(const SCEV *S) {
    101       return S->getSCEVType() == scZeroExtend;
    102     }
    103   };
    104 
    105   /// This class represents a sign extension of a small integer value
    106   /// to a larger integer value.
    107   class SCEVSignExtendExpr : public SCEVCastExpr {
    108     friend class ScalarEvolution;
    109 
    110     SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
    111                        const SCEV *op, Type *ty);
    112 
    113   public:
    114     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    115     static inline bool classof(const SCEV *S) {
    116       return S->getSCEVType() == scSignExtend;
    117     }
    118   };
    119 
    120 
    121   /// This node is a base class providing common functionality for
    122   /// n'ary operators.
    123   class SCEVNAryExpr : public SCEV {
    124   protected:
    125     // Since SCEVs are immutable, ScalarEvolution allocates operand
    126     // arrays with its SCEVAllocator, so this class just needs a simple
    127     // pointer rather than a more elaborate vector-like data structure.
    128     // This also avoids the need for a non-trivial destructor.
    129     const SCEV *const *Operands;
    130     size_t NumOperands;
    131 
    132     SCEVNAryExpr(const FoldingSetNodeIDRef ID,
    133                  enum SCEVTypes T, const SCEV *const *O, size_t N)
    134       : SCEV(ID, T), Operands(O), NumOperands(N) {}
    135 
    136   public:
    137     size_t getNumOperands() const { return NumOperands; }
    138     const SCEV *getOperand(unsigned i) const {
    139       assert(i < NumOperands && "Operand index out of range!");
    140       return Operands[i];
    141     }
    142 
    143     typedef const SCEV *const *op_iterator;
    144     typedef iterator_range<op_iterator> op_range;
    145     op_iterator op_begin() const { return Operands; }
    146     op_iterator op_end() const { return Operands + NumOperands; }
    147     op_range operands() const {
    148       return make_range(op_begin(), op_end());
    149     }
    150 
    151     Type *getType() const { return getOperand(0)->getType(); }
    152 
    153     NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
    154       return (NoWrapFlags)(SubclassData & Mask);
    155     }
    156 
    157     bool hasNoUnsignedWrap() const {
    158       return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
    159     }
    160 
    161     bool hasNoSignedWrap() const {
    162       return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
    163     }
    164 
    165     bool hasNoSelfWrap() const {
    166       return getNoWrapFlags(FlagNW) != FlagAnyWrap;
    167     }
    168 
    169     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    170     static inline bool classof(const SCEV *S) {
    171       return S->getSCEVType() == scAddExpr ||
    172              S->getSCEVType() == scMulExpr ||
    173              S->getSCEVType() == scSMaxExpr ||
    174              S->getSCEVType() == scUMaxExpr ||
    175              S->getSCEVType() == scAddRecExpr;
    176     }
    177   };
    178 
    179   /// This node is the base class for n'ary commutative operators.
    180   class SCEVCommutativeExpr : public SCEVNAryExpr {
    181   protected:
    182     SCEVCommutativeExpr(const FoldingSetNodeIDRef ID,
    183                         enum SCEVTypes T, const SCEV *const *O, size_t N)
    184       : SCEVNAryExpr(ID, T, O, N) {}
    185 
    186   public:
    187     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    188     static inline bool classof(const SCEV *S) {
    189       return S->getSCEVType() == scAddExpr ||
    190              S->getSCEVType() == scMulExpr ||
    191              S->getSCEVType() == scSMaxExpr ||
    192              S->getSCEVType() == scUMaxExpr;
    193     }
    194 
    195     /// Set flags for a non-recurrence without clearing previously set flags.
    196     void setNoWrapFlags(NoWrapFlags Flags) {
    197       SubclassData |= Flags;
    198     }
    199   };
    200 
    201 
    202   /// This node represents an addition of some number of SCEVs.
    203   class SCEVAddExpr : public SCEVCommutativeExpr {
    204     friend class ScalarEvolution;
    205 
    206     SCEVAddExpr(const FoldingSetNodeIDRef ID,
    207                 const SCEV *const *O, size_t N)
    208       : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
    209     }
    210 
    211   public:
    212     Type *getType() const {
    213       // Use the type of the last operand, which is likely to be a pointer
    214       // type, if there is one. This doesn't usually matter, but it can help
    215       // reduce casts when the expressions are expanded.
    216       return getOperand(getNumOperands() - 1)->getType();
    217     }
    218 
    219     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    220     static inline bool classof(const SCEV *S) {
    221       return S->getSCEVType() == scAddExpr;
    222     }
    223   };
    224 
    225 
    226   /// This node represents multiplication of some number of SCEVs.
    227   class SCEVMulExpr : public SCEVCommutativeExpr {
    228     friend class ScalarEvolution;
    229 
    230     SCEVMulExpr(const FoldingSetNodeIDRef ID,
    231                 const SCEV *const *O, size_t N)
    232       : SCEVCommutativeExpr(ID, scMulExpr, O, N) {
    233     }
    234 
    235   public:
    236     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    237     static inline bool classof(const SCEV *S) {
    238       return S->getSCEVType() == scMulExpr;
    239     }
    240   };
    241 
    242 
    243   /// This class represents a binary unsigned division operation.
    244   class SCEVUDivExpr : public SCEV {
    245     friend class ScalarEvolution;
    246 
    247     const SCEV *LHS;
    248     const SCEV *RHS;
    249     SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
    250       : SCEV(ID, scUDivExpr), LHS(lhs), RHS(rhs) {}
    251 
    252   public:
    253     const SCEV *getLHS() const { return LHS; }
    254     const SCEV *getRHS() const { return RHS; }
    255 
    256     Type *getType() const {
    257       // In most cases the types of LHS and RHS will be the same, but in some
    258       // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
    259       // depend on the type for correctness, but handling types carefully can
    260       // avoid extra casts in the SCEVExpander. The LHS is more likely to be
    261       // a pointer type than the RHS, so use the RHS' type here.
    262       return getRHS()->getType();
    263     }
    264 
    265     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    266     static inline bool classof(const SCEV *S) {
    267       return S->getSCEVType() == scUDivExpr;
    268     }
    269   };
    270 
    271 
    272   /// This node represents a polynomial recurrence on the trip count
    273   /// of the specified loop.  This is the primary focus of the
    274   /// ScalarEvolution framework; all the other SCEV subclasses are
    275   /// mostly just supporting infrastructure to allow SCEVAddRecExpr
    276   /// expressions to be created and analyzed.
    277   ///
    278   /// All operands of an AddRec are required to be loop invariant.
    279   ///
    280   class SCEVAddRecExpr : public SCEVNAryExpr {
    281     friend class ScalarEvolution;
    282 
    283     const Loop *L;
    284 
    285     SCEVAddRecExpr(const FoldingSetNodeIDRef ID,
    286                    const SCEV *const *O, size_t N, const Loop *l)
    287       : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
    288 
    289   public:
    290     const SCEV *getStart() const { return Operands[0]; }
    291     const Loop *getLoop() const { return L; }
    292 
    293     /// Constructs and returns the recurrence indicating how much this
    294     /// expression steps by.  If this is a polynomial of degree N, it
    295     /// returns a chrec of degree N-1.  We cannot determine whether
    296     /// the step recurrence has self-wraparound.
    297     const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
    298       if (isAffine()) return getOperand(1);
    299       return SE.getAddRecExpr(SmallVector<const SCEV *, 3>(op_begin()+1,
    300                                                            op_end()),
    301                               getLoop(), FlagAnyWrap);
    302     }
    303 
    304     /// Return true if this represents an expression A + B*x where A
    305     /// and B are loop invariant values.
    306     bool isAffine() const {
    307       // We know that the start value is invariant.  This expression is thus
    308       // affine iff the step is also invariant.
    309       return getNumOperands() == 2;
    310     }
    311 
    312     /// Return true if this represents an expression A + B*x + C*x^2
    313     /// where A, B and C are loop invariant values.  This corresponds
    314     /// to an addrec of the form {L,+,M,+,N}
    315     bool isQuadratic() const {
    316       return getNumOperands() == 3;
    317     }
    318 
    319     /// Set flags for a recurrence without clearing any previously set flags.
    320     /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
    321     /// to make it easier to propagate flags.
    322     void setNoWrapFlags(NoWrapFlags Flags) {
    323       if (Flags & (FlagNUW | FlagNSW))
    324         Flags = ScalarEvolution::setFlags(Flags, FlagNW);
    325       SubclassData |= Flags;
    326     }
    327 
    328     /// Return the value of this chain of recurrences at the specified
    329     /// iteration number.
    330     const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
    331 
    332     /// Return the number of iterations of this loop that produce
    333     /// values in the specified constant range.  Another way of
    334     /// looking at this is that it returns the first iteration number
    335     /// where the value is not in the condition, thus computing the
    336     /// exit count.  If the iteration count can't be computed, an
    337     /// instance of SCEVCouldNotCompute is returned.
    338     const SCEV *getNumIterationsInRange(const ConstantRange &Range,
    339                                         ScalarEvolution &SE) const;
    340 
    341     /// Return an expression representing the value of this expression
    342     /// one iteration of the loop ahead.
    343     const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const {
    344       return cast<SCEVAddRecExpr>(SE.getAddExpr(this, getStepRecurrence(SE)));
    345     }
    346 
    347     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    348     static inline bool classof(const SCEV *S) {
    349       return S->getSCEVType() == scAddRecExpr;
    350     }
    351   };
    352 
    353   /// This class represents a signed maximum selection.
    354   class SCEVSMaxExpr : public SCEVCommutativeExpr {
    355     friend class ScalarEvolution;
    356 
    357     SCEVSMaxExpr(const FoldingSetNodeIDRef ID,
    358                  const SCEV *const *O, size_t N)
    359       : SCEVCommutativeExpr(ID, scSMaxExpr, O, N) {
    360       // Max never overflows.
    361       setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
    362     }
    363 
    364   public:
    365     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    366     static inline bool classof(const SCEV *S) {
    367       return S->getSCEVType() == scSMaxExpr;
    368     }
    369   };
    370 
    371 
    372   /// This class represents an unsigned maximum selection.
    373   class SCEVUMaxExpr : public SCEVCommutativeExpr {
    374     friend class ScalarEvolution;
    375 
    376     SCEVUMaxExpr(const FoldingSetNodeIDRef ID,
    377                  const SCEV *const *O, size_t N)
    378       : SCEVCommutativeExpr(ID, scUMaxExpr, O, N) {
    379       // Max never overflows.
    380       setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
    381     }
    382 
    383   public:
    384     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    385     static inline bool classof(const SCEV *S) {
    386       return S->getSCEVType() == scUMaxExpr;
    387     }
    388   };
    389 
    390   /// This means that we are dealing with an entirely unknown SCEV
    391   /// value, and only represent it as its LLVM Value.  This is the
    392   /// "bottom" value for the analysis.
    393   class SCEVUnknown final : public SCEV, private CallbackVH {
    394     friend class ScalarEvolution;
    395 
    396     // Implement CallbackVH.
    397     void deleted() override;
    398     void allUsesReplacedWith(Value *New) override;
    399 
    400     /// The parent ScalarEvolution value. This is used to update the
    401     /// parent's maps when the value associated with a SCEVUnknown is
    402     /// deleted or RAUW'd.
    403     ScalarEvolution *SE;
    404 
    405     /// The next pointer in the linked list of all SCEVUnknown
    406     /// instances owned by a ScalarEvolution.
    407     SCEVUnknown *Next;
    408 
    409     SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V,
    410                 ScalarEvolution *se, SCEVUnknown *next) :
    411       SCEV(ID, scUnknown), CallbackVH(V), SE(se), Next(next) {}
    412 
    413   public:
    414     Value *getValue() const { return getValPtr(); }
    415 
    416     /// @{
    417     /// Test whether this is a special constant representing a type
    418     /// size, alignment, or field offset in a target-independent
    419     /// manner, and hasn't happened to have been folded with other
    420     /// operations into something unrecognizable. This is mainly only
    421     /// useful for pretty-printing and other situations where it isn't
    422     /// absolutely required for these to succeed.
    423     bool isSizeOf(Type *&AllocTy) const;
    424     bool isAlignOf(Type *&AllocTy) const;
    425     bool isOffsetOf(Type *&STy, Constant *&FieldNo) const;
    426     /// @}
    427 
    428     Type *getType() const { return getValPtr()->getType(); }
    429 
    430     /// Methods for support type inquiry through isa, cast, and dyn_cast:
    431     static inline bool classof(const SCEV *S) {
    432       return S->getSCEVType() == scUnknown;
    433     }
    434   };
    435 
    436   /// This class defines a simple visitor class that may be used for
    437   /// various SCEV analysis purposes.
    438   template<typename SC, typename RetVal=void>
    439   struct SCEVVisitor {
    440     RetVal visit(const SCEV *S) {
    441       switch (S->getSCEVType()) {
    442       case scConstant:
    443         return ((SC*)this)->visitConstant((const SCEVConstant*)S);
    444       case scTruncate:
    445         return ((SC*)this)->visitTruncateExpr((const SCEVTruncateExpr*)S);
    446       case scZeroExtend:
    447         return ((SC*)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr*)S);
    448       case scSignExtend:
    449         return ((SC*)this)->visitSignExtendExpr((const SCEVSignExtendExpr*)S);
    450       case scAddExpr:
    451         return ((SC*)this)->visitAddExpr((const SCEVAddExpr*)S);
    452       case scMulExpr:
    453         return ((SC*)this)->visitMulExpr((const SCEVMulExpr*)S);
    454       case scUDivExpr:
    455         return ((SC*)this)->visitUDivExpr((const SCEVUDivExpr*)S);
    456       case scAddRecExpr:
    457         return ((SC*)this)->visitAddRecExpr((const SCEVAddRecExpr*)S);
    458       case scSMaxExpr:
    459         return ((SC*)this)->visitSMaxExpr((const SCEVSMaxExpr*)S);
    460       case scUMaxExpr:
    461         return ((SC*)this)->visitUMaxExpr((const SCEVUMaxExpr*)S);
    462       case scUnknown:
    463         return ((SC*)this)->visitUnknown((const SCEVUnknown*)S);
    464       case scCouldNotCompute:
    465         return ((SC*)this)->visitCouldNotCompute((const SCEVCouldNotCompute*)S);
    466       default:
    467         llvm_unreachable("Unknown SCEV type!");
    468       }
    469     }
    470 
    471     RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
    472       llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
    473     }
    474   };
    475 
    476   /// Visit all nodes in the expression tree using worklist traversal.
    477   ///
    478   /// Visitor implements:
    479   ///   // return true to follow this node.
    480   ///   bool follow(const SCEV *S);
    481   ///   // return true to terminate the search.
    482   ///   bool isDone();
    483   template<typename SV>
    484   class SCEVTraversal {
    485     SV &Visitor;
    486     SmallVector<const SCEV *, 8> Worklist;
    487     SmallPtrSet<const SCEV *, 8> Visited;
    488 
    489     void push(const SCEV *S) {
    490       if (Visited.insert(S).second && Visitor.follow(S))
    491         Worklist.push_back(S);
    492     }
    493   public:
    494     SCEVTraversal(SV& V): Visitor(V) {}
    495 
    496     void visitAll(const SCEV *Root) {
    497       push(Root);
    498       while (!Worklist.empty() && !Visitor.isDone()) {
    499         const SCEV *S = Worklist.pop_back_val();
    500 
    501         switch (S->getSCEVType()) {
    502         case scConstant:
    503         case scUnknown:
    504           break;
    505         case scTruncate:
    506         case scZeroExtend:
    507         case scSignExtend:
    508           push(cast<SCEVCastExpr>(S)->getOperand());
    509           break;
    510         case scAddExpr:
    511         case scMulExpr:
    512         case scSMaxExpr:
    513         case scUMaxExpr:
    514         case scAddRecExpr:
    515           for (const auto *Op : cast<SCEVNAryExpr>(S)->operands())
    516             push(Op);
    517           break;
    518         case scUDivExpr: {
    519           const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
    520           push(UDiv->getLHS());
    521           push(UDiv->getRHS());
    522           break;
    523         }
    524         case scCouldNotCompute:
    525           llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
    526         default:
    527           llvm_unreachable("Unknown SCEV kind!");
    528         }
    529       }
    530     }
    531   };
    532 
    533   /// Use SCEVTraversal to visit all nodes in the given expression tree.
    534   template<typename SV>
    535   void visitAll(const SCEV *Root, SV& Visitor) {
    536     SCEVTraversal<SV> T(Visitor);
    537     T.visitAll(Root);
    538   }
    539 
    540   /// Return true if any node in \p Root satisfies the predicate \p Pred.
    541   template <typename PredTy>
    542   bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
    543     struct FindClosure {
    544       bool Found = false;
    545       PredTy Pred;
    546 
    547       FindClosure(PredTy Pred) : Pred(Pred) {}
    548 
    549       bool follow(const SCEV *S) {
    550         if (!Pred(S))
    551           return true;
    552 
    553         Found = true;
    554         return false;
    555       }
    556 
    557       bool isDone() const { return Found; }
    558     };
    559 
    560     FindClosure FC(Pred);
    561     visitAll(Root, FC);
    562     return FC.Found;
    563   }
    564 
    565   /// This visitor recursively visits a SCEV expression and re-writes it.
    566   /// The result from each visit is cached, so it will return the same
    567   /// SCEV for the same input.
    568   template<typename SC>
    569   class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
    570   protected:
    571     ScalarEvolution &SE;
    572     // Memoize the result of each visit so that we only compute once for
    573     // the same input SCEV. This is to avoid redundant computations when
    574     // a SCEV is referenced by multiple SCEVs. Without memoization, this
    575     // visit algorithm would have exponential time complexity in the worst
    576     // case, causing the compiler to hang on certain tests.
    577     DenseMap<const SCEV *, const SCEV *> RewriteResults;
    578 
    579   public:
    580     SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
    581 
    582     const SCEV *visit(const SCEV *S) {
    583       auto It = RewriteResults.find(S);
    584       if (It != RewriteResults.end())
    585         return It->second;
    586       auto* Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
    587       auto Result = RewriteResults.try_emplace(S, Visited);
    588       assert(Result.second && "Should insert a new entry");
    589       return Result.first->second;
    590     }
    591 
    592     const SCEV *visitConstant(const SCEVConstant *Constant) {
    593       return Constant;
    594     }
    595 
    596     const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
    597       const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
    598       return Operand == Expr->getOperand()
    599                  ? Expr
    600                  : SE.getTruncateExpr(Operand, Expr->getType());
    601     }
    602 
    603     const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
    604       const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
    605       return Operand == Expr->getOperand()
    606                  ? Expr
    607                  : SE.getZeroExtendExpr(Operand, Expr->getType());
    608     }
    609 
    610     const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
    611       const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand());
    612       return Operand == Expr->getOperand()
    613                  ? Expr
    614                  : SE.getSignExtendExpr(Operand, Expr->getType());
    615     }
    616 
    617     const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
    618       SmallVector<const SCEV *, 2> Operands;
    619       bool Changed = false;
    620       for (auto *Op : Expr->operands()) {
    621         Operands.push_back(((SC*)this)->visit(Op));
    622         Changed |= Op != Operands.back();
    623       }
    624       return !Changed ? Expr : SE.getAddExpr(Operands);
    625     }
    626 
    627     const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
    628       SmallVector<const SCEV *, 2> Operands;
    629       bool Changed = false;
    630       for (auto *Op : Expr->operands()) {
    631         Operands.push_back(((SC*)this)->visit(Op));
    632         Changed |= Op != Operands.back();
    633       }
    634       return !Changed ? Expr : SE.getMulExpr(Operands);
    635     }
    636 
    637     const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
    638       auto *LHS = ((SC *)this)->visit(Expr->getLHS());
    639       auto *RHS = ((SC *)this)->visit(Expr->getRHS());
    640       bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
    641       return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
    642     }
    643 
    644     const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
    645       SmallVector<const SCEV *, 2> Operands;
    646       bool Changed = false;
    647       for (auto *Op : Expr->operands()) {
    648         Operands.push_back(((SC*)this)->visit(Op));
    649         Changed |= Op != Operands.back();
    650       }
    651       return !Changed ? Expr
    652                       : SE.getAddRecExpr(Operands, Expr->getLoop(),
    653                                          Expr->getNoWrapFlags());
    654     }
    655 
    656     const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
    657       SmallVector<const SCEV *, 2> Operands;
    658       bool Changed = false;
    659       for (auto *Op : Expr->operands()) {
    660         Operands.push_back(((SC *)this)->visit(Op));
    661         Changed |= Op != Operands.back();
    662       }
    663       return !Changed ? Expr : SE.getSMaxExpr(Operands);
    664     }
    665 
    666     const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
    667       SmallVector<const SCEV *, 2> Operands;
    668       bool Changed = false;
    669       for (auto *Op : Expr->operands()) {
    670         Operands.push_back(((SC*)this)->visit(Op));
    671         Changed |= Op != Operands.back();
    672       }
    673       return !Changed ? Expr : SE.getUMaxExpr(Operands);
    674     }
    675 
    676     const SCEV *visitUnknown(const SCEVUnknown *Expr) {
    677       return Expr;
    678     }
    679 
    680     const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
    681       return Expr;
    682     }
    683   };
    684 
    685   typedef DenseMap<const Value*, Value*> ValueToValueMap;
    686 
    687   /// The SCEVParameterRewriter takes a scalar evolution expression and updates
    688   /// the SCEVUnknown components following the Map (Value -> Value).
    689   class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
    690   public:
    691     static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
    692                                ValueToValueMap &Map,
    693                                bool InterpretConsts = false) {
    694       SCEVParameterRewriter Rewriter(SE, Map, InterpretConsts);
    695       return Rewriter.visit(Scev);
    696     }
    697 
    698     SCEVParameterRewriter(ScalarEvolution &SE, ValueToValueMap &M, bool C)
    699       : SCEVRewriteVisitor(SE), Map(M), InterpretConsts(C) {}
    700 
    701     const SCEV *visitUnknown(const SCEVUnknown *Expr) {
    702       Value *V = Expr->getValue();
    703       if (Map.count(V)) {
    704         Value *NV = Map[V];
    705         if (InterpretConsts && isa<ConstantInt>(NV))
    706           return SE.getConstant(cast<ConstantInt>(NV));
    707         return SE.getUnknown(NV);
    708       }
    709       return Expr;
    710     }
    711 
    712   private:
    713     ValueToValueMap &Map;
    714     bool InterpretConsts;
    715   };
    716 
    717   typedef DenseMap<const Loop*, const SCEV*> LoopToScevMapT;
    718 
    719   /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
    720   /// the Map (Loop -> SCEV) to all AddRecExprs.
    721   class SCEVLoopAddRecRewriter
    722       : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
    723   public:
    724     static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
    725                                ScalarEvolution &SE) {
    726       SCEVLoopAddRecRewriter Rewriter(SE, Map);
    727       return Rewriter.visit(Scev);
    728     }
    729 
    730     SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
    731         : SCEVRewriteVisitor(SE), Map(M) {}
    732 
    733     const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
    734       SmallVector<const SCEV *, 2> Operands;
    735       for (int i = 0, e = Expr->getNumOperands(); i < e; ++i)
    736         Operands.push_back(visit(Expr->getOperand(i)));
    737 
    738       const Loop *L = Expr->getLoop();
    739       const SCEV *Res = SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
    740 
    741       if (0 == Map.count(L))
    742         return Res;
    743 
    744       const SCEVAddRecExpr *Rec = cast<SCEVAddRecExpr>(Res);
    745       return Rec->evaluateAtIteration(Map[L], SE);
    746     }
    747 
    748   private:
    749     LoopToScevMapT &Map;
    750   };
    751 }
    752 
    753 #endif
    754