1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file defines the classes used to represent and build scalar expressions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 15 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 16 17 #include "llvm/ADT/DenseMap.h" 18 #include "llvm/ADT/FoldingSet.h" 19 #include "llvm/ADT/SmallPtrSet.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/ADT/iterator_range.h" 22 #include "llvm/Analysis/ScalarEvolution.h" 23 #include "llvm/IR/Constants.h" 24 #include "llvm/IR/Value.h" 25 #include "llvm/IR/ValueHandle.h" 26 #include "llvm/Support/Casting.h" 27 #include "llvm/Support/ErrorHandling.h" 28 #include <cassert> 29 #include <cstddef> 30 31 namespace llvm { 32 33 class APInt; 34 class Constant; 35 class ConstantRange; 36 class Loop; 37 class Type; 38 39 enum SCEVTypes { 40 // These should be ordered in terms of increasing complexity to make the 41 // folders simpler. 42 scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr, 43 scUDivExpr, scAddRecExpr, scUMaxExpr, scSMaxExpr, 44 scUnknown, scCouldNotCompute 45 }; 46 47 /// This class represents a constant integer value. 48 class SCEVConstant : public SCEV { 49 friend class ScalarEvolution; 50 51 ConstantInt *V; 52 53 SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) : 54 SCEV(ID, scConstant), V(v) {} 55 56 public: 57 ConstantInt *getValue() const { return V; } 58 const APInt &getAPInt() const { return getValue()->getValue(); } 59 60 Type *getType() const { return V->getType(); } 61 62 /// Methods for support type inquiry through isa, cast, and dyn_cast: 63 static bool classof(const SCEV *S) { 64 return S->getSCEVType() == scConstant; 65 } 66 }; 67 68 /// This is the base class for unary cast operator classes. 69 class SCEVCastExpr : public SCEV { 70 protected: 71 const SCEV *Op; 72 Type *Ty; 73 74 SCEVCastExpr(const FoldingSetNodeIDRef ID, 75 unsigned SCEVTy, const SCEV *op, Type *ty); 76 77 public: 78 const SCEV *getOperand() const { return Op; } 79 Type *getType() const { return Ty; } 80 81 /// Methods for support type inquiry through isa, cast, and dyn_cast: 82 static bool classof(const SCEV *S) { 83 return S->getSCEVType() == scTruncate || 84 S->getSCEVType() == scZeroExtend || 85 S->getSCEVType() == scSignExtend; 86 } 87 }; 88 89 /// This class represents a truncation of an integer value to a 90 /// smaller integer value. 91 class SCEVTruncateExpr : public SCEVCastExpr { 92 friend class ScalarEvolution; 93 94 SCEVTruncateExpr(const FoldingSetNodeIDRef ID, 95 const SCEV *op, Type *ty); 96 97 public: 98 /// Methods for support type inquiry through isa, cast, and dyn_cast: 99 static bool classof(const SCEV *S) { 100 return S->getSCEVType() == scTruncate; 101 } 102 }; 103 104 /// This class represents a zero extension of a small integer value 105 /// to a larger integer value. 106 class SCEVZeroExtendExpr : public SCEVCastExpr { 107 friend class ScalarEvolution; 108 109 SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, 110 const SCEV *op, Type *ty); 111 112 public: 113 /// Methods for support type inquiry through isa, cast, and dyn_cast: 114 static bool classof(const SCEV *S) { 115 return S->getSCEVType() == scZeroExtend; 116 } 117 }; 118 119 /// This class represents a sign extension of a small integer value 120 /// to a larger integer value. 121 class SCEVSignExtendExpr : public SCEVCastExpr { 122 friend class ScalarEvolution; 123 124 SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, 125 const SCEV *op, Type *ty); 126 127 public: 128 /// Methods for support type inquiry through isa, cast, and dyn_cast: 129 static bool classof(const SCEV *S) { 130 return S->getSCEVType() == scSignExtend; 131 } 132 }; 133 134 /// This node is a base class providing common functionality for 135 /// n'ary operators. 136 class SCEVNAryExpr : public SCEV { 137 protected: 138 // Since SCEVs are immutable, ScalarEvolution allocates operand 139 // arrays with its SCEVAllocator, so this class just needs a simple 140 // pointer rather than a more elaborate vector-like data structure. 141 // This also avoids the need for a non-trivial destructor. 142 const SCEV *const *Operands; 143 size_t NumOperands; 144 145 SCEVNAryExpr(const FoldingSetNodeIDRef ID, 146 enum SCEVTypes T, const SCEV *const *O, size_t N) 147 : SCEV(ID, T), Operands(O), NumOperands(N) {} 148 149 public: 150 size_t getNumOperands() const { return NumOperands; } 151 152 const SCEV *getOperand(unsigned i) const { 153 assert(i < NumOperands && "Operand index out of range!"); 154 return Operands[i]; 155 } 156 157 using op_iterator = const SCEV *const *; 158 using op_range = iterator_range<op_iterator>; 159 160 op_iterator op_begin() const { return Operands; } 161 op_iterator op_end() const { return Operands + NumOperands; } 162 op_range operands() const { 163 return make_range(op_begin(), op_end()); 164 } 165 166 Type *getType() const { return getOperand(0)->getType(); } 167 168 NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const { 169 return (NoWrapFlags)(SubclassData & Mask); 170 } 171 172 bool hasNoUnsignedWrap() const { 173 return getNoWrapFlags(FlagNUW) != FlagAnyWrap; 174 } 175 176 bool hasNoSignedWrap() const { 177 return getNoWrapFlags(FlagNSW) != FlagAnyWrap; 178 } 179 180 bool hasNoSelfWrap() const { 181 return getNoWrapFlags(FlagNW) != FlagAnyWrap; 182 } 183 184 /// Methods for support type inquiry through isa, cast, and dyn_cast: 185 static bool classof(const SCEV *S) { 186 return S->getSCEVType() == scAddExpr || 187 S->getSCEVType() == scMulExpr || 188 S->getSCEVType() == scSMaxExpr || 189 S->getSCEVType() == scUMaxExpr || 190 S->getSCEVType() == scAddRecExpr; 191 } 192 }; 193 194 /// This node is the base class for n'ary commutative operators. 195 class SCEVCommutativeExpr : public SCEVNAryExpr { 196 protected: 197 SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, 198 enum SCEVTypes T, const SCEV *const *O, size_t N) 199 : SCEVNAryExpr(ID, T, O, N) {} 200 201 public: 202 /// Methods for support type inquiry through isa, cast, and dyn_cast: 203 static bool classof(const SCEV *S) { 204 return S->getSCEVType() == scAddExpr || 205 S->getSCEVType() == scMulExpr || 206 S->getSCEVType() == scSMaxExpr || 207 S->getSCEVType() == scUMaxExpr; 208 } 209 210 /// Set flags for a non-recurrence without clearing previously set flags. 211 void setNoWrapFlags(NoWrapFlags Flags) { 212 SubclassData |= Flags; 213 } 214 }; 215 216 /// This node represents an addition of some number of SCEVs. 217 class SCEVAddExpr : public SCEVCommutativeExpr { 218 friend class ScalarEvolution; 219 220 SCEVAddExpr(const FoldingSetNodeIDRef ID, 221 const SCEV *const *O, size_t N) 222 : SCEVCommutativeExpr(ID, scAddExpr, O, N) {} 223 224 public: 225 Type *getType() const { 226 // Use the type of the last operand, which is likely to be a pointer 227 // type, if there is one. This doesn't usually matter, but it can help 228 // reduce casts when the expressions are expanded. 229 return getOperand(getNumOperands() - 1)->getType(); 230 } 231 232 /// Methods for support type inquiry through isa, cast, and dyn_cast: 233 static bool classof(const SCEV *S) { 234 return S->getSCEVType() == scAddExpr; 235 } 236 }; 237 238 /// This node represents multiplication of some number of SCEVs. 239 class SCEVMulExpr : public SCEVCommutativeExpr { 240 friend class ScalarEvolution; 241 242 SCEVMulExpr(const FoldingSetNodeIDRef ID, 243 const SCEV *const *O, size_t N) 244 : SCEVCommutativeExpr(ID, scMulExpr, O, N) {} 245 246 public: 247 /// Methods for support type inquiry through isa, cast, and dyn_cast: 248 static bool classof(const SCEV *S) { 249 return S->getSCEVType() == scMulExpr; 250 } 251 }; 252 253 /// This class represents a binary unsigned division operation. 254 class SCEVUDivExpr : public SCEV { 255 friend class ScalarEvolution; 256 257 const SCEV *LHS; 258 const SCEV *RHS; 259 260 SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs) 261 : SCEV(ID, scUDivExpr), LHS(lhs), RHS(rhs) {} 262 263 public: 264 const SCEV *getLHS() const { return LHS; } 265 const SCEV *getRHS() const { return RHS; } 266 267 Type *getType() const { 268 // In most cases the types of LHS and RHS will be the same, but in some 269 // crazy cases one or the other may be a pointer. ScalarEvolution doesn't 270 // depend on the type for correctness, but handling types carefully can 271 // avoid extra casts in the SCEVExpander. The LHS is more likely to be 272 // a pointer type than the RHS, so use the RHS' type here. 273 return getRHS()->getType(); 274 } 275 276 /// Methods for support type inquiry through isa, cast, and dyn_cast: 277 static bool classof(const SCEV *S) { 278 return S->getSCEVType() == scUDivExpr; 279 } 280 }; 281 282 /// This node represents a polynomial recurrence on the trip count 283 /// of the specified loop. This is the primary focus of the 284 /// ScalarEvolution framework; all the other SCEV subclasses are 285 /// mostly just supporting infrastructure to allow SCEVAddRecExpr 286 /// expressions to be created and analyzed. 287 /// 288 /// All operands of an AddRec are required to be loop invariant. 289 /// 290 class SCEVAddRecExpr : public SCEVNAryExpr { 291 friend class ScalarEvolution; 292 293 const Loop *L; 294 295 SCEVAddRecExpr(const FoldingSetNodeIDRef ID, 296 const SCEV *const *O, size_t N, const Loop *l) 297 : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {} 298 299 public: 300 const SCEV *getStart() const { return Operands[0]; } 301 const Loop *getLoop() const { return L; } 302 303 /// Constructs and returns the recurrence indicating how much this 304 /// expression steps by. If this is a polynomial of degree N, it 305 /// returns a chrec of degree N-1. We cannot determine whether 306 /// the step recurrence has self-wraparound. 307 const SCEV *getStepRecurrence(ScalarEvolution &SE) const { 308 if (isAffine()) return getOperand(1); 309 return SE.getAddRecExpr(SmallVector<const SCEV *, 3>(op_begin()+1, 310 op_end()), 311 getLoop(), FlagAnyWrap); 312 } 313 314 /// Return true if this represents an expression A + B*x where A 315 /// and B are loop invariant values. 316 bool isAffine() const { 317 // We know that the start value is invariant. This expression is thus 318 // affine iff the step is also invariant. 319 return getNumOperands() == 2; 320 } 321 322 /// Return true if this represents an expression A + B*x + C*x^2 323 /// where A, B and C are loop invariant values. This corresponds 324 /// to an addrec of the form {L,+,M,+,N} 325 bool isQuadratic() const { 326 return getNumOperands() == 3; 327 } 328 329 /// Set flags for a recurrence without clearing any previously set flags. 330 /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here 331 /// to make it easier to propagate flags. 332 void setNoWrapFlags(NoWrapFlags Flags) { 333 if (Flags & (FlagNUW | FlagNSW)) 334 Flags = ScalarEvolution::setFlags(Flags, FlagNW); 335 SubclassData |= Flags; 336 } 337 338 /// Return the value of this chain of recurrences at the specified 339 /// iteration number. 340 const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const; 341 342 /// Return the number of iterations of this loop that produce 343 /// values in the specified constant range. Another way of 344 /// looking at this is that it returns the first iteration number 345 /// where the value is not in the condition, thus computing the 346 /// exit count. If the iteration count can't be computed, an 347 /// instance of SCEVCouldNotCompute is returned. 348 const SCEV *getNumIterationsInRange(const ConstantRange &Range, 349 ScalarEvolution &SE) const; 350 351 /// Return an expression representing the value of this expression 352 /// one iteration of the loop ahead. 353 const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const { 354 return cast<SCEVAddRecExpr>(SE.getAddExpr(this, getStepRecurrence(SE))); 355 } 356 357 /// Methods for support type inquiry through isa, cast, and dyn_cast: 358 static bool classof(const SCEV *S) { 359 return S->getSCEVType() == scAddRecExpr; 360 } 361 }; 362 363 /// This class represents a signed maximum selection. 364 class SCEVSMaxExpr : public SCEVCommutativeExpr { 365 friend class ScalarEvolution; 366 367 SCEVSMaxExpr(const FoldingSetNodeIDRef ID, 368 const SCEV *const *O, size_t N) 369 : SCEVCommutativeExpr(ID, scSMaxExpr, O, N) { 370 // Max never overflows. 371 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 372 } 373 374 public: 375 /// Methods for support type inquiry through isa, cast, and dyn_cast: 376 static bool classof(const SCEV *S) { 377 return S->getSCEVType() == scSMaxExpr; 378 } 379 }; 380 381 /// This class represents an unsigned maximum selection. 382 class SCEVUMaxExpr : public SCEVCommutativeExpr { 383 friend class ScalarEvolution; 384 385 SCEVUMaxExpr(const FoldingSetNodeIDRef ID, 386 const SCEV *const *O, size_t N) 387 : SCEVCommutativeExpr(ID, scUMaxExpr, O, N) { 388 // Max never overflows. 389 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 390 } 391 392 public: 393 /// Methods for support type inquiry through isa, cast, and dyn_cast: 394 static bool classof(const SCEV *S) { 395 return S->getSCEVType() == scUMaxExpr; 396 } 397 }; 398 399 /// This means that we are dealing with an entirely unknown SCEV 400 /// value, and only represent it as its LLVM Value. This is the 401 /// "bottom" value for the analysis. 402 class SCEVUnknown final : public SCEV, private CallbackVH { 403 friend class ScalarEvolution; 404 405 /// The parent ScalarEvolution value. This is used to update the 406 /// parent's maps when the value associated with a SCEVUnknown is 407 /// deleted or RAUW'd. 408 ScalarEvolution *SE; 409 410 /// The next pointer in the linked list of all SCEVUnknown 411 /// instances owned by a ScalarEvolution. 412 SCEVUnknown *Next; 413 414 SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, 415 ScalarEvolution *se, SCEVUnknown *next) : 416 SCEV(ID, scUnknown), CallbackVH(V), SE(se), Next(next) {} 417 418 // Implement CallbackVH. 419 void deleted() override; 420 void allUsesReplacedWith(Value *New) override; 421 422 public: 423 Value *getValue() const { return getValPtr(); } 424 425 /// @{ 426 /// Test whether this is a special constant representing a type 427 /// size, alignment, or field offset in a target-independent 428 /// manner, and hasn't happened to have been folded with other 429 /// operations into something unrecognizable. This is mainly only 430 /// useful for pretty-printing and other situations where it isn't 431 /// absolutely required for these to succeed. 432 bool isSizeOf(Type *&AllocTy) const; 433 bool isAlignOf(Type *&AllocTy) const; 434 bool isOffsetOf(Type *&STy, Constant *&FieldNo) const; 435 /// @} 436 437 Type *getType() const { return getValPtr()->getType(); } 438 439 /// Methods for support type inquiry through isa, cast, and dyn_cast: 440 static bool classof(const SCEV *S) { 441 return S->getSCEVType() == scUnknown; 442 } 443 }; 444 445 /// This class defines a simple visitor class that may be used for 446 /// various SCEV analysis purposes. 447 template<typename SC, typename RetVal=void> 448 struct SCEVVisitor { 449 RetVal visit(const SCEV *S) { 450 switch (S->getSCEVType()) { 451 case scConstant: 452 return ((SC*)this)->visitConstant((const SCEVConstant*)S); 453 case scTruncate: 454 return ((SC*)this)->visitTruncateExpr((const SCEVTruncateExpr*)S); 455 case scZeroExtend: 456 return ((SC*)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr*)S); 457 case scSignExtend: 458 return ((SC*)this)->visitSignExtendExpr((const SCEVSignExtendExpr*)S); 459 case scAddExpr: 460 return ((SC*)this)->visitAddExpr((const SCEVAddExpr*)S); 461 case scMulExpr: 462 return ((SC*)this)->visitMulExpr((const SCEVMulExpr*)S); 463 case scUDivExpr: 464 return ((SC*)this)->visitUDivExpr((const SCEVUDivExpr*)S); 465 case scAddRecExpr: 466 return ((SC*)this)->visitAddRecExpr((const SCEVAddRecExpr*)S); 467 case scSMaxExpr: 468 return ((SC*)this)->visitSMaxExpr((const SCEVSMaxExpr*)S); 469 case scUMaxExpr: 470 return ((SC*)this)->visitUMaxExpr((const SCEVUMaxExpr*)S); 471 case scUnknown: 472 return ((SC*)this)->visitUnknown((const SCEVUnknown*)S); 473 case scCouldNotCompute: 474 return ((SC*)this)->visitCouldNotCompute((const SCEVCouldNotCompute*)S); 475 default: 476 llvm_unreachable("Unknown SCEV type!"); 477 } 478 } 479 480 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) { 481 llvm_unreachable("Invalid use of SCEVCouldNotCompute!"); 482 } 483 }; 484 485 /// Visit all nodes in the expression tree using worklist traversal. 486 /// 487 /// Visitor implements: 488 /// // return true to follow this node. 489 /// bool follow(const SCEV *S); 490 /// // return true to terminate the search. 491 /// bool isDone(); 492 template<typename SV> 493 class SCEVTraversal { 494 SV &Visitor; 495 SmallVector<const SCEV *, 8> Worklist; 496 SmallPtrSet<const SCEV *, 8> Visited; 497 498 void push(const SCEV *S) { 499 if (Visited.insert(S).second && Visitor.follow(S)) 500 Worklist.push_back(S); 501 } 502 503 public: 504 SCEVTraversal(SV& V): Visitor(V) {} 505 506 void visitAll(const SCEV *Root) { 507 push(Root); 508 while (!Worklist.empty() && !Visitor.isDone()) { 509 const SCEV *S = Worklist.pop_back_val(); 510 511 switch (S->getSCEVType()) { 512 case scConstant: 513 case scUnknown: 514 break; 515 case scTruncate: 516 case scZeroExtend: 517 case scSignExtend: 518 push(cast<SCEVCastExpr>(S)->getOperand()); 519 break; 520 case scAddExpr: 521 case scMulExpr: 522 case scSMaxExpr: 523 case scUMaxExpr: 524 case scAddRecExpr: 525 for (const auto *Op : cast<SCEVNAryExpr>(S)->operands()) 526 push(Op); 527 break; 528 case scUDivExpr: { 529 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S); 530 push(UDiv->getLHS()); 531 push(UDiv->getRHS()); 532 break; 533 } 534 case scCouldNotCompute: 535 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 536 default: 537 llvm_unreachable("Unknown SCEV kind!"); 538 } 539 } 540 } 541 }; 542 543 /// Use SCEVTraversal to visit all nodes in the given expression tree. 544 template<typename SV> 545 void visitAll(const SCEV *Root, SV& Visitor) { 546 SCEVTraversal<SV> T(Visitor); 547 T.visitAll(Root); 548 } 549 550 /// Return true if any node in \p Root satisfies the predicate \p Pred. 551 template <typename PredTy> 552 bool SCEVExprContains(const SCEV *Root, PredTy Pred) { 553 struct FindClosure { 554 bool Found = false; 555 PredTy Pred; 556 557 FindClosure(PredTy Pred) : Pred(Pred) {} 558 559 bool follow(const SCEV *S) { 560 if (!Pred(S)) 561 return true; 562 563 Found = true; 564 return false; 565 } 566 567 bool isDone() const { return Found; } 568 }; 569 570 FindClosure FC(Pred); 571 visitAll(Root, FC); 572 return FC.Found; 573 } 574 575 /// This visitor recursively visits a SCEV expression and re-writes it. 576 /// The result from each visit is cached, so it will return the same 577 /// SCEV for the same input. 578 template<typename SC> 579 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> { 580 protected: 581 ScalarEvolution &SE; 582 // Memoize the result of each visit so that we only compute once for 583 // the same input SCEV. This is to avoid redundant computations when 584 // a SCEV is referenced by multiple SCEVs. Without memoization, this 585 // visit algorithm would have exponential time complexity in the worst 586 // case, causing the compiler to hang on certain tests. 587 DenseMap<const SCEV *, const SCEV *> RewriteResults; 588 589 public: 590 SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {} 591 592 const SCEV *visit(const SCEV *S) { 593 auto It = RewriteResults.find(S); 594 if (It != RewriteResults.end()) 595 return It->second; 596 auto* Visited = SCEVVisitor<SC, const SCEV *>::visit(S); 597 auto Result = RewriteResults.try_emplace(S, Visited); 598 assert(Result.second && "Should insert a new entry"); 599 return Result.first->second; 600 } 601 602 const SCEV *visitConstant(const SCEVConstant *Constant) { 603 return Constant; 604 } 605 606 const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { 607 const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand()); 608 return Operand == Expr->getOperand() 609 ? Expr 610 : SE.getTruncateExpr(Operand, Expr->getType()); 611 } 612 613 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 614 const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand()); 615 return Operand == Expr->getOperand() 616 ? Expr 617 : SE.getZeroExtendExpr(Operand, Expr->getType()); 618 } 619 620 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 621 const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand()); 622 return Operand == Expr->getOperand() 623 ? Expr 624 : SE.getSignExtendExpr(Operand, Expr->getType()); 625 } 626 627 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { 628 SmallVector<const SCEV *, 2> Operands; 629 bool Changed = false; 630 for (auto *Op : Expr->operands()) { 631 Operands.push_back(((SC*)this)->visit(Op)); 632 Changed |= Op != Operands.back(); 633 } 634 return !Changed ? Expr : SE.getAddExpr(Operands); 635 } 636 637 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { 638 SmallVector<const SCEV *, 2> Operands; 639 bool Changed = false; 640 for (auto *Op : Expr->operands()) { 641 Operands.push_back(((SC*)this)->visit(Op)); 642 Changed |= Op != Operands.back(); 643 } 644 return !Changed ? Expr : SE.getMulExpr(Operands); 645 } 646 647 const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { 648 auto *LHS = ((SC *)this)->visit(Expr->getLHS()); 649 auto *RHS = ((SC *)this)->visit(Expr->getRHS()); 650 bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS(); 651 return !Changed ? Expr : SE.getUDivExpr(LHS, RHS); 652 } 653 654 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 655 SmallVector<const SCEV *, 2> Operands; 656 bool Changed = false; 657 for (auto *Op : Expr->operands()) { 658 Operands.push_back(((SC*)this)->visit(Op)); 659 Changed |= Op != Operands.back(); 660 } 661 return !Changed ? Expr 662 : SE.getAddRecExpr(Operands, Expr->getLoop(), 663 Expr->getNoWrapFlags()); 664 } 665 666 const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { 667 SmallVector<const SCEV *, 2> Operands; 668 bool Changed = false; 669 for (auto *Op : Expr->operands()) { 670 Operands.push_back(((SC *)this)->visit(Op)); 671 Changed |= Op != Operands.back(); 672 } 673 return !Changed ? Expr : SE.getSMaxExpr(Operands); 674 } 675 676 const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { 677 SmallVector<const SCEV *, 2> Operands; 678 bool Changed = false; 679 for (auto *Op : Expr->operands()) { 680 Operands.push_back(((SC*)this)->visit(Op)); 681 Changed |= Op != Operands.back(); 682 } 683 return !Changed ? Expr : SE.getUMaxExpr(Operands); 684 } 685 686 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 687 return Expr; 688 } 689 690 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { 691 return Expr; 692 } 693 }; 694 695 using ValueToValueMap = DenseMap<const Value *, Value *>; 696 697 /// The SCEVParameterRewriter takes a scalar evolution expression and updates 698 /// the SCEVUnknown components following the Map (Value -> Value). 699 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> { 700 public: 701 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, 702 ValueToValueMap &Map, 703 bool InterpretConsts = false) { 704 SCEVParameterRewriter Rewriter(SE, Map, InterpretConsts); 705 return Rewriter.visit(Scev); 706 } 707 708 SCEVParameterRewriter(ScalarEvolution &SE, ValueToValueMap &M, bool C) 709 : SCEVRewriteVisitor(SE), Map(M), InterpretConsts(C) {} 710 711 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 712 Value *V = Expr->getValue(); 713 if (Map.count(V)) { 714 Value *NV = Map[V]; 715 if (InterpretConsts && isa<ConstantInt>(NV)) 716 return SE.getConstant(cast<ConstantInt>(NV)); 717 return SE.getUnknown(NV); 718 } 719 return Expr; 720 } 721 722 private: 723 ValueToValueMap ⤅ 724 bool InterpretConsts; 725 }; 726 727 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>; 728 729 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies 730 /// the Map (Loop -> SCEV) to all AddRecExprs. 731 class SCEVLoopAddRecRewriter 732 : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> { 733 public: 734 SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M) 735 : SCEVRewriteVisitor(SE), Map(M) {} 736 737 static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map, 738 ScalarEvolution &SE) { 739 SCEVLoopAddRecRewriter Rewriter(SE, Map); 740 return Rewriter.visit(Scev); 741 } 742 743 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 744 SmallVector<const SCEV *, 2> Operands; 745 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) 746 Operands.push_back(visit(Expr->getOperand(i))); 747 748 const Loop *L = Expr->getLoop(); 749 const SCEV *Res = SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags()); 750 751 if (0 == Map.count(L)) 752 return Res; 753 754 const SCEVAddRecExpr *Rec = cast<SCEVAddRecExpr>(Res); 755 return Rec->evaluateAtIteration(Map[L], SE); 756 } 757 758 private: 759 LoopToScevMapT ⤅ 760 }; 761 762 } // end namespace llvm 763 764 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 765