1 // Copyright 2016 The SwiftShader Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "ValidateLimitations.h" 16 #include "InfoSink.h" 17 #include "InitializeParseContext.h" 18 #include "ParseHelper.h" 19 20 namespace { 21 bool IsLoopIndex(const TIntermSymbol* symbol, const TLoopStack& stack) { 22 for (TLoopStack::const_iterator i = stack.begin(); i != stack.end(); ++i) { 23 if (i->index.id == symbol->getId()) 24 return true; 25 } 26 return false; 27 } 28 29 void MarkLoopForUnroll(const TIntermSymbol* symbol, TLoopStack& stack) { 30 for (TLoopStack::iterator i = stack.begin(); i != stack.end(); ++i) { 31 if (i->index.id == symbol->getId()) { 32 ASSERT(i->loop); 33 i->loop->setUnrollFlag(true); 34 return; 35 } 36 } 37 UNREACHABLE(0); 38 } 39 40 // Traverses a node to check if it represents a constant index expression. 41 // Definition: 42 // constant-index-expressions are a superset of constant-expressions. 43 // Constant-index-expressions can include loop indices as defined in 44 // GLSL ES 1.0 spec, Appendix A, section 4. 45 // The following are constant-index-expressions: 46 // - Constant expressions 47 // - Loop indices as defined in section 4 48 // - Expressions composed of both of the above 49 class ValidateConstIndexExpr : public TIntermTraverser { 50 public: 51 ValidateConstIndexExpr(const TLoopStack& stack) 52 : mValid(true), mLoopStack(stack) {} 53 54 // Returns true if the parsed node represents a constant index expression. 55 bool isValid() const { return mValid; } 56 57 virtual void visitSymbol(TIntermSymbol* symbol) { 58 // Only constants and loop indices are allowed in a 59 // constant index expression. 60 if (mValid) { 61 mValid = (symbol->getQualifier() == EvqConstExpr) || 62 IsLoopIndex(symbol, mLoopStack); 63 } 64 } 65 66 private: 67 bool mValid; 68 const TLoopStack& mLoopStack; 69 }; 70 71 // Traverses a node to check if it uses a loop index. 72 // If an int loop index is used in its body as a sampler array index, 73 // mark the loop for unroll. 74 class ValidateLoopIndexExpr : public TIntermTraverser { 75 public: 76 ValidateLoopIndexExpr(TLoopStack& stack) 77 : mUsesFloatLoopIndex(false), 78 mUsesIntLoopIndex(false), 79 mLoopStack(stack) {} 80 81 bool usesFloatLoopIndex() const { return mUsesFloatLoopIndex; } 82 bool usesIntLoopIndex() const { return mUsesIntLoopIndex; } 83 84 virtual void visitSymbol(TIntermSymbol* symbol) { 85 if (IsLoopIndex(symbol, mLoopStack)) { 86 switch (symbol->getBasicType()) { 87 case EbtFloat: 88 mUsesFloatLoopIndex = true; 89 break; 90 case EbtUInt: 91 mUsesIntLoopIndex = true; 92 MarkLoopForUnroll(symbol, mLoopStack); 93 break; 94 case EbtInt: 95 mUsesIntLoopIndex = true; 96 MarkLoopForUnroll(symbol, mLoopStack); 97 break; 98 default: 99 UNREACHABLE(symbol->getBasicType()); 100 } 101 } 102 } 103 104 private: 105 bool mUsesFloatLoopIndex; 106 bool mUsesIntLoopIndex; 107 TLoopStack& mLoopStack; 108 }; 109 } // namespace 110 111 ValidateLimitations::ValidateLimitations(GLenum shaderType, 112 TInfoSinkBase& sink) 113 : mShaderType(shaderType), 114 mSink(sink), 115 mNumErrors(0) 116 { 117 } 118 119 bool ValidateLimitations::visitBinary(Visit, TIntermBinary* node) 120 { 121 // Check if loop index is modified in the loop body. 122 validateOperation(node, node->getLeft()); 123 124 // Check indexing. 125 switch (node->getOp()) { 126 case EOpIndexDirect: 127 validateIndexing(node); 128 break; 129 case EOpIndexIndirect: 130 validateIndexing(node); 131 break; 132 default: break; 133 } 134 return true; 135 } 136 137 bool ValidateLimitations::visitUnary(Visit, TIntermUnary* node) 138 { 139 // Check if loop index is modified in the loop body. 140 validateOperation(node, node->getOperand()); 141 142 return true; 143 } 144 145 bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate* node) 146 { 147 switch (node->getOp()) { 148 case EOpFunctionCall: 149 validateFunctionCall(node); 150 break; 151 default: 152 break; 153 } 154 return true; 155 } 156 157 bool ValidateLimitations::visitLoop(Visit, TIntermLoop* node) 158 { 159 if (!validateLoopType(node)) 160 return false; 161 162 TLoopInfo info; 163 memset(&info, 0, sizeof(TLoopInfo)); 164 info.loop = node; 165 if (!validateForLoopHeader(node, &info)) 166 return false; 167 168 TIntermNode* body = node->getBody(); 169 if (body) { 170 mLoopStack.push_back(info); 171 body->traverse(this); 172 mLoopStack.pop_back(); 173 } 174 175 // The loop is fully processed - no need to visit children. 176 return false; 177 } 178 179 void ValidateLimitations::error(TSourceLoc loc, 180 const char *reason, const char* token) 181 { 182 mSink.prefix(EPrefixError); 183 mSink.location(loc); 184 mSink << "'" << token << "' : " << reason << "\n"; 185 ++mNumErrors; 186 } 187 188 bool ValidateLimitations::withinLoopBody() const 189 { 190 return !mLoopStack.empty(); 191 } 192 193 bool ValidateLimitations::isLoopIndex(const TIntermSymbol* symbol) const 194 { 195 return IsLoopIndex(symbol, mLoopStack); 196 } 197 198 bool ValidateLimitations::validateLoopType(TIntermLoop* node) { 199 TLoopType type = node->getType(); 200 if (type == ELoopFor) 201 return true; 202 203 // Reject while and do-while loops. 204 error(node->getLine(), 205 "This type of loop is not allowed", 206 type == ELoopWhile ? "while" : "do"); 207 return false; 208 } 209 210 bool ValidateLimitations::validateForLoopHeader(TIntermLoop* node, 211 TLoopInfo* info) 212 { 213 ASSERT(node->getType() == ELoopFor); 214 215 // 216 // The for statement has the form: 217 // for ( init-declaration ; condition ; expression ) statement 218 // 219 if (!validateForLoopInit(node, info)) 220 return false; 221 if (!validateForLoopCond(node, info)) 222 return false; 223 if (!validateForLoopExpr(node, info)) 224 return false; 225 226 return true; 227 } 228 229 bool ValidateLimitations::validateForLoopInit(TIntermLoop* node, 230 TLoopInfo* info) 231 { 232 TIntermNode* init = node->getInit(); 233 if (!init) { 234 error(node->getLine(), "Missing init declaration", "for"); 235 return false; 236 } 237 238 // 239 // init-declaration has the form: 240 // type-specifier identifier = constant-expression 241 // 242 TIntermAggregate* decl = init->getAsAggregate(); 243 if (!decl || (decl->getOp() != EOpDeclaration)) { 244 error(init->getLine(), "Invalid init declaration", "for"); 245 return false; 246 } 247 // To keep things simple do not allow declaration list. 248 TIntermSequence& declSeq = decl->getSequence(); 249 if (declSeq.size() != 1) { 250 error(decl->getLine(), "Invalid init declaration", "for"); 251 return false; 252 } 253 TIntermBinary* declInit = declSeq[0]->getAsBinaryNode(); 254 if (!declInit || (declInit->getOp() != EOpInitialize)) { 255 error(decl->getLine(), "Invalid init declaration", "for"); 256 return false; 257 } 258 TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode(); 259 if (!symbol) { 260 error(declInit->getLine(), "Invalid init declaration", "for"); 261 return false; 262 } 263 // The loop index has type int or float. 264 TBasicType type = symbol->getBasicType(); 265 if (!IsInteger(type) && (type != EbtFloat)) { 266 error(symbol->getLine(), 267 "Invalid type for loop index", getBasicString(type)); 268 return false; 269 } 270 // The loop index is initialized with constant expression. 271 if (!isConstExpr(declInit->getRight())) { 272 error(declInit->getLine(), 273 "Loop index cannot be initialized with non-constant expression", 274 symbol->getSymbol().c_str()); 275 return false; 276 } 277 278 info->index.id = symbol->getId(); 279 return true; 280 } 281 282 bool ValidateLimitations::validateForLoopCond(TIntermLoop* node, 283 TLoopInfo* info) 284 { 285 TIntermNode* cond = node->getCondition(); 286 if (!cond) { 287 error(node->getLine(), "Missing condition", "for"); 288 return false; 289 } 290 // 291 // condition has the form: 292 // loop_index relational_operator constant_expression 293 // 294 TIntermBinary* binOp = cond->getAsBinaryNode(); 295 if (!binOp) { 296 error(node->getLine(), "Invalid condition", "for"); 297 return false; 298 } 299 // Loop index should be to the left of relational operator. 300 TIntermSymbol* symbol = binOp->getLeft()->getAsSymbolNode(); 301 if (!symbol) { 302 error(binOp->getLine(), "Invalid condition", "for"); 303 return false; 304 } 305 if (symbol->getId() != info->index.id) { 306 error(symbol->getLine(), 307 "Expected loop index", symbol->getSymbol().c_str()); 308 return false; 309 } 310 // Relational operator is one of: > >= < <= == or !=. 311 switch (binOp->getOp()) { 312 case EOpEqual: 313 case EOpNotEqual: 314 case EOpLessThan: 315 case EOpGreaterThan: 316 case EOpLessThanEqual: 317 case EOpGreaterThanEqual: 318 break; 319 default: 320 error(binOp->getLine(), 321 "Invalid relational operator", 322 getOperatorString(binOp->getOp())); 323 break; 324 } 325 // Loop index must be compared with a constant. 326 if (!isConstExpr(binOp->getRight())) { 327 error(binOp->getLine(), 328 "Loop index cannot be compared with non-constant expression", 329 symbol->getSymbol().c_str()); 330 return false; 331 } 332 333 return true; 334 } 335 336 bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node, 337 TLoopInfo* info) 338 { 339 TIntermNode* expr = node->getExpression(); 340 if (!expr) { 341 error(node->getLine(), "Missing expression", "for"); 342 return false; 343 } 344 345 // for expression has one of the following forms: 346 // loop_index++ 347 // loop_index-- 348 // loop_index += constant_expression 349 // loop_index -= constant_expression 350 // ++loop_index 351 // --loop_index 352 // The last two forms are not specified in the spec, but I am assuming 353 // its an oversight. 354 TIntermUnary* unOp = expr->getAsUnaryNode(); 355 TIntermBinary* binOp = unOp ? nullptr : expr->getAsBinaryNode(); 356 357 TOperator op = EOpNull; 358 TIntermSymbol* symbol = nullptr; 359 if (unOp) { 360 op = unOp->getOp(); 361 symbol = unOp->getOperand()->getAsSymbolNode(); 362 } else if (binOp) { 363 op = binOp->getOp(); 364 symbol = binOp->getLeft()->getAsSymbolNode(); 365 } 366 367 // The operand must be loop index. 368 if (!symbol) { 369 error(expr->getLine(), "Invalid expression", "for"); 370 return false; 371 } 372 if (symbol->getId() != info->index.id) { 373 error(symbol->getLine(), 374 "Expected loop index", symbol->getSymbol().c_str()); 375 return false; 376 } 377 378 // The operator is one of: ++ -- += -=. 379 switch (op) { 380 case EOpPostIncrement: 381 case EOpPostDecrement: 382 case EOpPreIncrement: 383 case EOpPreDecrement: 384 ASSERT((unOp != NULL) && (binOp == NULL)); 385 break; 386 case EOpAddAssign: 387 case EOpSubAssign: 388 ASSERT((unOp == NULL) && (binOp != NULL)); 389 break; 390 default: 391 error(expr->getLine(), "Invalid operator", getOperatorString(op)); 392 return false; 393 } 394 395 // Loop index must be incremented/decremented with a constant. 396 if (binOp != NULL) { 397 if (!isConstExpr(binOp->getRight())) { 398 error(binOp->getLine(), 399 "Loop index cannot be modified by non-constant expression", 400 symbol->getSymbol().c_str()); 401 return false; 402 } 403 } 404 405 return true; 406 } 407 408 bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node) 409 { 410 ASSERT(node->getOp() == EOpFunctionCall); 411 412 // If not within loop body, there is nothing to check. 413 if (!withinLoopBody()) 414 return true; 415 416 // List of param indices for which loop indices are used as argument. 417 typedef std::vector<int> ParamIndex; 418 ParamIndex pIndex; 419 TIntermSequence& params = node->getSequence(); 420 for (TIntermSequence::size_type i = 0; i < params.size(); ++i) { 421 TIntermSymbol* symbol = params[i]->getAsSymbolNode(); 422 if (symbol && isLoopIndex(symbol)) 423 pIndex.push_back(i); 424 } 425 // If none of the loop indices are used as arguments, 426 // there is nothing to check. 427 if (pIndex.empty()) 428 return true; 429 430 bool valid = true; 431 TSymbolTable& symbolTable = GetGlobalParseContext()->symbolTable; 432 TSymbol* symbol = symbolTable.find(node->getName(), GetGlobalParseContext()->getShaderVersion()); 433 ASSERT(symbol && symbol->isFunction()); 434 TFunction* function = static_cast<TFunction*>(symbol); 435 for (ParamIndex::const_iterator i = pIndex.begin(); 436 i != pIndex.end(); ++i) { 437 const TParameter& param = function->getParam(*i); 438 TQualifier qual = param.type->getQualifier(); 439 if ((qual == EvqOut) || (qual == EvqInOut)) { 440 error(params[*i]->getLine(), 441 "Loop index cannot be used as argument to a function out or inout parameter", 442 params[*i]->getAsSymbolNode()->getSymbol().c_str()); 443 valid = false; 444 } 445 } 446 447 return valid; 448 } 449 450 bool ValidateLimitations::validateOperation(TIntermOperator* node, 451 TIntermNode* operand) { 452 // Check if loop index is modified in the loop body. 453 if (!withinLoopBody() || !node->modifiesState()) 454 return true; 455 456 const TIntermSymbol* symbol = operand->getAsSymbolNode(); 457 if (symbol && isLoopIndex(symbol)) { 458 error(node->getLine(), 459 "Loop index cannot be statically assigned to within the body of the loop", 460 symbol->getSymbol().c_str()); 461 } 462 return true; 463 } 464 465 bool ValidateLimitations::isConstExpr(TIntermNode* node) 466 { 467 ASSERT(node); 468 return node->getAsConstantUnion() != nullptr; 469 } 470 471 bool ValidateLimitations::isConstIndexExpr(TIntermNode* node) 472 { 473 ASSERT(node); 474 475 ValidateConstIndexExpr validate(mLoopStack); 476 node->traverse(&validate); 477 return validate.isValid(); 478 } 479 480 bool ValidateLimitations::validateIndexing(TIntermBinary* node) 481 { 482 ASSERT((node->getOp() == EOpIndexDirect) || 483 (node->getOp() == EOpIndexIndirect)); 484 485 bool valid = true; 486 TIntermTyped* index = node->getRight(); 487 // The index expression must have integral type. 488 if (!index->isScalarInt()) { 489 error(index->getLine(), 490 "Index expression must have integral type", 491 index->getCompleteString().c_str()); 492 valid = false; 493 } 494 // The index expession must be a constant-index-expression unless 495 // the operand is a uniform in a vertex shader. 496 TIntermTyped* operand = node->getLeft(); 497 bool skip = (mShaderType == GL_VERTEX_SHADER) && 498 (operand->getQualifier() == EvqUniform); 499 if (!skip && !isConstIndexExpr(index)) { 500 error(index->getLine(), "Index expression must be constant", "[]"); 501 valid = false; 502 } 503 return valid; 504 } 505 506