1 // 2 // Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style license that can be 4 // found in the LICENSE file. 5 // 6 7 #include "compiler/ForLoopUnroll.h" 8 9 namespace { 10 11 class IntegerForLoopUnrollMarker : public TIntermTraverser { 12 public: 13 14 virtual bool visitLoop(Visit, TIntermLoop* node) 15 { 16 // This is called after ValidateLimitations pass, so all the ASSERT 17 // should never fail. 18 // See ValidateLimitations::validateForLoopInit(). 19 ASSERT(node); 20 ASSERT(node->getType() == ELoopFor); 21 ASSERT(node->getInit()); 22 TIntermAggregate* decl = node->getInit()->getAsAggregate(); 23 ASSERT(decl && decl->getOp() == EOpDeclaration); 24 TIntermSequence& declSeq = decl->getSequence(); 25 ASSERT(declSeq.size() == 1); 26 TIntermBinary* declInit = declSeq[0]->getAsBinaryNode(); 27 ASSERT(declInit && declInit->getOp() == EOpInitialize); 28 ASSERT(declInit->getLeft()); 29 TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode(); 30 ASSERT(symbol); 31 TBasicType type = symbol->getBasicType(); 32 ASSERT(type == EbtInt || type == EbtFloat); 33 if (type == EbtInt) 34 node->setUnrollFlag(true); 35 return true; 36 } 37 38 }; 39 40 } // anonymous namepsace 41 42 void ForLoopUnroll::FillLoopIndexInfo(TIntermLoop* node, TLoopIndexInfo& info) 43 { 44 ASSERT(node->getType() == ELoopFor); 45 ASSERT(node->getUnrollFlag()); 46 47 TIntermNode* init = node->getInit(); 48 ASSERT(init != NULL); 49 TIntermAggregate* decl = init->getAsAggregate(); 50 ASSERT((decl != NULL) && (decl->getOp() == EOpDeclaration)); 51 TIntermSequence& declSeq = decl->getSequence(); 52 ASSERT(declSeq.size() == 1); 53 TIntermBinary* declInit = declSeq[0]->getAsBinaryNode(); 54 ASSERT((declInit != NULL) && (declInit->getOp() == EOpInitialize)); 55 TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode(); 56 ASSERT(symbol != NULL); 57 ASSERT(symbol->getBasicType() == EbtInt); 58 59 info.id = symbol->getId(); 60 61 ASSERT(declInit->getRight() != NULL); 62 TIntermConstantUnion* initNode = declInit->getRight()->getAsConstantUnion(); 63 ASSERT(initNode != NULL); 64 65 info.initValue = evaluateIntConstant(initNode); 66 info.currentValue = info.initValue; 67 68 TIntermNode* cond = node->getCondition(); 69 ASSERT(cond != NULL); 70 TIntermBinary* binOp = cond->getAsBinaryNode(); 71 ASSERT(binOp != NULL); 72 ASSERT(binOp->getRight() != NULL); 73 ASSERT(binOp->getRight()->getAsConstantUnion() != NULL); 74 75 info.incrementValue = getLoopIncrement(node); 76 info.stopValue = evaluateIntConstant( 77 binOp->getRight()->getAsConstantUnion()); 78 info.op = binOp->getOp(); 79 } 80 81 void ForLoopUnroll::Step() 82 { 83 ASSERT(mLoopIndexStack.size() > 0); 84 TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1]; 85 info.currentValue += info.incrementValue; 86 } 87 88 bool ForLoopUnroll::SatisfiesLoopCondition() 89 { 90 ASSERT(mLoopIndexStack.size() > 0); 91 TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1]; 92 // Relational operator is one of: > >= < <= == or !=. 93 switch (info.op) { 94 case EOpEqual: 95 return (info.currentValue == info.stopValue); 96 case EOpNotEqual: 97 return (info.currentValue != info.stopValue); 98 case EOpLessThan: 99 return (info.currentValue < info.stopValue); 100 case EOpGreaterThan: 101 return (info.currentValue > info.stopValue); 102 case EOpLessThanEqual: 103 return (info.currentValue <= info.stopValue); 104 case EOpGreaterThanEqual: 105 return (info.currentValue >= info.stopValue); 106 default: 107 UNREACHABLE(); 108 } 109 return false; 110 } 111 112 bool ForLoopUnroll::NeedsToReplaceSymbolWithValue(TIntermSymbol* symbol) 113 { 114 for (TVector<TLoopIndexInfo>::iterator i = mLoopIndexStack.begin(); 115 i != mLoopIndexStack.end(); 116 ++i) { 117 if (i->id == symbol->getId()) 118 return true; 119 } 120 return false; 121 } 122 123 int ForLoopUnroll::GetLoopIndexValue(TIntermSymbol* symbol) 124 { 125 for (TVector<TLoopIndexInfo>::iterator i = mLoopIndexStack.begin(); 126 i != mLoopIndexStack.end(); 127 ++i) { 128 if (i->id == symbol->getId()) 129 return i->currentValue; 130 } 131 UNREACHABLE(); 132 return false; 133 } 134 135 void ForLoopUnroll::Push(TLoopIndexInfo& info) 136 { 137 mLoopIndexStack.push_back(info); 138 } 139 140 void ForLoopUnroll::Pop() 141 { 142 mLoopIndexStack.pop_back(); 143 } 144 145 // static 146 void ForLoopUnroll::MarkForLoopsWithIntegerIndicesForUnrolling( 147 TIntermNode* root) 148 { 149 ASSERT(root); 150 151 IntegerForLoopUnrollMarker marker; 152 root->traverse(&marker); 153 } 154 155 int ForLoopUnroll::getLoopIncrement(TIntermLoop* node) 156 { 157 TIntermNode* expr = node->getExpression(); 158 ASSERT(expr != NULL); 159 // for expression has one of the following forms: 160 // loop_index++ 161 // loop_index-- 162 // loop_index += constant_expression 163 // loop_index -= constant_expression 164 // ++loop_index 165 // --loop_index 166 // The last two forms are not specified in the spec, but I am assuming 167 // its an oversight. 168 TIntermUnary* unOp = expr->getAsUnaryNode(); 169 TIntermBinary* binOp = unOp ? NULL : expr->getAsBinaryNode(); 170 171 TOperator op = EOpNull; 172 TIntermConstantUnion* incrementNode = NULL; 173 if (unOp != NULL) { 174 op = unOp->getOp(); 175 } else if (binOp != NULL) { 176 op = binOp->getOp(); 177 ASSERT(binOp->getRight() != NULL); 178 incrementNode = binOp->getRight()->getAsConstantUnion(); 179 ASSERT(incrementNode != NULL); 180 } 181 182 int increment = 0; 183 // The operator is one of: ++ -- += -=. 184 switch (op) { 185 case EOpPostIncrement: 186 case EOpPreIncrement: 187 ASSERT((unOp != NULL) && (binOp == NULL)); 188 increment = 1; 189 break; 190 case EOpPostDecrement: 191 case EOpPreDecrement: 192 ASSERT((unOp != NULL) && (binOp == NULL)); 193 increment = -1; 194 break; 195 case EOpAddAssign: 196 ASSERT((unOp == NULL) && (binOp != NULL)); 197 increment = evaluateIntConstant(incrementNode); 198 break; 199 case EOpSubAssign: 200 ASSERT((unOp == NULL) && (binOp != NULL)); 201 increment = - evaluateIntConstant(incrementNode); 202 break; 203 default: 204 ASSERT(false); 205 } 206 207 return increment; 208 } 209 210 int ForLoopUnroll::evaluateIntConstant(TIntermConstantUnion* node) 211 { 212 ASSERT((node != NULL) && (node->getUnionArrayPointer() != NULL)); 213 return node->getIConst(0); 214 } 215 216