1 // 2 // Copyright (C) 2002-2005 3Dlabs Inc. Ltd. 3 // Copyright (C) 2012-2013 LunarG, Inc. 4 // 5 // All rights reserved. 6 // 7 // Redistribution and use in source and binary forms, with or without 8 // modification, are permitted provided that the following conditions 9 // are met: 10 // 11 // Redistributions of source code must retain the above copyright 12 // notice, this list of conditions and the following disclaimer. 13 // 14 // Redistributions in binary form must reproduce the above 15 // copyright notice, this list of conditions and the following 16 // disclaimer in the documentation and/or other materials provided 17 // with the distribution. 18 // 19 // Neither the name of 3Dlabs Inc. Ltd. nor the names of its 20 // contributors may be used to endorse or promote products derived 21 // from this software without specific prior written permission. 22 // 23 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 24 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 25 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 26 // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 27 // COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 28 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 29 // BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 30 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 32 // LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 33 // ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 34 // POSSIBILITY OF SUCH DAMAGE. 35 // 36 37 #include "localintermediate.h" 38 #include <cmath> 39 #include <cfloat> 40 #include <cstdlib> 41 42 namespace { 43 44 using namespace glslang; 45 46 typedef union { 47 double d; 48 int i[2]; 49 } DoubleIntUnion; 50 51 // Some helper functions 52 53 bool isNan(double x) 54 { 55 DoubleIntUnion u; 56 // tough to find a platform independent library function, do it directly 57 u.d = x; 58 int bitPatternL = u.i[0]; 59 int bitPatternH = u.i[1]; 60 return (bitPatternH & 0x7ff80000) == 0x7ff80000 && 61 ((bitPatternH & 0xFFFFF) != 0 || bitPatternL != 0); 62 } 63 64 bool isInf(double x) 65 { 66 DoubleIntUnion u; 67 // tough to find a platform independent library function, do it directly 68 u.d = x; 69 int bitPatternL = u.i[0]; 70 int bitPatternH = u.i[1]; 71 return (bitPatternH & 0x7ff00000) == 0x7ff00000 && 72 (bitPatternH & 0xFFFFF) == 0 && bitPatternL == 0; 73 } 74 75 const double pi = 3.1415926535897932384626433832795; 76 77 } // end anonymous namespace 78 79 80 namespace glslang { 81 82 // 83 // The fold functions see if an operation on a constant can be done in place, 84 // without generating run-time code. 85 // 86 // Returns the node to keep using, which may or may not be the node passed in. 87 // 88 // Note: As of version 1.2, all constant operations must be folded. It is 89 // not opportunistic, but rather a semantic requirement. 90 // 91 92 // 93 // Do folding between a pair of nodes. 94 // 'this' is the left-hand operand and 'rightConstantNode' is the right-hand operand. 95 // 96 // Returns a new node representing the result. 97 // 98 TIntermTyped* TIntermConstantUnion::fold(TOperator op, const TIntermTyped* rightConstantNode) const 99 { 100 // For most cases, the return type matches the argument type, so set that 101 // up and just code to exceptions below. 102 TType returnType; 103 returnType.shallowCopy(getType()); 104 105 // 106 // A pair of nodes is to be folded together 107 // 108 109 const TIntermConstantUnion *rightNode = rightConstantNode->getAsConstantUnion(); 110 TConstUnionArray leftUnionArray = getConstArray(); 111 TConstUnionArray rightUnionArray = rightNode->getConstArray(); 112 113 // Figure out the size of the result 114 int newComps; 115 int constComps; 116 switch(op) { 117 case EOpMatrixTimesMatrix: 118 newComps = rightNode->getMatrixCols() * getMatrixRows(); 119 break; 120 case EOpMatrixTimesVector: 121 newComps = getMatrixRows(); 122 break; 123 case EOpVectorTimesMatrix: 124 newComps = rightNode->getMatrixCols(); 125 break; 126 default: 127 newComps = getType().computeNumComponents(); 128 constComps = rightConstantNode->getType().computeNumComponents(); 129 if (constComps == 1 && newComps > 1) { 130 // for a case like vec4 f = vec4(2,3,4,5) + 1.2; 131 TConstUnionArray smearedArray(newComps, rightNode->getConstArray()[0]); 132 rightUnionArray = smearedArray; 133 } else if (constComps > 1 && newComps == 1) { 134 // for a case like vec4 f = 1.2 + vec4(2,3,4,5); 135 newComps = constComps; 136 rightUnionArray = rightNode->getConstArray(); 137 TConstUnionArray smearedArray(newComps, getConstArray()[0]); 138 leftUnionArray = smearedArray; 139 returnType.shallowCopy(rightNode->getType()); 140 } 141 break; 142 } 143 144 TConstUnionArray newConstArray(newComps); 145 TType constBool(EbtBool, EvqConst); 146 147 switch(op) { 148 case EOpAdd: 149 for (int i = 0; i < newComps; i++) 150 newConstArray[i] = leftUnionArray[i] + rightUnionArray[i]; 151 break; 152 case EOpSub: 153 for (int i = 0; i < newComps; i++) 154 newConstArray[i] = leftUnionArray[i] - rightUnionArray[i]; 155 break; 156 157 case EOpMul: 158 case EOpVectorTimesScalar: 159 case EOpMatrixTimesScalar: 160 for (int i = 0; i < newComps; i++) 161 newConstArray[i] = leftUnionArray[i] * rightUnionArray[i]; 162 break; 163 case EOpMatrixTimesMatrix: 164 for (int row = 0; row < getMatrixRows(); row++) { 165 for (int column = 0; column < rightNode->getMatrixCols(); column++) { 166 double sum = 0.0f; 167 for (int i = 0; i < rightNode->getMatrixRows(); i++) 168 sum += leftUnionArray[i * getMatrixRows() + row].getDConst() * rightUnionArray[column * rightNode->getMatrixRows() + i].getDConst(); 169 newConstArray[column * getMatrixRows() + row].setDConst(sum); 170 } 171 } 172 returnType.shallowCopy(TType(getType().getBasicType(), EvqConst, 0, rightNode->getMatrixCols(), getMatrixRows())); 173 break; 174 case EOpDiv: 175 for (int i = 0; i < newComps; i++) { 176 switch (getType().getBasicType()) { 177 case EbtDouble: 178 case EbtFloat: 179 #ifdef AMD_EXTENSIONS 180 case EbtFloat16: 181 #endif 182 newConstArray[i].setDConst(leftUnionArray[i].getDConst() / rightUnionArray[i].getDConst()); 183 break; 184 185 case EbtInt: 186 if (rightUnionArray[i] == 0) 187 newConstArray[i].setIConst(0x7FFFFFFF); 188 else if (rightUnionArray[i].getIConst() == -1 && leftUnionArray[i].getIConst() == (int)0x80000000) 189 newConstArray[i].setIConst(0x80000000); 190 else 191 newConstArray[i].setIConst(leftUnionArray[i].getIConst() / rightUnionArray[i].getIConst()); 192 break; 193 194 case EbtUint: 195 if (rightUnionArray[i] == 0) { 196 newConstArray[i].setUConst(0xFFFFFFFFu); 197 } else 198 newConstArray[i].setUConst(leftUnionArray[i].getUConst() / rightUnionArray[i].getUConst()); 199 break; 200 201 case EbtInt64: 202 if (rightUnionArray[i] == 0) 203 newConstArray[i].setI64Const(0x7FFFFFFFFFFFFFFFll); 204 else if (rightUnionArray[i].getI64Const() == -1 && leftUnionArray[i].getI64Const() == (long long)0x8000000000000000) 205 newConstArray[i].setI64Const(0x8000000000000000); 206 else 207 newConstArray[i].setI64Const(leftUnionArray[i].getI64Const() / rightUnionArray[i].getI64Const()); 208 break; 209 210 case EbtUint64: 211 if (rightUnionArray[i] == 0) { 212 newConstArray[i].setU64Const(0xFFFFFFFFFFFFFFFFull); 213 } else 214 newConstArray[i].setU64Const(leftUnionArray[i].getU64Const() / rightUnionArray[i].getU64Const()); 215 break; 216 #ifdef AMD_EXTENSIONS 217 case EbtInt16: 218 if (rightUnionArray[i] == 0) 219 newConstArray[i].setIConst(0x7FFF); 220 else if (rightUnionArray[i].getIConst() == -1 && leftUnionArray[i].getIConst() == (int)0x8000) 221 newConstArray[i].setIConst(0x8000); 222 else 223 newConstArray[i].setIConst(leftUnionArray[i].getIConst() / rightUnionArray[i].getIConst()); 224 break; 225 226 case EbtUint16: 227 if (rightUnionArray[i] == 0) { 228 newConstArray[i].setUConst(0xFFFFu); 229 } else 230 newConstArray[i].setUConst(leftUnionArray[i].getUConst() / rightUnionArray[i].getUConst()); 231 break; 232 #endif 233 default: 234 return 0; 235 } 236 } 237 break; 238 239 case EOpMatrixTimesVector: 240 for (int i = 0; i < getMatrixRows(); i++) { 241 double sum = 0.0f; 242 for (int j = 0; j < rightNode->getVectorSize(); j++) { 243 sum += leftUnionArray[j*getMatrixRows() + i].getDConst() * rightUnionArray[j].getDConst(); 244 } 245 newConstArray[i].setDConst(sum); 246 } 247 248 returnType.shallowCopy(TType(getBasicType(), EvqConst, getMatrixRows())); 249 break; 250 251 case EOpVectorTimesMatrix: 252 for (int i = 0; i < rightNode->getMatrixCols(); i++) { 253 double sum = 0.0f; 254 for (int j = 0; j < getVectorSize(); j++) 255 sum += leftUnionArray[j].getDConst() * rightUnionArray[i*rightNode->getMatrixRows() + j].getDConst(); 256 newConstArray[i].setDConst(sum); 257 } 258 259 returnType.shallowCopy(TType(getBasicType(), EvqConst, rightNode->getMatrixCols())); 260 break; 261 262 case EOpMod: 263 for (int i = 0; i < newComps; i++) { 264 if (rightUnionArray[i] == 0) 265 newConstArray[i] = leftUnionArray[i]; 266 else 267 newConstArray[i] = leftUnionArray[i] % rightUnionArray[i]; 268 } 269 break; 270 271 case EOpRightShift: 272 for (int i = 0; i < newComps; i++) 273 newConstArray[i] = leftUnionArray[i] >> rightUnionArray[i]; 274 break; 275 276 case EOpLeftShift: 277 for (int i = 0; i < newComps; i++) 278 newConstArray[i] = leftUnionArray[i] << rightUnionArray[i]; 279 break; 280 281 case EOpAnd: 282 for (int i = 0; i < newComps; i++) 283 newConstArray[i] = leftUnionArray[i] & rightUnionArray[i]; 284 break; 285 case EOpInclusiveOr: 286 for (int i = 0; i < newComps; i++) 287 newConstArray[i] = leftUnionArray[i] | rightUnionArray[i]; 288 break; 289 case EOpExclusiveOr: 290 for (int i = 0; i < newComps; i++) 291 newConstArray[i] = leftUnionArray[i] ^ rightUnionArray[i]; 292 break; 293 294 case EOpLogicalAnd: // this code is written for possible future use, will not get executed currently 295 for (int i = 0; i < newComps; i++) 296 newConstArray[i] = leftUnionArray[i] && rightUnionArray[i]; 297 break; 298 299 case EOpLogicalOr: // this code is written for possible future use, will not get executed currently 300 for (int i = 0; i < newComps; i++) 301 newConstArray[i] = leftUnionArray[i] || rightUnionArray[i]; 302 break; 303 304 case EOpLogicalXor: 305 for (int i = 0; i < newComps; i++) { 306 switch (getType().getBasicType()) { 307 case EbtBool: newConstArray[i].setBConst((leftUnionArray[i] == rightUnionArray[i]) ? false : true); break; 308 default: assert(false && "Default missing"); 309 } 310 } 311 break; 312 313 case EOpLessThan: 314 newConstArray[0].setBConst(leftUnionArray[0] < rightUnionArray[0]); 315 returnType.shallowCopy(constBool); 316 break; 317 case EOpGreaterThan: 318 newConstArray[0].setBConst(leftUnionArray[0] > rightUnionArray[0]); 319 returnType.shallowCopy(constBool); 320 break; 321 case EOpLessThanEqual: 322 newConstArray[0].setBConst(! (leftUnionArray[0] > rightUnionArray[0])); 323 returnType.shallowCopy(constBool); 324 break; 325 case EOpGreaterThanEqual: 326 newConstArray[0].setBConst(! (leftUnionArray[0] < rightUnionArray[0])); 327 returnType.shallowCopy(constBool); 328 break; 329 case EOpEqual: 330 newConstArray[0].setBConst(rightNode->getConstArray() == leftUnionArray); 331 returnType.shallowCopy(constBool); 332 break; 333 case EOpNotEqual: 334 newConstArray[0].setBConst(rightNode->getConstArray() != leftUnionArray); 335 returnType.shallowCopy(constBool); 336 break; 337 338 default: 339 return 0; 340 } 341 342 TIntermConstantUnion *newNode = new TIntermConstantUnion(newConstArray, returnType); 343 newNode->setLoc(getLoc()); 344 345 return newNode; 346 } 347 348 // 349 // Do single unary node folding 350 // 351 // Returns a new node representing the result. 352 // 353 TIntermTyped* TIntermConstantUnion::fold(TOperator op, const TType& returnType) const 354 { 355 // First, size the result, which is mostly the same as the argument's size, 356 // but not always, and classify what is componentwise. 357 // Also, eliminate cases that can't be compile-time constant. 358 int resultSize; 359 bool componentWise = true; 360 361 int objectSize = getType().computeNumComponents(); 362 switch (op) { 363 case EOpDeterminant: 364 case EOpAny: 365 case EOpAll: 366 case EOpLength: 367 componentWise = false; 368 resultSize = 1; 369 break; 370 371 case EOpEmitStreamVertex: 372 case EOpEndStreamPrimitive: 373 // These don't actually fold 374 return 0; 375 376 case EOpPackSnorm2x16: 377 case EOpPackUnorm2x16: 378 case EOpPackHalf2x16: 379 componentWise = false; 380 resultSize = 1; 381 break; 382 383 case EOpUnpackSnorm2x16: 384 case EOpUnpackUnorm2x16: 385 case EOpUnpackHalf2x16: 386 componentWise = false; 387 resultSize = 2; 388 break; 389 390 case EOpNormalize: 391 componentWise = false; 392 resultSize = objectSize; 393 break; 394 395 default: 396 resultSize = objectSize; 397 break; 398 } 399 400 // Set up for processing 401 TConstUnionArray newConstArray(resultSize); 402 const TConstUnionArray& unionArray = getConstArray(); 403 404 // Process non-component-wise operations 405 switch (op) { 406 case EOpLength: 407 case EOpNormalize: 408 { 409 double sum = 0; 410 for (int i = 0; i < objectSize; i++) 411 sum += unionArray[i].getDConst() * unionArray[i].getDConst(); 412 double length = sqrt(sum); 413 if (op == EOpLength) 414 newConstArray[0].setDConst(length); 415 else { 416 for (int i = 0; i < objectSize; i++) 417 newConstArray[i].setDConst(unionArray[i].getDConst() / length); 418 } 419 break; 420 } 421 422 case EOpAny: 423 { 424 bool result = false; 425 for (int i = 0; i < objectSize; i++) { 426 if (unionArray[i].getBConst()) 427 result = true; 428 } 429 newConstArray[0].setBConst(result); 430 break; 431 } 432 case EOpAll: 433 { 434 bool result = true; 435 for (int i = 0; i < objectSize; i++) { 436 if (! unionArray[i].getBConst()) 437 result = false; 438 } 439 newConstArray[0].setBConst(result); 440 break; 441 } 442 443 // TODO: 3.0 Functionality: unary constant folding: the rest of the ops have to be fleshed out 444 445 case EOpPackSnorm2x16: 446 case EOpPackUnorm2x16: 447 case EOpPackHalf2x16: 448 449 case EOpUnpackSnorm2x16: 450 case EOpUnpackUnorm2x16: 451 case EOpUnpackHalf2x16: 452 453 case EOpDeterminant: 454 case EOpMatrixInverse: 455 case EOpTranspose: 456 return 0; 457 458 default: 459 assert(componentWise); 460 break; 461 } 462 463 // Turn off the componentwise loop 464 if (! componentWise) 465 objectSize = 0; 466 467 // Process component-wise operations 468 for (int i = 0; i < objectSize; i++) { 469 switch (op) { 470 case EOpNegative: 471 switch (getType().getBasicType()) { 472 case EbtDouble: 473 #ifdef AMD_EXTENSIONS 474 case EbtFloat16: 475 #endif 476 case EbtFloat: newConstArray[i].setDConst(-unionArray[i].getDConst()); break; 477 #ifdef AMD_EXTENSIONS 478 case EbtInt16: 479 #endif 480 case EbtInt: newConstArray[i].setIConst(-unionArray[i].getIConst()); break; 481 #ifdef AMD_EXTENSIONS 482 case EbtUint16: 483 #endif 484 case EbtUint: newConstArray[i].setUConst(static_cast<unsigned int>(-static_cast<int>(unionArray[i].getUConst()))); break; 485 case EbtInt64: newConstArray[i].setI64Const(-unionArray[i].getI64Const()); break; 486 case EbtUint64: newConstArray[i].setU64Const(static_cast<unsigned long long>(-static_cast<long long>(unionArray[i].getU64Const()))); break; 487 default: 488 return 0; 489 } 490 break; 491 case EOpLogicalNot: 492 case EOpVectorLogicalNot: 493 switch (getType().getBasicType()) { 494 case EbtBool: newConstArray[i].setBConst(!unionArray[i].getBConst()); break; 495 default: 496 return 0; 497 } 498 break; 499 case EOpBitwiseNot: 500 newConstArray[i] = ~unionArray[i]; 501 break; 502 case EOpRadians: 503 newConstArray[i].setDConst(unionArray[i].getDConst() * pi / 180.0); 504 break; 505 case EOpDegrees: 506 newConstArray[i].setDConst(unionArray[i].getDConst() * 180.0 / pi); 507 break; 508 case EOpSin: 509 newConstArray[i].setDConst(sin(unionArray[i].getDConst())); 510 break; 511 case EOpCos: 512 newConstArray[i].setDConst(cos(unionArray[i].getDConst())); 513 break; 514 case EOpTan: 515 newConstArray[i].setDConst(tan(unionArray[i].getDConst())); 516 break; 517 case EOpAsin: 518 newConstArray[i].setDConst(asin(unionArray[i].getDConst())); 519 break; 520 case EOpAcos: 521 newConstArray[i].setDConst(acos(unionArray[i].getDConst())); 522 break; 523 case EOpAtan: 524 newConstArray[i].setDConst(atan(unionArray[i].getDConst())); 525 break; 526 527 case EOpDPdx: 528 case EOpDPdy: 529 case EOpFwidth: 530 case EOpDPdxFine: 531 case EOpDPdyFine: 532 case EOpFwidthFine: 533 case EOpDPdxCoarse: 534 case EOpDPdyCoarse: 535 case EOpFwidthCoarse: 536 // The derivatives are all mandated to create a constant 0. 537 newConstArray[i].setDConst(0.0); 538 break; 539 540 case EOpExp: 541 newConstArray[i].setDConst(exp(unionArray[i].getDConst())); 542 break; 543 case EOpLog: 544 newConstArray[i].setDConst(log(unionArray[i].getDConst())); 545 break; 546 case EOpExp2: 547 { 548 const double inv_log2_e = 0.69314718055994530941723212145818; 549 newConstArray[i].setDConst(exp(unionArray[i].getDConst() * inv_log2_e)); 550 break; 551 } 552 case EOpLog2: 553 { 554 const double log2_e = 1.4426950408889634073599246810019; 555 newConstArray[i].setDConst(log2_e * log(unionArray[i].getDConst())); 556 break; 557 } 558 case EOpSqrt: 559 newConstArray[i].setDConst(sqrt(unionArray[i].getDConst())); 560 break; 561 case EOpInverseSqrt: 562 newConstArray[i].setDConst(1.0 / sqrt(unionArray[i].getDConst())); 563 break; 564 565 case EOpAbs: 566 if (unionArray[i].getType() == EbtDouble) 567 newConstArray[i].setDConst(fabs(unionArray[i].getDConst())); 568 else if (unionArray[i].getType() == EbtInt) 569 newConstArray[i].setIConst(abs(unionArray[i].getIConst())); 570 else 571 newConstArray[i] = unionArray[i]; 572 break; 573 case EOpSign: 574 #define SIGN(X) (X == 0 ? 0 : (X < 0 ? -1 : 1)) 575 if (unionArray[i].getType() == EbtDouble) 576 newConstArray[i].setDConst(SIGN(unionArray[i].getDConst())); 577 else 578 newConstArray[i].setIConst(SIGN(unionArray[i].getIConst())); 579 break; 580 case EOpFloor: 581 newConstArray[i].setDConst(floor(unionArray[i].getDConst())); 582 break; 583 case EOpTrunc: 584 if (unionArray[i].getDConst() > 0) 585 newConstArray[i].setDConst(floor(unionArray[i].getDConst())); 586 else 587 newConstArray[i].setDConst(ceil(unionArray[i].getDConst())); 588 break; 589 case EOpRound: 590 newConstArray[i].setDConst(floor(0.5 + unionArray[i].getDConst())); 591 break; 592 case EOpRoundEven: 593 { 594 double flr = floor(unionArray[i].getDConst()); 595 bool even = flr / 2.0 == floor(flr / 2.0); 596 double rounded = even ? ceil(unionArray[i].getDConst() - 0.5) : floor(unionArray[i].getDConst() + 0.5); 597 newConstArray[i].setDConst(rounded); 598 break; 599 } 600 case EOpCeil: 601 newConstArray[i].setDConst(ceil(unionArray[i].getDConst())); 602 break; 603 case EOpFract: 604 { 605 double x = unionArray[i].getDConst(); 606 newConstArray[i].setDConst(x - floor(x)); 607 break; 608 } 609 610 case EOpIsNan: 611 { 612 newConstArray[i].setBConst(isNan(unionArray[i].getDConst())); 613 break; 614 } 615 case EOpIsInf: 616 { 617 newConstArray[i].setBConst(isInf(unionArray[i].getDConst())); 618 break; 619 } 620 621 // TODO: 3.0 Functionality: unary constant folding: the rest of the ops have to be fleshed out 622 623 case EOpSinh: 624 case EOpCosh: 625 case EOpTanh: 626 case EOpAsinh: 627 case EOpAcosh: 628 case EOpAtanh: 629 630 case EOpFloatBitsToInt: 631 case EOpFloatBitsToUint: 632 case EOpIntBitsToFloat: 633 case EOpUintBitsToFloat: 634 case EOpDoubleBitsToInt64: 635 case EOpDoubleBitsToUint64: 636 case EOpInt64BitsToDouble: 637 case EOpUint64BitsToDouble: 638 #ifdef AMD_EXTENSIONS 639 case EOpFloat16BitsToInt16: 640 case EOpFloat16BitsToUint16: 641 case EOpInt16BitsToFloat16: 642 case EOpUint16BitsToFloat16: 643 #endif 644 645 default: 646 return 0; 647 } 648 } 649 650 TIntermConstantUnion *newNode = new TIntermConstantUnion(newConstArray, returnType); 651 newNode->getWritableType().getQualifier().storage = EvqConst; 652 newNode->setLoc(getLoc()); 653 654 return newNode; 655 } 656 657 // 658 // Do constant folding for an aggregate node that has all its children 659 // as constants and an operator that requires constant folding. 660 // 661 TIntermTyped* TIntermediate::fold(TIntermAggregate* aggrNode) 662 { 663 if (aggrNode == nullptr) 664 return aggrNode; 665 666 if (! areAllChildConst(aggrNode)) 667 return aggrNode; 668 669 if (aggrNode->isConstructor()) 670 return foldConstructor(aggrNode); 671 672 TIntermSequence& children = aggrNode->getSequence(); 673 674 // First, see if this is an operation to constant fold, kick out if not, 675 // see what size the result is if so. 676 677 bool componentwise = false; // will also say componentwise if a scalar argument gets repeated to make per-component results 678 int objectSize; 679 switch (aggrNode->getOp()) { 680 case EOpAtan: 681 case EOpPow: 682 case EOpMin: 683 case EOpMax: 684 case EOpMix: 685 case EOpClamp: 686 case EOpLessThan: 687 case EOpGreaterThan: 688 case EOpLessThanEqual: 689 case EOpGreaterThanEqual: 690 case EOpVectorEqual: 691 case EOpVectorNotEqual: 692 componentwise = true; 693 objectSize = children[0]->getAsConstantUnion()->getType().computeNumComponents(); 694 break; 695 case EOpCross: 696 case EOpReflect: 697 case EOpRefract: 698 case EOpFaceForward: 699 objectSize = children[0]->getAsConstantUnion()->getType().computeNumComponents(); 700 break; 701 case EOpDistance: 702 case EOpDot: 703 objectSize = 1; 704 break; 705 case EOpOuterProduct: 706 objectSize = children[0]->getAsTyped()->getType().getVectorSize() * 707 children[1]->getAsTyped()->getType().getVectorSize(); 708 break; 709 case EOpStep: 710 componentwise = true; 711 objectSize = std::max(children[0]->getAsTyped()->getType().getVectorSize(), 712 children[1]->getAsTyped()->getType().getVectorSize()); 713 break; 714 case EOpSmoothStep: 715 componentwise = true; 716 objectSize = std::max(children[0]->getAsTyped()->getType().getVectorSize(), 717 children[2]->getAsTyped()->getType().getVectorSize()); 718 break; 719 default: 720 return aggrNode; 721 } 722 TConstUnionArray newConstArray(objectSize); 723 724 TVector<TConstUnionArray> childConstUnions; 725 for (unsigned int arg = 0; arg < children.size(); ++arg) 726 childConstUnions.push_back(children[arg]->getAsConstantUnion()->getConstArray()); 727 728 // Second, do the actual folding 729 730 bool isFloatingPoint = children[0]->getAsTyped()->getBasicType() == EbtFloat || 731 #ifdef AMD_EXTENSIONS 732 children[0]->getAsTyped()->getBasicType() == EbtFloat16 || 733 #endif 734 children[0]->getAsTyped()->getBasicType() == EbtDouble; 735 bool isSigned = children[0]->getAsTyped()->getBasicType() == EbtInt || 736 #ifdef AMD_EXTENSIONS 737 children[0]->getAsTyped()->getBasicType() == EbtInt16 || 738 #endif 739 children[0]->getAsTyped()->getBasicType() == EbtInt64; 740 bool isInt64 = children[0]->getAsTyped()->getBasicType() == EbtInt64 || 741 children[0]->getAsTyped()->getBasicType() == EbtUint64; 742 if (componentwise) { 743 for (int comp = 0; comp < objectSize; comp++) { 744 745 // some arguments are scalars instead of matching vectors; simulate a smear 746 int arg0comp = std::min(comp, children[0]->getAsTyped()->getType().getVectorSize() - 1); 747 int arg1comp = 0; 748 if (children.size() > 1) 749 arg1comp = std::min(comp, children[1]->getAsTyped()->getType().getVectorSize() - 1); 750 int arg2comp = 0; 751 if (children.size() > 2) 752 arg2comp = std::min(comp, children[2]->getAsTyped()->getType().getVectorSize() - 1); 753 754 switch (aggrNode->getOp()) { 755 case EOpAtan: 756 newConstArray[comp].setDConst(atan2(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst())); 757 break; 758 case EOpPow: 759 newConstArray[comp].setDConst(pow(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst())); 760 break; 761 case EOpMin: 762 if (isFloatingPoint) 763 newConstArray[comp].setDConst(std::min(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst())); 764 else if (isSigned) { 765 if (isInt64) 766 newConstArray[comp].setI64Const(std::min(childConstUnions[0][arg0comp].getI64Const(), childConstUnions[1][arg1comp].getI64Const())); 767 else 768 newConstArray[comp].setIConst(std::min(childConstUnions[0][arg0comp].getIConst(), childConstUnions[1][arg1comp].getIConst())); 769 } else { 770 if (isInt64) 771 newConstArray[comp].setU64Const(std::min(childConstUnions[0][arg0comp].getU64Const(), childConstUnions[1][arg1comp].getU64Const())); 772 else 773 newConstArray[comp].setUConst(std::min(childConstUnions[0][arg0comp].getUConst(), childConstUnions[1][arg1comp].getUConst())); 774 } 775 break; 776 case EOpMax: 777 if (isFloatingPoint) 778 newConstArray[comp].setDConst(std::max(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst())); 779 else if (isSigned) { 780 if (isInt64) 781 newConstArray[comp].setI64Const(std::max(childConstUnions[0][arg0comp].getI64Const(), childConstUnions[1][arg1comp].getI64Const())); 782 else 783 newConstArray[comp].setIConst(std::max(childConstUnions[0][arg0comp].getIConst(), childConstUnions[1][arg1comp].getIConst())); 784 } else { 785 if (isInt64) 786 newConstArray[comp].setU64Const(std::max(childConstUnions[0][arg0comp].getU64Const(), childConstUnions[1][arg1comp].getU64Const())); 787 else 788 newConstArray[comp].setUConst(std::max(childConstUnions[0][arg0comp].getUConst(), childConstUnions[1][arg1comp].getUConst())); 789 } 790 break; 791 case EOpClamp: 792 if (isFloatingPoint) 793 newConstArray[comp].setDConst(std::min(std::max(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst()), 794 childConstUnions[2][arg2comp].getDConst())); 795 else if (isSigned) { 796 if (isInt64) 797 newConstArray[comp].setI64Const(std::min(std::max(childConstUnions[0][arg0comp].getI64Const(), childConstUnions[1][arg1comp].getI64Const()), 798 childConstUnions[2][arg2comp].getI64Const())); 799 else 800 newConstArray[comp].setIConst(std::min(std::max(childConstUnions[0][arg0comp].getIConst(), childConstUnions[1][arg1comp].getIConst()), 801 childConstUnions[2][arg2comp].getIConst())); 802 } else { 803 if (isInt64) 804 newConstArray[comp].setU64Const(std::min(std::max(childConstUnions[0][arg0comp].getU64Const(), childConstUnions[1][arg1comp].getU64Const()), 805 childConstUnions[2][arg2comp].getU64Const())); 806 else 807 newConstArray[comp].setUConst(std::min(std::max(childConstUnions[0][arg0comp].getUConst(), childConstUnions[1][arg1comp].getUConst()), 808 childConstUnions[2][arg2comp].getUConst())); 809 } 810 break; 811 case EOpLessThan: 812 newConstArray[comp].setBConst(childConstUnions[0][arg0comp] < childConstUnions[1][arg1comp]); 813 break; 814 case EOpGreaterThan: 815 newConstArray[comp].setBConst(childConstUnions[0][arg0comp] > childConstUnions[1][arg1comp]); 816 break; 817 case EOpLessThanEqual: 818 newConstArray[comp].setBConst(! (childConstUnions[0][arg0comp] > childConstUnions[1][arg1comp])); 819 break; 820 case EOpGreaterThanEqual: 821 newConstArray[comp].setBConst(! (childConstUnions[0][arg0comp] < childConstUnions[1][arg1comp])); 822 break; 823 case EOpVectorEqual: 824 newConstArray[comp].setBConst(childConstUnions[0][arg0comp] == childConstUnions[1][arg1comp]); 825 break; 826 case EOpVectorNotEqual: 827 newConstArray[comp].setBConst(childConstUnions[0][arg0comp] != childConstUnions[1][arg1comp]); 828 break; 829 case EOpMix: 830 if (children[2]->getAsTyped()->getBasicType() == EbtBool) 831 newConstArray[comp].setDConst(childConstUnions[2][arg2comp].getBConst() ? childConstUnions[1][arg1comp].getDConst() : 832 childConstUnions[0][arg0comp].getDConst()); 833 else 834 newConstArray[comp].setDConst(childConstUnions[0][arg0comp].getDConst() * (1.0 - childConstUnions[2][arg2comp].getDConst()) + 835 childConstUnions[1][arg1comp].getDConst() * childConstUnions[2][arg2comp].getDConst()); 836 break; 837 case EOpStep: 838 newConstArray[comp].setDConst(childConstUnions[1][arg1comp].getDConst() < childConstUnions[0][arg0comp].getDConst() ? 0.0 : 1.0); 839 break; 840 case EOpSmoothStep: 841 { 842 double t = (childConstUnions[2][arg2comp].getDConst() - childConstUnions[0][arg0comp].getDConst()) / 843 (childConstUnions[1][arg1comp].getDConst() - childConstUnions[0][arg0comp].getDConst()); 844 if (t < 0.0) 845 t = 0.0; 846 if (t > 1.0) 847 t = 1.0; 848 newConstArray[comp].setDConst(t * t * (3.0 - 2.0 * t)); 849 break; 850 } 851 default: 852 return aggrNode; 853 } 854 } 855 } else { 856 // Non-componentwise... 857 858 int numComps = children[0]->getAsConstantUnion()->getType().computeNumComponents(); 859 double dot; 860 861 switch (aggrNode->getOp()) { 862 case EOpDistance: 863 { 864 double sum = 0.0; 865 for (int comp = 0; comp < numComps; ++comp) { 866 double diff = childConstUnions[1][comp].getDConst() - childConstUnions[0][comp].getDConst(); 867 sum += diff * diff; 868 } 869 newConstArray[0].setDConst(sqrt(sum)); 870 break; 871 } 872 case EOpDot: 873 newConstArray[0].setDConst(childConstUnions[0].dot(childConstUnions[1])); 874 break; 875 case EOpCross: 876 newConstArray[0] = childConstUnions[0][1] * childConstUnions[1][2] - childConstUnions[0][2] * childConstUnions[1][1]; 877 newConstArray[1] = childConstUnions[0][2] * childConstUnions[1][0] - childConstUnions[0][0] * childConstUnions[1][2]; 878 newConstArray[2] = childConstUnions[0][0] * childConstUnions[1][1] - childConstUnions[0][1] * childConstUnions[1][0]; 879 break; 880 case EOpFaceForward: 881 // If dot(Nref, I) < 0 return N, otherwise return -N: Arguments are (N, I, Nref). 882 dot = childConstUnions[1].dot(childConstUnions[2]); 883 for (int comp = 0; comp < numComps; ++comp) { 884 if (dot < 0.0) 885 newConstArray[comp] = childConstUnions[0][comp]; 886 else 887 newConstArray[comp].setDConst(-childConstUnions[0][comp].getDConst()); 888 } 889 break; 890 case EOpReflect: 891 // I - 2 * dot(N, I) * N: Arguments are (I, N). 892 dot = childConstUnions[0].dot(childConstUnions[1]); 893 dot *= 2.0; 894 for (int comp = 0; comp < numComps; ++comp) 895 newConstArray[comp].setDConst(childConstUnions[0][comp].getDConst() - dot * childConstUnions[1][comp].getDConst()); 896 break; 897 case EOpRefract: 898 { 899 // Arguments are (I, N, eta). 900 // k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I)) 901 // if (k < 0.0) 902 // return dvec(0.0) 903 // else 904 // return eta * I - (eta * dot(N, I) + sqrt(k)) * N 905 dot = childConstUnions[0].dot(childConstUnions[1]); 906 double eta = childConstUnions[2][0].getDConst(); 907 double k = 1.0 - eta * eta * (1.0 - dot * dot); 908 if (k < 0.0) { 909 for (int comp = 0; comp < numComps; ++comp) 910 newConstArray[comp].setDConst(0.0); 911 } else { 912 for (int comp = 0; comp < numComps; ++comp) 913 newConstArray[comp].setDConst(eta * childConstUnions[0][comp].getDConst() - (eta * dot + sqrt(k)) * childConstUnions[1][comp].getDConst()); 914 } 915 break; 916 } 917 case EOpOuterProduct: 918 { 919 int numRows = numComps; 920 int numCols = children[1]->getAsConstantUnion()->getType().computeNumComponents(); 921 for (int row = 0; row < numRows; ++row) 922 for (int col = 0; col < numCols; ++col) 923 newConstArray[col * numRows + row] = childConstUnions[0][row] * childConstUnions[1][col]; 924 break; 925 } 926 default: 927 return aggrNode; 928 } 929 } 930 931 TIntermConstantUnion *newNode = new TIntermConstantUnion(newConstArray, aggrNode->getType()); 932 newNode->getWritableType().getQualifier().storage = EvqConst; 933 newNode->setLoc(aggrNode->getLoc()); 934 935 return newNode; 936 } 937 938 bool TIntermediate::areAllChildConst(TIntermAggregate* aggrNode) 939 { 940 bool allConstant = true; 941 942 // check if all the child nodes are constants so that they can be inserted into 943 // the parent node 944 if (aggrNode) { 945 TIntermSequence& childSequenceVector = aggrNode->getSequence(); 946 for (TIntermSequence::iterator p = childSequenceVector.begin(); 947 p != childSequenceVector.end(); p++) { 948 if (!(*p)->getAsTyped()->getAsConstantUnion()) 949 return false; 950 } 951 } 952 953 return allConstant; 954 } 955 956 TIntermTyped* TIntermediate::foldConstructor(TIntermAggregate* aggrNode) 957 { 958 bool error = false; 959 960 TConstUnionArray unionArray(aggrNode->getType().computeNumComponents()); 961 if (aggrNode->getSequence().size() == 1) 962 error = parseConstTree(aggrNode, unionArray, aggrNode->getOp(), aggrNode->getType(), true); 963 else 964 error = parseConstTree(aggrNode, unionArray, aggrNode->getOp(), aggrNode->getType()); 965 966 if (error) 967 return aggrNode; 968 969 return addConstantUnion(unionArray, aggrNode->getType(), aggrNode->getLoc()); 970 } 971 972 // 973 // Constant folding of a bracket (array-style) dereference or struct-like dot 974 // dereference. Can handle anything except a multi-character swizzle, though 975 // all swizzles may go to foldSwizzle(). 976 // 977 TIntermTyped* TIntermediate::foldDereference(TIntermTyped* node, int index, const TSourceLoc& loc) 978 { 979 TType dereferencedType(node->getType(), index); 980 dereferencedType.getQualifier().storage = EvqConst; 981 TIntermTyped* result = 0; 982 int size = dereferencedType.computeNumComponents(); 983 984 // arrays, vectors, matrices, all use simple multiplicative math 985 // while structures need to add up heterogeneous members 986 int start; 987 if (node->isArray() || ! node->isStruct()) 988 start = size * index; 989 else { 990 // it is a structure 991 assert(node->isStruct()); 992 start = 0; 993 for (int i = 0; i < index; ++i) 994 start += (*node->getType().getStruct())[i].type->computeNumComponents(); 995 } 996 997 result = addConstantUnion(TConstUnionArray(node->getAsConstantUnion()->getConstArray(), start, size), node->getType(), loc); 998 999 if (result == 0) 1000 result = node; 1001 else 1002 result->setType(dereferencedType); 1003 1004 return result; 1005 } 1006 1007 // 1008 // Make a constant vector node or constant scalar node, representing a given 1009 // constant vector and constant swizzle into it. 1010 // 1011 TIntermTyped* TIntermediate::foldSwizzle(TIntermTyped* node, TSwizzleSelectors<TVectorSelector>& selectors, const TSourceLoc& loc) 1012 { 1013 const TConstUnionArray& unionArray = node->getAsConstantUnion()->getConstArray(); 1014 TConstUnionArray constArray(selectors.size()); 1015 1016 for (int i = 0; i < selectors.size(); i++) 1017 constArray[i] = unionArray[selectors[i]]; 1018 1019 TIntermTyped* result = addConstantUnion(constArray, node->getType(), loc); 1020 1021 if (result == 0) 1022 result = node; 1023 else 1024 result->setType(TType(node->getBasicType(), EvqConst, selectors.size())); 1025 1026 return result; 1027 } 1028 1029 } // end namespace glslang 1030