1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file defines the classes used to represent and build scalar expressions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 15 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 16 17 #include "llvm/ADT/SmallPtrSet.h" 18 #include "llvm/ADT/iterator_range.h" 19 #include "llvm/Analysis/ScalarEvolution.h" 20 #include "llvm/Support/ErrorHandling.h" 21 22 namespace llvm { 23 class ConstantInt; 24 class ConstantRange; 25 class DominatorTree; 26 27 enum SCEVTypes { 28 // These should be ordered in terms of increasing complexity to make the 29 // folders simpler. 30 scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr, 31 scUDivExpr, scAddRecExpr, scUMaxExpr, scSMaxExpr, 32 scUnknown, scCouldNotCompute 33 }; 34 35 //===--------------------------------------------------------------------===// 36 /// SCEVConstant - This class represents a constant integer value. 37 /// 38 class SCEVConstant : public SCEV { 39 friend class ScalarEvolution; 40 41 ConstantInt *V; 42 SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) : 43 SCEV(ID, scConstant), V(v) {} 44 public: 45 ConstantInt *getValue() const { return V; } 46 const APInt &getAPInt() const { return getValue()->getValue(); } 47 48 Type *getType() const { return V->getType(); } 49 50 /// Methods for support type inquiry through isa, cast, and dyn_cast: 51 static inline bool classof(const SCEV *S) { 52 return S->getSCEVType() == scConstant; 53 } 54 }; 55 56 //===--------------------------------------------------------------------===// 57 /// SCEVCastExpr - This is the base class for unary cast operator classes. 58 /// 59 class SCEVCastExpr : public SCEV { 60 protected: 61 const SCEV *Op; 62 Type *Ty; 63 64 SCEVCastExpr(const FoldingSetNodeIDRef ID, 65 unsigned SCEVTy, const SCEV *op, Type *ty); 66 67 public: 68 const SCEV *getOperand() const { return Op; } 69 Type *getType() const { return Ty; } 70 71 /// Methods for support type inquiry through isa, cast, and dyn_cast: 72 static inline bool classof(const SCEV *S) { 73 return S->getSCEVType() == scTruncate || 74 S->getSCEVType() == scZeroExtend || 75 S->getSCEVType() == scSignExtend; 76 } 77 }; 78 79 //===--------------------------------------------------------------------===// 80 /// SCEVTruncateExpr - This class represents a truncation of an integer value 81 /// to a smaller integer value. 82 /// 83 class SCEVTruncateExpr : public SCEVCastExpr { 84 friend class ScalarEvolution; 85 86 SCEVTruncateExpr(const FoldingSetNodeIDRef ID, 87 const SCEV *op, Type *ty); 88 89 public: 90 /// Methods for support type inquiry through isa, cast, and dyn_cast: 91 static inline bool classof(const SCEV *S) { 92 return S->getSCEVType() == scTruncate; 93 } 94 }; 95 96 //===--------------------------------------------------------------------===// 97 /// SCEVZeroExtendExpr - This class represents a zero extension of a small 98 /// integer value to a larger integer value. 99 /// 100 class SCEVZeroExtendExpr : public SCEVCastExpr { 101 friend class ScalarEvolution; 102 103 SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, 104 const SCEV *op, Type *ty); 105 106 public: 107 /// Methods for support type inquiry through isa, cast, and dyn_cast: 108 static inline bool classof(const SCEV *S) { 109 return S->getSCEVType() == scZeroExtend; 110 } 111 }; 112 113 //===--------------------------------------------------------------------===// 114 /// SCEVSignExtendExpr - This class represents a sign extension of a small 115 /// integer value to a larger integer value. 116 /// 117 class SCEVSignExtendExpr : public SCEVCastExpr { 118 friend class ScalarEvolution; 119 120 SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, 121 const SCEV *op, Type *ty); 122 123 public: 124 /// Methods for support type inquiry through isa, cast, and dyn_cast: 125 static inline bool classof(const SCEV *S) { 126 return S->getSCEVType() == scSignExtend; 127 } 128 }; 129 130 131 //===--------------------------------------------------------------------===// 132 /// SCEVNAryExpr - This node is a base class providing common 133 /// functionality for n'ary operators. 134 /// 135 class SCEVNAryExpr : public SCEV { 136 protected: 137 // Since SCEVs are immutable, ScalarEvolution allocates operand 138 // arrays with its SCEVAllocator, so this class just needs a simple 139 // pointer rather than a more elaborate vector-like data structure. 140 // This also avoids the need for a non-trivial destructor. 141 const SCEV *const *Operands; 142 size_t NumOperands; 143 144 SCEVNAryExpr(const FoldingSetNodeIDRef ID, 145 enum SCEVTypes T, const SCEV *const *O, size_t N) 146 : SCEV(ID, T), Operands(O), NumOperands(N) {} 147 148 public: 149 size_t getNumOperands() const { return NumOperands; } 150 const SCEV *getOperand(unsigned i) const { 151 assert(i < NumOperands && "Operand index out of range!"); 152 return Operands[i]; 153 } 154 155 typedef const SCEV *const *op_iterator; 156 typedef iterator_range<op_iterator> op_range; 157 op_iterator op_begin() const { return Operands; } 158 op_iterator op_end() const { return Operands + NumOperands; } 159 op_range operands() const { 160 return make_range(op_begin(), op_end()); 161 } 162 163 Type *getType() const { return getOperand(0)->getType(); } 164 165 NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const { 166 return (NoWrapFlags)(SubclassData & Mask); 167 } 168 169 /// Methods for support type inquiry through isa, cast, and dyn_cast: 170 static inline bool classof(const SCEV *S) { 171 return S->getSCEVType() == scAddExpr || 172 S->getSCEVType() == scMulExpr || 173 S->getSCEVType() == scSMaxExpr || 174 S->getSCEVType() == scUMaxExpr || 175 S->getSCEVType() == scAddRecExpr; 176 } 177 }; 178 179 //===--------------------------------------------------------------------===// 180 /// SCEVCommutativeExpr - This node is the base class for n'ary commutative 181 /// operators. 182 /// 183 class SCEVCommutativeExpr : public SCEVNAryExpr { 184 protected: 185 SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, 186 enum SCEVTypes T, const SCEV *const *O, size_t N) 187 : SCEVNAryExpr(ID, T, O, N) {} 188 189 public: 190 /// Methods for support type inquiry through isa, cast, and dyn_cast: 191 static inline bool classof(const SCEV *S) { 192 return S->getSCEVType() == scAddExpr || 193 S->getSCEVType() == scMulExpr || 194 S->getSCEVType() == scSMaxExpr || 195 S->getSCEVType() == scUMaxExpr; 196 } 197 198 /// Set flags for a non-recurrence without clearing previously set flags. 199 void setNoWrapFlags(NoWrapFlags Flags) { 200 SubclassData |= Flags; 201 } 202 }; 203 204 205 //===--------------------------------------------------------------------===// 206 /// SCEVAddExpr - This node represents an addition of some number of SCEVs. 207 /// 208 class SCEVAddExpr : public SCEVCommutativeExpr { 209 friend class ScalarEvolution; 210 211 SCEVAddExpr(const FoldingSetNodeIDRef ID, 212 const SCEV *const *O, size_t N) 213 : SCEVCommutativeExpr(ID, scAddExpr, O, N) { 214 } 215 216 public: 217 Type *getType() const { 218 // Use the type of the last operand, which is likely to be a pointer 219 // type, if there is one. This doesn't usually matter, but it can help 220 // reduce casts when the expressions are expanded. 221 return getOperand(getNumOperands() - 1)->getType(); 222 } 223 224 /// Methods for support type inquiry through isa, cast, and dyn_cast: 225 static inline bool classof(const SCEV *S) { 226 return S->getSCEVType() == scAddExpr; 227 } 228 }; 229 230 //===--------------------------------------------------------------------===// 231 /// SCEVMulExpr - This node represents multiplication of some number of SCEVs. 232 /// 233 class SCEVMulExpr : public SCEVCommutativeExpr { 234 friend class ScalarEvolution; 235 236 SCEVMulExpr(const FoldingSetNodeIDRef ID, 237 const SCEV *const *O, size_t N) 238 : SCEVCommutativeExpr(ID, scMulExpr, O, N) { 239 } 240 241 public: 242 /// Methods for support type inquiry through isa, cast, and dyn_cast: 243 static inline bool classof(const SCEV *S) { 244 return S->getSCEVType() == scMulExpr; 245 } 246 }; 247 248 249 //===--------------------------------------------------------------------===// 250 /// SCEVUDivExpr - This class represents a binary unsigned division operation. 251 /// 252 class SCEVUDivExpr : public SCEV { 253 friend class ScalarEvolution; 254 255 const SCEV *LHS; 256 const SCEV *RHS; 257 SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs) 258 : SCEV(ID, scUDivExpr), LHS(lhs), RHS(rhs) {} 259 260 public: 261 const SCEV *getLHS() const { return LHS; } 262 const SCEV *getRHS() const { return RHS; } 263 264 Type *getType() const { 265 // In most cases the types of LHS and RHS will be the same, but in some 266 // crazy cases one or the other may be a pointer. ScalarEvolution doesn't 267 // depend on the type for correctness, but handling types carefully can 268 // avoid extra casts in the SCEVExpander. The LHS is more likely to be 269 // a pointer type than the RHS, so use the RHS' type here. 270 return getRHS()->getType(); 271 } 272 273 /// Methods for support type inquiry through isa, cast, and dyn_cast: 274 static inline bool classof(const SCEV *S) { 275 return S->getSCEVType() == scUDivExpr; 276 } 277 }; 278 279 280 //===--------------------------------------------------------------------===// 281 /// SCEVAddRecExpr - This node represents a polynomial recurrence on the trip 282 /// count of the specified loop. This is the primary focus of the 283 /// ScalarEvolution framework; all the other SCEV subclasses are mostly just 284 /// supporting infrastructure to allow SCEVAddRecExpr expressions to be 285 /// created and analyzed. 286 /// 287 /// All operands of an AddRec are required to be loop invariant. 288 /// 289 class SCEVAddRecExpr : public SCEVNAryExpr { 290 friend class ScalarEvolution; 291 292 const Loop *L; 293 294 SCEVAddRecExpr(const FoldingSetNodeIDRef ID, 295 const SCEV *const *O, size_t N, const Loop *l) 296 : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {} 297 298 public: 299 const SCEV *getStart() const { return Operands[0]; } 300 const Loop *getLoop() const { return L; } 301 302 /// getStepRecurrence - This method constructs and returns the recurrence 303 /// indicating how much this expression steps by. If this is a polynomial 304 /// of degree N, it returns a chrec of degree N-1. 305 /// We cannot determine whether the step recurrence has self-wraparound. 306 const SCEV *getStepRecurrence(ScalarEvolution &SE) const { 307 if (isAffine()) return getOperand(1); 308 return SE.getAddRecExpr(SmallVector<const SCEV *, 3>(op_begin()+1, 309 op_end()), 310 getLoop(), FlagAnyWrap); 311 } 312 313 /// isAffine - Return true if this represents an expression 314 /// A + B*x where A and B are loop invariant values. 315 bool isAffine() const { 316 // We know that the start value is invariant. This expression is thus 317 // affine iff the step is also invariant. 318 return getNumOperands() == 2; 319 } 320 321 /// isQuadratic - Return true if this represents an expression 322 /// A + B*x + C*x^2 where A, B and C are loop invariant values. 323 /// This corresponds to an addrec of the form {L,+,M,+,N} 324 bool isQuadratic() const { 325 return getNumOperands() == 3; 326 } 327 328 /// Set flags for a recurrence without clearing any previously set flags. 329 /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here 330 /// to make it easier to propagate flags. 331 void setNoWrapFlags(NoWrapFlags Flags) { 332 if (Flags & (FlagNUW | FlagNSW)) 333 Flags = ScalarEvolution::setFlags(Flags, FlagNW); 334 SubclassData |= Flags; 335 } 336 337 /// evaluateAtIteration - Return the value of this chain of recurrences at 338 /// the specified iteration number. 339 const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const; 340 341 /// getNumIterationsInRange - Return the number of iterations of this loop 342 /// that produce values in the specified constant range. Another way of 343 /// looking at this is that it returns the first iteration number where the 344 /// value is not in the condition, thus computing the exit count. If the 345 /// iteration count can't be computed, an instance of SCEVCouldNotCompute is 346 /// returned. 347 const SCEV *getNumIterationsInRange(ConstantRange Range, 348 ScalarEvolution &SE) const; 349 350 /// getPostIncExpr - Return an expression representing the value of 351 /// this expression one iteration of the loop ahead. 352 const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const { 353 return cast<SCEVAddRecExpr>(SE.getAddExpr(this, getStepRecurrence(SE))); 354 } 355 356 /// Methods for support type inquiry through isa, cast, and dyn_cast: 357 static inline bool classof(const SCEV *S) { 358 return S->getSCEVType() == scAddRecExpr; 359 } 360 }; 361 362 //===--------------------------------------------------------------------===// 363 /// SCEVSMaxExpr - This class represents a signed maximum selection. 364 /// 365 class SCEVSMaxExpr : public SCEVCommutativeExpr { 366 friend class ScalarEvolution; 367 368 SCEVSMaxExpr(const FoldingSetNodeIDRef ID, 369 const SCEV *const *O, size_t N) 370 : SCEVCommutativeExpr(ID, scSMaxExpr, O, N) { 371 // Max never overflows. 372 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 373 } 374 375 public: 376 /// Methods for support type inquiry through isa, cast, and dyn_cast: 377 static inline bool classof(const SCEV *S) { 378 return S->getSCEVType() == scSMaxExpr; 379 } 380 }; 381 382 383 //===--------------------------------------------------------------------===// 384 /// SCEVUMaxExpr - This class represents an unsigned maximum selection. 385 /// 386 class SCEVUMaxExpr : public SCEVCommutativeExpr { 387 friend class ScalarEvolution; 388 389 SCEVUMaxExpr(const FoldingSetNodeIDRef ID, 390 const SCEV *const *O, size_t N) 391 : SCEVCommutativeExpr(ID, scUMaxExpr, O, N) { 392 // Max never overflows. 393 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 394 } 395 396 public: 397 /// Methods for support type inquiry through isa, cast, and dyn_cast: 398 static inline bool classof(const SCEV *S) { 399 return S->getSCEVType() == scUMaxExpr; 400 } 401 }; 402 403 //===--------------------------------------------------------------------===// 404 /// SCEVUnknown - This means that we are dealing with an entirely unknown SCEV 405 /// value, and only represent it as its LLVM Value. This is the "bottom" 406 /// value for the analysis. 407 /// 408 class SCEVUnknown final : public SCEV, private CallbackVH { 409 friend class ScalarEvolution; 410 411 // Implement CallbackVH. 412 void deleted() override; 413 void allUsesReplacedWith(Value *New) override; 414 415 /// SE - The parent ScalarEvolution value. This is used to update 416 /// the parent's maps when the value associated with a SCEVUnknown 417 /// is deleted or RAUW'd. 418 ScalarEvolution *SE; 419 420 /// Next - The next pointer in the linked list of all 421 /// SCEVUnknown instances owned by a ScalarEvolution. 422 SCEVUnknown *Next; 423 424 SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, 425 ScalarEvolution *se, SCEVUnknown *next) : 426 SCEV(ID, scUnknown), CallbackVH(V), SE(se), Next(next) {} 427 428 public: 429 Value *getValue() const { return getValPtr(); } 430 431 /// isSizeOf, isAlignOf, isOffsetOf - Test whether this is a special 432 /// constant representing a type size, alignment, or field offset in 433 /// a target-independent manner, and hasn't happened to have been 434 /// folded with other operations into something unrecognizable. This 435 /// is mainly only useful for pretty-printing and other situations 436 /// where it isn't absolutely required for these to succeed. 437 bool isSizeOf(Type *&AllocTy) const; 438 bool isAlignOf(Type *&AllocTy) const; 439 bool isOffsetOf(Type *&STy, Constant *&FieldNo) const; 440 441 Type *getType() const { return getValPtr()->getType(); } 442 443 /// Methods for support type inquiry through isa, cast, and dyn_cast: 444 static inline bool classof(const SCEV *S) { 445 return S->getSCEVType() == scUnknown; 446 } 447 }; 448 449 /// SCEVVisitor - This class defines a simple visitor class that may be used 450 /// for various SCEV analysis purposes. 451 template<typename SC, typename RetVal=void> 452 struct SCEVVisitor { 453 RetVal visit(const SCEV *S) { 454 switch (S->getSCEVType()) { 455 case scConstant: 456 return ((SC*)this)->visitConstant((const SCEVConstant*)S); 457 case scTruncate: 458 return ((SC*)this)->visitTruncateExpr((const SCEVTruncateExpr*)S); 459 case scZeroExtend: 460 return ((SC*)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr*)S); 461 case scSignExtend: 462 return ((SC*)this)->visitSignExtendExpr((const SCEVSignExtendExpr*)S); 463 case scAddExpr: 464 return ((SC*)this)->visitAddExpr((const SCEVAddExpr*)S); 465 case scMulExpr: 466 return ((SC*)this)->visitMulExpr((const SCEVMulExpr*)S); 467 case scUDivExpr: 468 return ((SC*)this)->visitUDivExpr((const SCEVUDivExpr*)S); 469 case scAddRecExpr: 470 return ((SC*)this)->visitAddRecExpr((const SCEVAddRecExpr*)S); 471 case scSMaxExpr: 472 return ((SC*)this)->visitSMaxExpr((const SCEVSMaxExpr*)S); 473 case scUMaxExpr: 474 return ((SC*)this)->visitUMaxExpr((const SCEVUMaxExpr*)S); 475 case scUnknown: 476 return ((SC*)this)->visitUnknown((const SCEVUnknown*)S); 477 case scCouldNotCompute: 478 return ((SC*)this)->visitCouldNotCompute((const SCEVCouldNotCompute*)S); 479 default: 480 llvm_unreachable("Unknown SCEV type!"); 481 } 482 } 483 484 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) { 485 llvm_unreachable("Invalid use of SCEVCouldNotCompute!"); 486 } 487 }; 488 489 /// Visit all nodes in the expression tree using worklist traversal. 490 /// 491 /// Visitor implements: 492 /// // return true to follow this node. 493 /// bool follow(const SCEV *S); 494 /// // return true to terminate the search. 495 /// bool isDone(); 496 template<typename SV> 497 class SCEVTraversal { 498 SV &Visitor; 499 SmallVector<const SCEV *, 8> Worklist; 500 SmallPtrSet<const SCEV *, 8> Visited; 501 502 void push(const SCEV *S) { 503 if (Visited.insert(S).second && Visitor.follow(S)) 504 Worklist.push_back(S); 505 } 506 public: 507 SCEVTraversal(SV& V): Visitor(V) {} 508 509 void visitAll(const SCEV *Root) { 510 push(Root); 511 while (!Worklist.empty() && !Visitor.isDone()) { 512 const SCEV *S = Worklist.pop_back_val(); 513 514 switch (S->getSCEVType()) { 515 case scConstant: 516 case scUnknown: 517 break; 518 case scTruncate: 519 case scZeroExtend: 520 case scSignExtend: 521 push(cast<SCEVCastExpr>(S)->getOperand()); 522 break; 523 case scAddExpr: 524 case scMulExpr: 525 case scSMaxExpr: 526 case scUMaxExpr: 527 case scAddRecExpr: { 528 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S); 529 for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), 530 E = NAry->op_end(); I != E; ++I) { 531 push(*I); 532 } 533 break; 534 } 535 case scUDivExpr: { 536 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S); 537 push(UDiv->getLHS()); 538 push(UDiv->getRHS()); 539 break; 540 } 541 case scCouldNotCompute: 542 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 543 default: 544 llvm_unreachable("Unknown SCEV kind!"); 545 } 546 } 547 } 548 }; 549 550 /// Use SCEVTraversal to visit all nodes in the given expression tree. 551 template<typename SV> 552 void visitAll(const SCEV *Root, SV& Visitor) { 553 SCEVTraversal<SV> T(Visitor); 554 T.visitAll(Root); 555 } 556 557 /// Recursively visits a SCEV expression and re-writes it. 558 template<typename SC> 559 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> { 560 protected: 561 ScalarEvolution &SE; 562 public: 563 SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {} 564 565 const SCEV *visitConstant(const SCEVConstant *Constant) { 566 return Constant; 567 } 568 569 const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { 570 const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand()); 571 return SE.getTruncateExpr(Operand, Expr->getType()); 572 } 573 574 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 575 const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand()); 576 return SE.getZeroExtendExpr(Operand, Expr->getType()); 577 } 578 579 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 580 const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand()); 581 return SE.getSignExtendExpr(Operand, Expr->getType()); 582 } 583 584 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { 585 SmallVector<const SCEV *, 2> Operands; 586 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) 587 Operands.push_back(((SC*)this)->visit(Expr->getOperand(i))); 588 return SE.getAddExpr(Operands); 589 } 590 591 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { 592 SmallVector<const SCEV *, 2> Operands; 593 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) 594 Operands.push_back(((SC*)this)->visit(Expr->getOperand(i))); 595 return SE.getMulExpr(Operands); 596 } 597 598 const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { 599 return SE.getUDivExpr(((SC*)this)->visit(Expr->getLHS()), 600 ((SC*)this)->visit(Expr->getRHS())); 601 } 602 603 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 604 SmallVector<const SCEV *, 2> Operands; 605 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) 606 Operands.push_back(((SC*)this)->visit(Expr->getOperand(i))); 607 return SE.getAddRecExpr(Operands, Expr->getLoop(), 608 Expr->getNoWrapFlags()); 609 } 610 611 const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { 612 SmallVector<const SCEV *, 2> Operands; 613 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) 614 Operands.push_back(((SC*)this)->visit(Expr->getOperand(i))); 615 return SE.getSMaxExpr(Operands); 616 } 617 618 const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { 619 SmallVector<const SCEV *, 2> Operands; 620 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) 621 Operands.push_back(((SC*)this)->visit(Expr->getOperand(i))); 622 return SE.getUMaxExpr(Operands); 623 } 624 625 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 626 return Expr; 627 } 628 629 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { 630 return Expr; 631 } 632 }; 633 634 typedef DenseMap<const Value*, Value*> ValueToValueMap; 635 636 /// The SCEVParameterRewriter takes a scalar evolution expression and updates 637 /// the SCEVUnknown components following the Map (Value -> Value). 638 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> { 639 public: 640 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, 641 ValueToValueMap &Map, 642 bool InterpretConsts = false) { 643 SCEVParameterRewriter Rewriter(SE, Map, InterpretConsts); 644 return Rewriter.visit(Scev); 645 } 646 647 SCEVParameterRewriter(ScalarEvolution &SE, ValueToValueMap &M, bool C) 648 : SCEVRewriteVisitor(SE), Map(M), InterpretConsts(C) {} 649 650 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 651 Value *V = Expr->getValue(); 652 if (Map.count(V)) { 653 Value *NV = Map[V]; 654 if (InterpretConsts && isa<ConstantInt>(NV)) 655 return SE.getConstant(cast<ConstantInt>(NV)); 656 return SE.getUnknown(NV); 657 } 658 return Expr; 659 } 660 661 private: 662 ValueToValueMap ⤅ 663 bool InterpretConsts; 664 }; 665 666 typedef DenseMap<const Loop*, const SCEV*> LoopToScevMapT; 667 668 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies 669 /// the Map (Loop -> SCEV) to all AddRecExprs. 670 class SCEVLoopAddRecRewriter 671 : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> { 672 public: 673 static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map, 674 ScalarEvolution &SE) { 675 SCEVLoopAddRecRewriter Rewriter(SE, Map); 676 return Rewriter.visit(Scev); 677 } 678 679 SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M) 680 : SCEVRewriteVisitor(SE), Map(M) {} 681 682 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 683 SmallVector<const SCEV *, 2> Operands; 684 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) 685 Operands.push_back(visit(Expr->getOperand(i))); 686 687 const Loop *L = Expr->getLoop(); 688 const SCEV *Res = SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags()); 689 690 if (0 == Map.count(L)) 691 return Res; 692 693 const SCEVAddRecExpr *Rec = cast<SCEVAddRecExpr>(Res); 694 return Rec->evaluateAtIteration(Map[L], SE); 695 } 696 697 private: 698 LoopToScevMapT ⤅ 699 }; 700 701 /// Applies the Map (Loop -> SCEV) to the given Scev. 702 static inline const SCEV *apply(const SCEV *Scev, LoopToScevMapT &Map, 703 ScalarEvolution &SE) { 704 return SCEVLoopAddRecRewriter::rewrite(Scev, Map, SE); 705 } 706 707 } 708 709 #endif 710