Home | History | Annotate | Download | only in compiler
      1 //
      2 // Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved.
      3 // Use of this source code is governed by a BSD-style license that can be
      4 // found in the LICENSE file.
      5 //
      6 
      7 #include "compiler/ForLoopUnroll.h"
      8 
      9 namespace {
     10 
     11 class IntegerForLoopUnrollMarker : public TIntermTraverser {
     12 public:
     13 
     14     virtual bool visitLoop(Visit, TIntermLoop* node)
     15     {
     16         // This is called after ValidateLimitations pass, so all the ASSERT
     17         // should never fail.
     18         // See ValidateLimitations::validateForLoopInit().
     19         ASSERT(node);
     20         ASSERT(node->getType() == ELoopFor);
     21         ASSERT(node->getInit());
     22         TIntermAggregate* decl = node->getInit()->getAsAggregate();
     23         ASSERT(decl && decl->getOp() == EOpDeclaration);
     24         TIntermSequence& declSeq = decl->getSequence();
     25         ASSERT(declSeq.size() == 1);
     26         TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
     27         ASSERT(declInit && declInit->getOp() == EOpInitialize);
     28         ASSERT(declInit->getLeft());
     29         TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
     30         ASSERT(symbol);
     31         TBasicType type = symbol->getBasicType();
     32         ASSERT(type == EbtInt || type == EbtFloat);
     33         if (type == EbtInt)
     34             node->setUnrollFlag(true);
     35         return true;
     36     }
     37 
     38 };
     39 
     40 }  // anonymous namepsace
     41 
     42 void ForLoopUnroll::FillLoopIndexInfo(TIntermLoop* node, TLoopIndexInfo& info)
     43 {
     44     ASSERT(node->getType() == ELoopFor);
     45     ASSERT(node->getUnrollFlag());
     46 
     47     TIntermNode* init = node->getInit();
     48     ASSERT(init != NULL);
     49     TIntermAggregate* decl = init->getAsAggregate();
     50     ASSERT((decl != NULL) && (decl->getOp() == EOpDeclaration));
     51     TIntermSequence& declSeq = decl->getSequence();
     52     ASSERT(declSeq.size() == 1);
     53     TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
     54     ASSERT((declInit != NULL) && (declInit->getOp() == EOpInitialize));
     55     TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
     56     ASSERT(symbol != NULL);
     57     ASSERT(symbol->getBasicType() == EbtInt);
     58 
     59     info.id = symbol->getId();
     60 
     61     ASSERT(declInit->getRight() != NULL);
     62     TIntermConstantUnion* initNode = declInit->getRight()->getAsConstantUnion();
     63     ASSERT(initNode != NULL);
     64 
     65     info.initValue = evaluateIntConstant(initNode);
     66     info.currentValue = info.initValue;
     67 
     68     TIntermNode* cond = node->getCondition();
     69     ASSERT(cond != NULL);
     70     TIntermBinary* binOp = cond->getAsBinaryNode();
     71     ASSERT(binOp != NULL);
     72     ASSERT(binOp->getRight() != NULL);
     73     ASSERT(binOp->getRight()->getAsConstantUnion() != NULL);
     74 
     75     info.incrementValue = getLoopIncrement(node);
     76     info.stopValue = evaluateIntConstant(
     77         binOp->getRight()->getAsConstantUnion());
     78     info.op = binOp->getOp();
     79 }
     80 
     81 void ForLoopUnroll::Step()
     82 {
     83     ASSERT(mLoopIndexStack.size() > 0);
     84     TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1];
     85     info.currentValue += info.incrementValue;
     86 }
     87 
     88 bool ForLoopUnroll::SatisfiesLoopCondition()
     89 {
     90     ASSERT(mLoopIndexStack.size() > 0);
     91     TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1];
     92     // Relational operator is one of: > >= < <= == or !=.
     93     switch (info.op) {
     94       case EOpEqual:
     95         return (info.currentValue == info.stopValue);
     96       case EOpNotEqual:
     97         return (info.currentValue != info.stopValue);
     98       case EOpLessThan:
     99         return (info.currentValue < info.stopValue);
    100       case EOpGreaterThan:
    101         return (info.currentValue > info.stopValue);
    102       case EOpLessThanEqual:
    103         return (info.currentValue <= info.stopValue);
    104       case EOpGreaterThanEqual:
    105         return (info.currentValue >= info.stopValue);
    106       default:
    107         UNREACHABLE();
    108     }
    109     return false;
    110 }
    111 
    112 bool ForLoopUnroll::NeedsToReplaceSymbolWithValue(TIntermSymbol* symbol)
    113 {
    114     for (TVector<TLoopIndexInfo>::iterator i = mLoopIndexStack.begin();
    115          i != mLoopIndexStack.end();
    116          ++i) {
    117         if (i->id == symbol->getId())
    118             return true;
    119     }
    120     return false;
    121 }
    122 
    123 int ForLoopUnroll::GetLoopIndexValue(TIntermSymbol* symbol)
    124 {
    125     for (TVector<TLoopIndexInfo>::iterator i = mLoopIndexStack.begin();
    126          i != mLoopIndexStack.end();
    127          ++i) {
    128         if (i->id == symbol->getId())
    129             return i->currentValue;
    130     }
    131     UNREACHABLE();
    132     return false;
    133 }
    134 
    135 void ForLoopUnroll::Push(TLoopIndexInfo& info)
    136 {
    137     mLoopIndexStack.push_back(info);
    138 }
    139 
    140 void ForLoopUnroll::Pop()
    141 {
    142     mLoopIndexStack.pop_back();
    143 }
    144 
    145 // static
    146 void ForLoopUnroll::MarkForLoopsWithIntegerIndicesForUnrolling(
    147     TIntermNode* root)
    148 {
    149     ASSERT(root);
    150 
    151     IntegerForLoopUnrollMarker marker;
    152     root->traverse(&marker);
    153 }
    154 
    155 int ForLoopUnroll::getLoopIncrement(TIntermLoop* node)
    156 {
    157     TIntermNode* expr = node->getExpression();
    158     ASSERT(expr != NULL);
    159     // for expression has one of the following forms:
    160     //     loop_index++
    161     //     loop_index--
    162     //     loop_index += constant_expression
    163     //     loop_index -= constant_expression
    164     //     ++loop_index
    165     //     --loop_index
    166     // The last two forms are not specified in the spec, but I am assuming
    167     // its an oversight.
    168     TIntermUnary* unOp = expr->getAsUnaryNode();
    169     TIntermBinary* binOp = unOp ? NULL : expr->getAsBinaryNode();
    170 
    171     TOperator op = EOpNull;
    172     TIntermConstantUnion* incrementNode = NULL;
    173     if (unOp != NULL) {
    174         op = unOp->getOp();
    175     } else if (binOp != NULL) {
    176         op = binOp->getOp();
    177         ASSERT(binOp->getRight() != NULL);
    178         incrementNode = binOp->getRight()->getAsConstantUnion();
    179         ASSERT(incrementNode != NULL);
    180     }
    181 
    182     int increment = 0;
    183     // The operator is one of: ++ -- += -=.
    184     switch (op) {
    185         case EOpPostIncrement:
    186         case EOpPreIncrement:
    187             ASSERT((unOp != NULL) && (binOp == NULL));
    188             increment = 1;
    189             break;
    190         case EOpPostDecrement:
    191         case EOpPreDecrement:
    192             ASSERT((unOp != NULL) && (binOp == NULL));
    193             increment = -1;
    194             break;
    195         case EOpAddAssign:
    196             ASSERT((unOp == NULL) && (binOp != NULL));
    197             increment = evaluateIntConstant(incrementNode);
    198             break;
    199         case EOpSubAssign:
    200             ASSERT((unOp == NULL) && (binOp != NULL));
    201             increment = - evaluateIntConstant(incrementNode);
    202             break;
    203         default:
    204             ASSERT(false);
    205     }
    206 
    207     return increment;
    208 }
    209 
    210 int ForLoopUnroll::evaluateIntConstant(TIntermConstantUnion* node)
    211 {
    212     ASSERT((node != NULL) && (node->getUnionArrayPointer() != NULL));
    213     return node->getIConst(0);
    214 }
    215 
    216