Home | History | Annotate | Download | only in compiler
      1 //
      2 // Copyright (c) 2002-2010 The ANGLE Project Authors. All rights reserved.
      3 // Use of this source code is governed by a BSD-style license that can be
      4 // found in the LICENSE file.
      5 //
      6 
      7 #include "compiler/ValidateLimitations.h"
      8 #include "compiler/InfoSink.h"
      9 #include "compiler/InitializeParseContext.h"
     10 #include "compiler/ParseContext.h"
     11 
     12 namespace {
     13 bool IsLoopIndex(const TIntermSymbol* symbol, const TLoopStack& stack) {
     14     for (TLoopStack::const_iterator i = stack.begin(); i != stack.end(); ++i) {
     15         if (i->index.id == symbol->getId())
     16             return true;
     17     }
     18     return false;
     19 }
     20 
     21 void MarkLoopForUnroll(const TIntermSymbol* symbol, TLoopStack& stack) {
     22     for (TLoopStack::iterator i = stack.begin(); i != stack.end(); ++i) {
     23         if (i->index.id == symbol->getId()) {
     24             ASSERT(i->loop != NULL);
     25             i->loop->setUnrollFlag(true);
     26             return;
     27         }
     28     }
     29     UNREACHABLE();
     30 }
     31 
     32 // Traverses a node to check if it represents a constant index expression.
     33 // Definition:
     34 // constant-index-expressions are a superset of constant-expressions.
     35 // Constant-index-expressions can include loop indices as defined in
     36 // GLSL ES 1.0 spec, Appendix A, section 4.
     37 // The following are constant-index-expressions:
     38 // - Constant expressions
     39 // - Loop indices as defined in section 4
     40 // - Expressions composed of both of the above
     41 class ValidateConstIndexExpr : public TIntermTraverser {
     42 public:
     43     ValidateConstIndexExpr(const TLoopStack& stack)
     44         : mValid(true), mLoopStack(stack) {}
     45 
     46     // Returns true if the parsed node represents a constant index expression.
     47     bool isValid() const { return mValid; }
     48 
     49     virtual void visitSymbol(TIntermSymbol* symbol) {
     50         // Only constants and loop indices are allowed in a
     51         // constant index expression.
     52         if (mValid) {
     53             mValid = (symbol->getQualifier() == EvqConst) ||
     54                      IsLoopIndex(symbol, mLoopStack);
     55         }
     56     }
     57 
     58 private:
     59     bool mValid;
     60     const TLoopStack& mLoopStack;
     61 };
     62 
     63 // Traverses a node to check if it uses a loop index.
     64 // If an int loop index is used in its body as a sampler array index,
     65 // mark the loop for unroll.
     66 class ValidateLoopIndexExpr : public TIntermTraverser {
     67 public:
     68     ValidateLoopIndexExpr(TLoopStack& stack)
     69         : mUsesFloatLoopIndex(false),
     70           mUsesIntLoopIndex(false),
     71           mLoopStack(stack) {}
     72 
     73     bool usesFloatLoopIndex() const { return mUsesFloatLoopIndex; }
     74     bool usesIntLoopIndex() const { return mUsesIntLoopIndex; }
     75 
     76     virtual void visitSymbol(TIntermSymbol* symbol) {
     77         if (IsLoopIndex(symbol, mLoopStack)) {
     78             switch (symbol->getBasicType()) {
     79               case EbtFloat:
     80                 mUsesFloatLoopIndex = true;
     81                 break;
     82               case EbtInt:
     83                 mUsesIntLoopIndex = true;
     84                 MarkLoopForUnroll(symbol, mLoopStack);
     85                 break;
     86               default:
     87                 UNREACHABLE();
     88             }
     89         }
     90     }
     91 
     92 private:
     93     bool mUsesFloatLoopIndex;
     94     bool mUsesIntLoopIndex;
     95     TLoopStack& mLoopStack;
     96 };
     97 }  // namespace
     98 
     99 ValidateLimitations::ValidateLimitations(ShShaderType shaderType,
    100                                          TInfoSinkBase& sink)
    101     : mShaderType(shaderType),
    102       mSink(sink),
    103       mNumErrors(0)
    104 {
    105 }
    106 
    107 bool ValidateLimitations::visitBinary(Visit, TIntermBinary* node)
    108 {
    109     // Check if loop index is modified in the loop body.
    110     validateOperation(node, node->getLeft());
    111 
    112     // Check indexing.
    113     switch (node->getOp()) {
    114       case EOpIndexDirect:
    115         validateIndexing(node);
    116         break;
    117       case EOpIndexIndirect:
    118 #if defined(__APPLE__)
    119         // Loop unrolling is a work-around for a Mac Cg compiler bug where it
    120         // crashes when a sampler array's index is also the loop index.
    121         // Once Apple fixes this bug, we should remove the code in this CL.
    122         // See http://codereview.appspot.com/4331048/.
    123         if ((node->getLeft() != NULL) && (node->getRight() != NULL) &&
    124             (node->getLeft()->getAsSymbolNode())) {
    125             TIntermSymbol* symbol = node->getLeft()->getAsSymbolNode();
    126             if (IsSampler(symbol->getBasicType()) && symbol->isArray()) {
    127                 ValidateLoopIndexExpr validate(mLoopStack);
    128                 node->getRight()->traverse(&validate);
    129                 if (validate.usesFloatLoopIndex()) {
    130                     error(node->getLine(),
    131                           "sampler array index is float loop index",
    132                           "for");
    133                 }
    134             }
    135         }
    136 #endif
    137         validateIndexing(node);
    138         break;
    139       default: break;
    140     }
    141     return true;
    142 }
    143 
    144 bool ValidateLimitations::visitUnary(Visit, TIntermUnary* node)
    145 {
    146     // Check if loop index is modified in the loop body.
    147     validateOperation(node, node->getOperand());
    148 
    149     return true;
    150 }
    151 
    152 bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate* node)
    153 {
    154     switch (node->getOp()) {
    155       case EOpFunctionCall:
    156         validateFunctionCall(node);
    157         break;
    158       default:
    159         break;
    160     }
    161     return true;
    162 }
    163 
    164 bool ValidateLimitations::visitLoop(Visit, TIntermLoop* node)
    165 {
    166     if (!validateLoopType(node))
    167         return false;
    168 
    169     TLoopInfo info;
    170     memset(&info, 0, sizeof(TLoopInfo));
    171     info.loop = node;
    172     if (!validateForLoopHeader(node, &info))
    173         return false;
    174 
    175     TIntermNode* body = node->getBody();
    176     if (body != NULL) {
    177         mLoopStack.push_back(info);
    178         body->traverse(this);
    179         mLoopStack.pop_back();
    180     }
    181 
    182     // The loop is fully processed - no need to visit children.
    183     return false;
    184 }
    185 
    186 void ValidateLimitations::error(TSourceLoc loc,
    187                                 const char *reason, const char* token)
    188 {
    189     mSink.prefix(EPrefixError);
    190     mSink.location(loc);
    191     mSink << "'" << token << "' : " << reason << "\n";
    192     ++mNumErrors;
    193 }
    194 
    195 bool ValidateLimitations::withinLoopBody() const
    196 {
    197     return !mLoopStack.empty();
    198 }
    199 
    200 bool ValidateLimitations::isLoopIndex(const TIntermSymbol* symbol) const
    201 {
    202     return IsLoopIndex(symbol, mLoopStack);
    203 }
    204 
    205 bool ValidateLimitations::validateLoopType(TIntermLoop* node) {
    206     TLoopType type = node->getType();
    207     if (type == ELoopFor)
    208         return true;
    209 
    210     // Reject while and do-while loops.
    211     error(node->getLine(),
    212           "This type of loop is not allowed",
    213           type == ELoopWhile ? "while" : "do");
    214     return false;
    215 }
    216 
    217 bool ValidateLimitations::validateForLoopHeader(TIntermLoop* node,
    218                                                 TLoopInfo* info)
    219 {
    220     ASSERT(node->getType() == ELoopFor);
    221 
    222     //
    223     // The for statement has the form:
    224     //    for ( init-declaration ; condition ; expression ) statement
    225     //
    226     if (!validateForLoopInit(node, info))
    227         return false;
    228     if (!validateForLoopCond(node, info))
    229         return false;
    230     if (!validateForLoopExpr(node, info))
    231         return false;
    232 
    233     return true;
    234 }
    235 
    236 bool ValidateLimitations::validateForLoopInit(TIntermLoop* node,
    237                                               TLoopInfo* info)
    238 {
    239     TIntermNode* init = node->getInit();
    240     if (init == NULL) {
    241         error(node->getLine(), "Missing init declaration", "for");
    242         return false;
    243     }
    244 
    245     //
    246     // init-declaration has the form:
    247     //     type-specifier identifier = constant-expression
    248     //
    249     TIntermAggregate* decl = init->getAsAggregate();
    250     if ((decl == NULL) || (decl->getOp() != EOpDeclaration)) {
    251         error(init->getLine(), "Invalid init declaration", "for");
    252         return false;
    253     }
    254     // To keep things simple do not allow declaration list.
    255     TIntermSequence& declSeq = decl->getSequence();
    256     if (declSeq.size() != 1) {
    257         error(decl->getLine(), "Invalid init declaration", "for");
    258         return false;
    259     }
    260     TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
    261     if ((declInit == NULL) || (declInit->getOp() != EOpInitialize)) {
    262         error(decl->getLine(), "Invalid init declaration", "for");
    263         return false;
    264     }
    265     TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
    266     if (symbol == NULL) {
    267         error(declInit->getLine(), "Invalid init declaration", "for");
    268         return false;
    269     }
    270     // The loop index has type int or float.
    271     TBasicType type = symbol->getBasicType();
    272     if ((type != EbtInt) && (type != EbtFloat)) {
    273         error(symbol->getLine(),
    274               "Invalid type for loop index", getBasicString(type));
    275         return false;
    276     }
    277     // The loop index is initialized with constant expression.
    278     if (!isConstExpr(declInit->getRight())) {
    279         error(declInit->getLine(),
    280               "Loop index cannot be initialized with non-constant expression",
    281               symbol->getSymbol().c_str());
    282         return false;
    283     }
    284 
    285     info->index.id = symbol->getId();
    286     return true;
    287 }
    288 
    289 bool ValidateLimitations::validateForLoopCond(TIntermLoop* node,
    290                                               TLoopInfo* info)
    291 {
    292     TIntermNode* cond = node->getCondition();
    293     if (cond == NULL) {
    294         error(node->getLine(), "Missing condition", "for");
    295         return false;
    296     }
    297     //
    298     // condition has the form:
    299     //     loop_index relational_operator constant_expression
    300     //
    301     TIntermBinary* binOp = cond->getAsBinaryNode();
    302     if (binOp == NULL) {
    303         error(node->getLine(), "Invalid condition", "for");
    304         return false;
    305     }
    306     // Loop index should be to the left of relational operator.
    307     TIntermSymbol* symbol = binOp->getLeft()->getAsSymbolNode();
    308     if (symbol == NULL) {
    309         error(binOp->getLine(), "Invalid condition", "for");
    310         return false;
    311     }
    312     if (symbol->getId() != info->index.id) {
    313         error(symbol->getLine(),
    314               "Expected loop index", symbol->getSymbol().c_str());
    315         return false;
    316     }
    317     // Relational operator is one of: > >= < <= == or !=.
    318     switch (binOp->getOp()) {
    319       case EOpEqual:
    320       case EOpNotEqual:
    321       case EOpLessThan:
    322       case EOpGreaterThan:
    323       case EOpLessThanEqual:
    324       case EOpGreaterThanEqual:
    325         break;
    326       default:
    327         error(binOp->getLine(),
    328               "Invalid relational operator",
    329               getOperatorString(binOp->getOp()));
    330         break;
    331     }
    332     // Loop index must be compared with a constant.
    333     if (!isConstExpr(binOp->getRight())) {
    334         error(binOp->getLine(),
    335               "Loop index cannot be compared with non-constant expression",
    336               symbol->getSymbol().c_str());
    337         return false;
    338     }
    339 
    340     return true;
    341 }
    342 
    343 bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node,
    344                                               TLoopInfo* info)
    345 {
    346     TIntermNode* expr = node->getExpression();
    347     if (expr == NULL) {
    348         error(node->getLine(), "Missing expression", "for");
    349         return false;
    350     }
    351 
    352     // for expression has one of the following forms:
    353     //     loop_index++
    354     //     loop_index--
    355     //     loop_index += constant_expression
    356     //     loop_index -= constant_expression
    357     //     ++loop_index
    358     //     --loop_index
    359     // The last two forms are not specified in the spec, but I am assuming
    360     // its an oversight.
    361     TIntermUnary* unOp = expr->getAsUnaryNode();
    362     TIntermBinary* binOp = unOp ? NULL : expr->getAsBinaryNode();
    363 
    364     TOperator op = EOpNull;
    365     TIntermSymbol* symbol = NULL;
    366     if (unOp != NULL) {
    367         op = unOp->getOp();
    368         symbol = unOp->getOperand()->getAsSymbolNode();
    369     } else if (binOp != NULL) {
    370         op = binOp->getOp();
    371         symbol = binOp->getLeft()->getAsSymbolNode();
    372     }
    373 
    374     // The operand must be loop index.
    375     if (symbol == NULL) {
    376         error(expr->getLine(), "Invalid expression", "for");
    377         return false;
    378     }
    379     if (symbol->getId() != info->index.id) {
    380         error(symbol->getLine(),
    381               "Expected loop index", symbol->getSymbol().c_str());
    382         return false;
    383     }
    384 
    385     // The operator is one of: ++ -- += -=.
    386     switch (op) {
    387         case EOpPostIncrement:
    388         case EOpPostDecrement:
    389         case EOpPreIncrement:
    390         case EOpPreDecrement:
    391             ASSERT((unOp != NULL) && (binOp == NULL));
    392             break;
    393         case EOpAddAssign:
    394         case EOpSubAssign:
    395             ASSERT((unOp == NULL) && (binOp != NULL));
    396             break;
    397         default:
    398             error(expr->getLine(), "Invalid operator", getOperatorString(op));
    399             return false;
    400     }
    401 
    402     // Loop index must be incremented/decremented with a constant.
    403     if (binOp != NULL) {
    404         if (!isConstExpr(binOp->getRight())) {
    405             error(binOp->getLine(),
    406                   "Loop index cannot be modified by non-constant expression",
    407                   symbol->getSymbol().c_str());
    408             return false;
    409         }
    410     }
    411 
    412     return true;
    413 }
    414 
    415 bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node)
    416 {
    417     ASSERT(node->getOp() == EOpFunctionCall);
    418 
    419     // If not within loop body, there is nothing to check.
    420     if (!withinLoopBody())
    421         return true;
    422 
    423     // List of param indices for which loop indices are used as argument.
    424     typedef std::vector<size_t> ParamIndex;
    425     ParamIndex pIndex;
    426     TIntermSequence& params = node->getSequence();
    427     for (TIntermSequence::size_type i = 0; i < params.size(); ++i) {
    428         TIntermSymbol* symbol = params[i]->getAsSymbolNode();
    429         if (symbol && isLoopIndex(symbol))
    430             pIndex.push_back(i);
    431     }
    432     // If none of the loop indices are used as arguments,
    433     // there is nothing to check.
    434     if (pIndex.empty())
    435         return true;
    436 
    437     bool valid = true;
    438     TSymbolTable& symbolTable = GetGlobalParseContext()->symbolTable;
    439     TSymbol* symbol = symbolTable.find(node->getName());
    440     ASSERT(symbol && symbol->isFunction());
    441     TFunction* function = static_cast<TFunction*>(symbol);
    442     for (ParamIndex::const_iterator i = pIndex.begin();
    443          i != pIndex.end(); ++i) {
    444         const TParameter& param = function->getParam(*i);
    445         TQualifier qual = param.type->getQualifier();
    446         if ((qual == EvqOut) || (qual == EvqInOut)) {
    447             error(params[*i]->getLine(),
    448                   "Loop index cannot be used as argument to a function out or inout parameter",
    449                   params[*i]->getAsSymbolNode()->getSymbol().c_str());
    450             valid = false;
    451         }
    452     }
    453 
    454     return valid;
    455 }
    456 
    457 bool ValidateLimitations::validateOperation(TIntermOperator* node,
    458                                             TIntermNode* operand) {
    459     // Check if loop index is modified in the loop body.
    460     if (!withinLoopBody() || !node->isAssignment())
    461         return true;
    462 
    463     const TIntermSymbol* symbol = operand->getAsSymbolNode();
    464     if (symbol && isLoopIndex(symbol)) {
    465         error(node->getLine(),
    466               "Loop index cannot be statically assigned to within the body of the loop",
    467               symbol->getSymbol().c_str());
    468     }
    469     return true;
    470 }
    471 
    472 bool ValidateLimitations::isConstExpr(TIntermNode* node)
    473 {
    474     ASSERT(node != NULL);
    475     return node->getAsConstantUnion() != NULL;
    476 }
    477 
    478 bool ValidateLimitations::isConstIndexExpr(TIntermNode* node)
    479 {
    480     ASSERT(node != NULL);
    481 
    482     ValidateConstIndexExpr validate(mLoopStack);
    483     node->traverse(&validate);
    484     return validate.isValid();
    485 }
    486 
    487 bool ValidateLimitations::validateIndexing(TIntermBinary* node)
    488 {
    489     ASSERT((node->getOp() == EOpIndexDirect) ||
    490            (node->getOp() == EOpIndexIndirect));
    491 
    492     bool valid = true;
    493     TIntermTyped* index = node->getRight();
    494     // The index expression must have integral type.
    495     if (!index->isScalar() || (index->getBasicType() != EbtInt)) {
    496         error(index->getLine(),
    497               "Index expression must have integral type",
    498               index->getCompleteString().c_str());
    499         valid = false;
    500     }
    501     // The index expession must be a constant-index-expression unless
    502     // the operand is a uniform in a vertex shader.
    503     TIntermTyped* operand = node->getLeft();
    504     bool skip = (mShaderType == SH_VERTEX_SHADER) &&
    505                 (operand->getQualifier() == EvqUniform);
    506     if (!skip && !isConstIndexExpr(index)) {
    507         error(index->getLine(), "Index expression must be constant", "[]");
    508         valid = false;
    509     }
    510     return valid;
    511 }
    512 
    513