Home | History | Annotate | Download | only in MachineIndependent
      1 //
      2 //Copyright (C) 2002-2005  3Dlabs Inc. Ltd.
      3 //Copyright (C) 2012-2013 LunarG, Inc.
      4 //
      5 //All rights reserved.
      6 //
      7 //Redistribution and use in source and binary forms, with or without
      8 //modification, are permitted provided that the following conditions
      9 //are met:
     10 //
     11 //    Redistributions of source code must retain the above copyright
     12 //    notice, this list of conditions and the following disclaimer.
     13 //
     14 //    Redistributions in binary form must reproduce the above
     15 //    copyright notice, this list of conditions and the following
     16 //    disclaimer in the documentation and/or other materials provided
     17 //    with the distribution.
     18 //
     19 //    Neither the name of 3Dlabs Inc. Ltd. nor the names of its
     20 //    contributors may be used to endorse or promote products derived
     21 //    from this software without specific prior written permission.
     22 //
     23 //THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     24 //"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     25 //LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
     26 //FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
     27 //COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
     28 //INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
     29 //BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
     30 //LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
     31 //CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
     32 //LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
     33 //ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
     34 //POSSIBILITY OF SUCH DAMAGE.
     35 //
     36 
     37 #include "localintermediate.h"
     38 #include <cmath>
     39 #include <cfloat>
     40 #include <cstdlib>
     41 
     42 namespace {
     43 
     44 using namespace glslang;
     45 
     46 typedef union {
     47     double d;
     48     int i[2];
     49 } DoubleIntUnion;
     50 
     51 // Some helper functions
     52 
     53 bool isNan(double x)
     54 {
     55     DoubleIntUnion u;
     56     // tough to find a platform independent library function, do it directly
     57     u.d = x;
     58     int bitPatternL = u.i[0];
     59     int bitPatternH = u.i[1];
     60     return (bitPatternH & 0x7ff80000) == 0x7ff80000 &&
     61            ((bitPatternH & 0xFFFFF) != 0 || bitPatternL != 0);
     62 }
     63 
     64 bool isInf(double x)
     65 {
     66     DoubleIntUnion u;
     67     // tough to find a platform independent library function, do it directly
     68     u.d = x;
     69     int bitPatternL = u.i[0];
     70     int bitPatternH = u.i[1];
     71     return (bitPatternH & 0x7ff00000) == 0x7ff00000 &&
     72            (bitPatternH & 0xFFFFF) == 0 && bitPatternL == 0;
     73 }
     74 
     75 const double pi = 3.1415926535897932384626433832795;
     76 
     77 } // end anonymous namespace
     78 
     79 
     80 namespace glslang {
     81 
     82 //
     83 // The fold functions see if an operation on a constant can be done in place,
     84 // without generating run-time code.
     85 //
     86 // Returns the node to keep using, which may or may not be the node passed in.
     87 //
     88 // Note: As of version 1.2, all constant operations must be folded.  It is
     89 // not opportunistic, but rather a semantic requirement.
     90 //
     91 
     92 //
     93 // Do folding between a pair of nodes.
     94 // 'this' is the left-hand operand and 'rightConstantNode' is the right-hand operand.
     95 //
     96 // Returns a new node representing the result.
     97 //
     98 TIntermTyped* TIntermConstantUnion::fold(TOperator op, const TIntermTyped* rightConstantNode) const
     99 {
    100     // For most cases, the return type matches the argument type, so set that
    101     // up and just code to exceptions below.
    102     TType returnType;
    103     returnType.shallowCopy(getType());
    104 
    105     //
    106     // A pair of nodes is to be folded together
    107     //
    108 
    109     const TIntermConstantUnion *rightNode = rightConstantNode->getAsConstantUnion();
    110     TConstUnionArray leftUnionArray = getConstArray();
    111     TConstUnionArray rightUnionArray = rightNode->getConstArray();
    112 
    113     // Figure out the size of the result
    114     int newComps;
    115     int constComps;
    116     switch(op) {
    117     case EOpMatrixTimesMatrix:
    118         newComps = rightNode->getMatrixCols() * getMatrixRows();
    119         break;
    120     case EOpMatrixTimesVector:
    121         newComps = getMatrixRows();
    122         break;
    123     case EOpVectorTimesMatrix:
    124         newComps = rightNode->getMatrixCols();
    125         break;
    126     default:
    127         newComps = getType().computeNumComponents();
    128         constComps = rightConstantNode->getType().computeNumComponents();
    129         if (constComps == 1 && newComps > 1) {
    130             // for a case like vec4 f = vec4(2,3,4,5) + 1.2;
    131             TConstUnionArray smearedArray(newComps, rightNode->getConstArray()[0]);
    132             rightUnionArray = smearedArray;
    133         } else if (constComps > 1 && newComps == 1) {
    134             // for a case like vec4 f = 1.2 + vec4(2,3,4,5);
    135             newComps = constComps;
    136             rightUnionArray = rightNode->getConstArray();
    137             TConstUnionArray smearedArray(newComps, getConstArray()[0]);
    138             leftUnionArray = smearedArray;
    139             returnType.shallowCopy(rightNode->getType());
    140         }
    141         break;
    142     }
    143 
    144     TConstUnionArray newConstArray(newComps);
    145     TType constBool(EbtBool, EvqConst);
    146 
    147     switch(op) {
    148     case EOpAdd:
    149         for (int i = 0; i < newComps; i++)
    150             newConstArray[i] = leftUnionArray[i] + rightUnionArray[i];
    151         break;
    152     case EOpSub:
    153         for (int i = 0; i < newComps; i++)
    154             newConstArray[i] = leftUnionArray[i] - rightUnionArray[i];
    155         break;
    156 
    157     case EOpMul:
    158     case EOpVectorTimesScalar:
    159     case EOpMatrixTimesScalar:
    160         for (int i = 0; i < newComps; i++)
    161             newConstArray[i] = leftUnionArray[i] * rightUnionArray[i];
    162         break;
    163     case EOpMatrixTimesMatrix:
    164         for (int row = 0; row < getMatrixRows(); row++) {
    165             for (int column = 0; column < rightNode->getMatrixCols(); column++) {
    166                 double sum = 0.0f;
    167                 for (int i = 0; i < rightNode->getMatrixRows(); i++)
    168                     sum += leftUnionArray[i * getMatrixRows() + row].getDConst() * rightUnionArray[column * rightNode->getMatrixRows() + i].getDConst();
    169                 newConstArray[column * getMatrixRows() + row].setDConst(sum);
    170             }
    171         }
    172         returnType.shallowCopy(TType(getType().getBasicType(), EvqConst, 0, rightNode->getMatrixCols(), getMatrixRows()));
    173         break;
    174     case EOpDiv:
    175         for (int i = 0; i < newComps; i++) {
    176             switch (getType().getBasicType()) {
    177             case EbtDouble:
    178             case EbtFloat:
    179                 newConstArray[i].setDConst(leftUnionArray[i].getDConst() / rightUnionArray[i].getDConst());
    180                 break;
    181 
    182             case EbtInt:
    183                 if (rightUnionArray[i] == 0)
    184                     newConstArray[i].setIConst(0x7FFFFFFF);
    185                 else if (rightUnionArray[i].getIConst() == -1 && leftUnionArray[i].getIConst() == (int)0x80000000)
    186                     newConstArray[i].setIConst(0x80000000);
    187                 else
    188                     newConstArray[i].setIConst(leftUnionArray[i].getIConst() / rightUnionArray[i].getIConst());
    189                 break;
    190 
    191             case EbtUint:
    192                 if (rightUnionArray[i] == 0) {
    193                     newConstArray[i].setUConst(0xFFFFFFFF);
    194                 } else
    195                     newConstArray[i].setUConst(leftUnionArray[i].getUConst() / rightUnionArray[i].getUConst());
    196                 break;
    197 
    198             case EbtInt64:
    199                 if (rightUnionArray[i] == 0)
    200                     newConstArray[i].setI64Const(0x7FFFFFFFFFFFFFFFll);
    201                 else if (rightUnionArray[i].getI64Const() == -1 && leftUnionArray[i].getI64Const() == (long long)0x8000000000000000)
    202                     newConstArray[i].setI64Const(0x8000000000000000);
    203                 else
    204                     newConstArray[i].setI64Const(leftUnionArray[i].getI64Const() / rightUnionArray[i].getI64Const());
    205                 break;
    206 
    207             case EbtUint64:
    208                 if (rightUnionArray[i] == 0) {
    209                     newConstArray[i].setU64Const(0xFFFFFFFFFFFFFFFFull);
    210                 } else
    211                     newConstArray[i].setU64Const(leftUnionArray[i].getU64Const() / rightUnionArray[i].getU64Const());
    212                 break;
    213             default:
    214                 return 0;
    215             }
    216         }
    217         break;
    218 
    219     case EOpMatrixTimesVector:
    220         for (int i = 0; i < getMatrixRows(); i++) {
    221             double sum = 0.0f;
    222             for (int j = 0; j < rightNode->getVectorSize(); j++) {
    223                 sum += leftUnionArray[j*getMatrixRows() + i].getDConst() * rightUnionArray[j].getDConst();
    224             }
    225             newConstArray[i].setDConst(sum);
    226         }
    227 
    228         returnType.shallowCopy(TType(getBasicType(), EvqConst, getMatrixRows()));
    229         break;
    230 
    231     case EOpVectorTimesMatrix:
    232         for (int i = 0; i < rightNode->getMatrixCols(); i++) {
    233             double sum = 0.0f;
    234             for (int j = 0; j < getVectorSize(); j++)
    235                 sum += leftUnionArray[j].getDConst() * rightUnionArray[i*rightNode->getMatrixRows() + j].getDConst();
    236             newConstArray[i].setDConst(sum);
    237         }
    238 
    239         returnType.shallowCopy(TType(getBasicType(), EvqConst, rightNode->getMatrixCols()));
    240         break;
    241 
    242     case EOpMod:
    243         for (int i = 0; i < newComps; i++) {
    244             if (rightUnionArray[i] == 0)
    245                 newConstArray[i] = leftUnionArray[i];
    246             else
    247                 newConstArray[i] = leftUnionArray[i] % rightUnionArray[i];
    248         }
    249         break;
    250 
    251     case EOpRightShift:
    252         for (int i = 0; i < newComps; i++)
    253             newConstArray[i] = leftUnionArray[i] >> rightUnionArray[i];
    254         break;
    255 
    256     case EOpLeftShift:
    257         for (int i = 0; i < newComps; i++)
    258             newConstArray[i] = leftUnionArray[i] << rightUnionArray[i];
    259         break;
    260 
    261     case EOpAnd:
    262         for (int i = 0; i < newComps; i++)
    263             newConstArray[i] = leftUnionArray[i] & rightUnionArray[i];
    264         break;
    265     case EOpInclusiveOr:
    266         for (int i = 0; i < newComps; i++)
    267             newConstArray[i] = leftUnionArray[i] | rightUnionArray[i];
    268         break;
    269     case EOpExclusiveOr:
    270         for (int i = 0; i < newComps; i++)
    271             newConstArray[i] = leftUnionArray[i] ^ rightUnionArray[i];
    272         break;
    273 
    274     case EOpLogicalAnd: // this code is written for possible future use, will not get executed currently
    275         for (int i = 0; i < newComps; i++)
    276             newConstArray[i] = leftUnionArray[i] && rightUnionArray[i];
    277         break;
    278 
    279     case EOpLogicalOr: // this code is written for possible future use, will not get executed currently
    280         for (int i = 0; i < newComps; i++)
    281             newConstArray[i] = leftUnionArray[i] || rightUnionArray[i];
    282         break;
    283 
    284     case EOpLogicalXor:
    285         for (int i = 0; i < newComps; i++) {
    286             switch (getType().getBasicType()) {
    287             case EbtBool: newConstArray[i].setBConst((leftUnionArray[i] == rightUnionArray[i]) ? false : true); break;
    288             default: assert(false && "Default missing");
    289             }
    290         }
    291         break;
    292 
    293     case EOpLessThan:
    294         newConstArray[0].setBConst(leftUnionArray[0] < rightUnionArray[0]);
    295         returnType.shallowCopy(constBool);
    296         break;
    297     case EOpGreaterThan:
    298         newConstArray[0].setBConst(leftUnionArray[0] > rightUnionArray[0]);
    299         returnType.shallowCopy(constBool);
    300         break;
    301     case EOpLessThanEqual:
    302         newConstArray[0].setBConst(! (leftUnionArray[0] > rightUnionArray[0]));
    303         returnType.shallowCopy(constBool);
    304         break;
    305     case EOpGreaterThanEqual:
    306         newConstArray[0].setBConst(! (leftUnionArray[0] < rightUnionArray[0]));
    307         returnType.shallowCopy(constBool);
    308         break;
    309     case EOpEqual:
    310         newConstArray[0].setBConst(rightNode->getConstArray() == leftUnionArray);
    311         returnType.shallowCopy(constBool);
    312         break;
    313     case EOpNotEqual:
    314         newConstArray[0].setBConst(rightNode->getConstArray() != leftUnionArray);
    315         returnType.shallowCopy(constBool);
    316         break;
    317 
    318     default:
    319         return 0;
    320     }
    321 
    322     TIntermConstantUnion *newNode = new TIntermConstantUnion(newConstArray, returnType);
    323     newNode->setLoc(getLoc());
    324 
    325     return newNode;
    326 }
    327 
    328 //
    329 // Do single unary node folding
    330 //
    331 // Returns a new node representing the result.
    332 //
    333 TIntermTyped* TIntermConstantUnion::fold(TOperator op, const TType& returnType) const
    334 {
    335     // First, size the result, which is mostly the same as the argument's size,
    336     // but not always, and classify what is componentwise.
    337     // Also, eliminate cases that can't be compile-time constant.
    338     int resultSize;
    339     bool componentWise = true;
    340 
    341     int objectSize = getType().computeNumComponents();
    342     switch (op) {
    343     case EOpDeterminant:
    344     case EOpAny:
    345     case EOpAll:
    346     case EOpLength:
    347         componentWise = false;
    348         resultSize = 1;
    349         break;
    350 
    351     case EOpEmitStreamVertex:
    352     case EOpEndStreamPrimitive:
    353         // These don't actually fold
    354         return 0;
    355 
    356     case EOpPackSnorm2x16:
    357     case EOpPackUnorm2x16:
    358     case EOpPackHalf2x16:
    359         componentWise = false;
    360         resultSize = 1;
    361         break;
    362 
    363     case EOpUnpackSnorm2x16:
    364     case EOpUnpackUnorm2x16:
    365     case EOpUnpackHalf2x16:
    366         componentWise = false;
    367         resultSize = 2;
    368         break;
    369 
    370     case EOpNormalize:
    371         componentWise = false;
    372         resultSize = objectSize;
    373         break;
    374 
    375     default:
    376         resultSize = objectSize;
    377         break;
    378     }
    379 
    380     // Set up for processing
    381     TConstUnionArray newConstArray(resultSize);
    382     const TConstUnionArray& unionArray = getConstArray();
    383 
    384     // Process non-component-wise operations
    385     switch (op) {
    386     case EOpLength:
    387     case EOpNormalize:
    388     {
    389         double sum = 0;
    390         for (int i = 0; i < objectSize; i++)
    391             sum += unionArray[i].getDConst() * unionArray[i].getDConst();
    392         double length = sqrt(sum);
    393         if (op == EOpLength)
    394             newConstArray[0].setDConst(length);
    395         else {
    396             for (int i = 0; i < objectSize; i++)
    397                 newConstArray[i].setDConst(unionArray[i].getDConst() / length);
    398         }
    399         break;
    400     }
    401 
    402     case EOpAny:
    403     {
    404         bool result = false;
    405         for (int i = 0; i < objectSize; i++) {
    406             if (unionArray[i].getBConst())
    407                 result = true;
    408         }
    409         newConstArray[0].setBConst(result);
    410         break;
    411     }
    412     case EOpAll:
    413     {
    414         bool result = true;
    415         for (int i = 0; i < objectSize; i++) {
    416             if (! unionArray[i].getBConst())
    417                 result = false;
    418         }
    419         newConstArray[0].setBConst(result);
    420         break;
    421     }
    422 
    423     // TODO: 3.0 Functionality: unary constant folding: the rest of the ops have to be fleshed out
    424 
    425     case EOpPackSnorm2x16:
    426     case EOpPackUnorm2x16:
    427     case EOpPackHalf2x16:
    428 
    429     case EOpUnpackSnorm2x16:
    430     case EOpUnpackUnorm2x16:
    431     case EOpUnpackHalf2x16:
    432 
    433     case EOpDeterminant:
    434     case EOpMatrixInverse:
    435     case EOpTranspose:
    436         return 0;
    437 
    438     default:
    439         assert(componentWise);
    440         break;
    441     }
    442 
    443     // Turn off the componentwise loop
    444     if (! componentWise)
    445         objectSize = 0;
    446 
    447     // Process component-wise operations
    448     for (int i = 0; i < objectSize; i++) {
    449         switch (op) {
    450         case EOpNegative:
    451             switch (getType().getBasicType()) {
    452             case EbtDouble:
    453             case EbtFloat: newConstArray[i].setDConst(-unionArray[i].getDConst()); break;
    454             case EbtInt:   newConstArray[i].setIConst(-unionArray[i].getIConst()); break;
    455             case EbtUint:  newConstArray[i].setUConst(static_cast<unsigned int>(-static_cast<int>(unionArray[i].getUConst())));  break;
    456             case EbtInt64: newConstArray[i].setI64Const(-unionArray[i].getI64Const()); break;
    457             case EbtUint64: newConstArray[i].setU64Const(static_cast<unsigned int>(-static_cast<int>(unionArray[i].getU64Const())));  break;
    458             default:
    459                 return 0;
    460             }
    461             break;
    462         case EOpLogicalNot:
    463         case EOpVectorLogicalNot:
    464             switch (getType().getBasicType()) {
    465             case EbtBool:  newConstArray[i].setBConst(!unionArray[i].getBConst()); break;
    466             default:
    467                 return 0;
    468             }
    469             break;
    470         case EOpBitwiseNot:
    471             newConstArray[i] = ~unionArray[i];
    472             break;
    473         case EOpRadians:
    474             newConstArray[i].setDConst(unionArray[i].getDConst() * pi / 180.0);
    475             break;
    476         case EOpDegrees:
    477             newConstArray[i].setDConst(unionArray[i].getDConst() * 180.0 / pi);
    478             break;
    479         case EOpSin:
    480             newConstArray[i].setDConst(sin(unionArray[i].getDConst()));
    481             break;
    482         case EOpCos:
    483             newConstArray[i].setDConst(cos(unionArray[i].getDConst()));
    484             break;
    485         case EOpTan:
    486             newConstArray[i].setDConst(tan(unionArray[i].getDConst()));
    487             break;
    488         case EOpAsin:
    489             newConstArray[i].setDConst(asin(unionArray[i].getDConst()));
    490             break;
    491         case EOpAcos:
    492             newConstArray[i].setDConst(acos(unionArray[i].getDConst()));
    493             break;
    494         case EOpAtan:
    495             newConstArray[i].setDConst(atan(unionArray[i].getDConst()));
    496             break;
    497 
    498         case EOpDPdx:
    499         case EOpDPdy:
    500         case EOpFwidth:
    501         case EOpDPdxFine:
    502         case EOpDPdyFine:
    503         case EOpFwidthFine:
    504         case EOpDPdxCoarse:
    505         case EOpDPdyCoarse:
    506         case EOpFwidthCoarse:
    507             // The derivatives are all mandated to create a constant 0.
    508             newConstArray[i].setDConst(0.0);
    509             break;
    510 
    511         case EOpExp:
    512             newConstArray[i].setDConst(exp(unionArray[i].getDConst()));
    513             break;
    514         case EOpLog:
    515             newConstArray[i].setDConst(log(unionArray[i].getDConst()));
    516             break;
    517         case EOpExp2:
    518             {
    519                 const double inv_log2_e = 0.69314718055994530941723212145818;
    520                 newConstArray[i].setDConst(exp(unionArray[i].getDConst() * inv_log2_e));
    521                 break;
    522             }
    523         case EOpLog2:
    524             {
    525                 const double log2_e = 1.4426950408889634073599246810019;
    526                 newConstArray[i].setDConst(log2_e * log(unionArray[i].getDConst()));
    527                 break;
    528             }
    529         case EOpSqrt:
    530             newConstArray[i].setDConst(sqrt(unionArray[i].getDConst()));
    531             break;
    532         case EOpInverseSqrt:
    533             newConstArray[i].setDConst(1.0 / sqrt(unionArray[i].getDConst()));
    534             break;
    535 
    536         case EOpAbs:
    537             if (unionArray[i].getType() == EbtDouble)
    538                 newConstArray[i].setDConst(fabs(unionArray[i].getDConst()));
    539             else if (unionArray[i].getType() == EbtInt)
    540                 newConstArray[i].setIConst(abs(unionArray[i].getIConst()));
    541             else
    542                 newConstArray[i] = unionArray[i];
    543             break;
    544         case EOpSign:
    545             #define SIGN(X) (X == 0 ? 0 : (X < 0 ? -1 : 1))
    546             if (unionArray[i].getType() == EbtDouble)
    547                 newConstArray[i].setDConst(SIGN(unionArray[i].getDConst()));
    548             else
    549                 newConstArray[i].setIConst(SIGN(unionArray[i].getIConst()));
    550             break;
    551         case EOpFloor:
    552             newConstArray[i].setDConst(floor(unionArray[i].getDConst()));
    553             break;
    554         case EOpTrunc:
    555             if (unionArray[i].getDConst() > 0)
    556                 newConstArray[i].setDConst(floor(unionArray[i].getDConst()));
    557             else
    558                 newConstArray[i].setDConst(ceil(unionArray[i].getDConst()));
    559             break;
    560         case EOpRound:
    561             newConstArray[i].setDConst(floor(0.5 + unionArray[i].getDConst()));
    562             break;
    563         case EOpRoundEven:
    564         {
    565             double flr = floor(unionArray[i].getDConst());
    566             bool even = flr / 2.0 == floor(flr / 2.0);
    567             double rounded = even ? ceil(unionArray[i].getDConst() - 0.5) : floor(unionArray[i].getDConst() + 0.5);
    568             newConstArray[i].setDConst(rounded);
    569             break;
    570         }
    571         case EOpCeil:
    572             newConstArray[i].setDConst(ceil(unionArray[i].getDConst()));
    573             break;
    574         case EOpFract:
    575         {
    576             double x = unionArray[i].getDConst();
    577             newConstArray[i].setDConst(x - floor(x));
    578             break;
    579         }
    580 
    581         case EOpIsNan:
    582         {
    583             newConstArray[i].setBConst(isNan(unionArray[i].getDConst()));
    584             break;
    585         }
    586         case EOpIsInf:
    587         {
    588             newConstArray[i].setBConst(isInf(unionArray[i].getDConst()));
    589             break;
    590         }
    591 
    592         // TODO: 3.0 Functionality: unary constant folding: the rest of the ops have to be fleshed out
    593 
    594         case EOpSinh:
    595         case EOpCosh:
    596         case EOpTanh:
    597         case EOpAsinh:
    598         case EOpAcosh:
    599         case EOpAtanh:
    600 
    601         case EOpFloatBitsToInt:
    602         case EOpFloatBitsToUint:
    603         case EOpIntBitsToFloat:
    604         case EOpUintBitsToFloat:
    605         case EOpDoubleBitsToInt64:
    606         case EOpDoubleBitsToUint64:
    607 
    608         default:
    609             return 0;
    610         }
    611     }
    612 
    613     TIntermConstantUnion *newNode = new TIntermConstantUnion(newConstArray, returnType);
    614     newNode->getWritableType().getQualifier().storage = EvqConst;
    615     newNode->setLoc(getLoc());
    616 
    617     return newNode;
    618 }
    619 
    620 //
    621 // Do constant folding for an aggregate node that has all its children
    622 // as constants and an operator that requires constant folding.
    623 //
    624 TIntermTyped* TIntermediate::fold(TIntermAggregate* aggrNode)
    625 {
    626     if (! areAllChildConst(aggrNode))
    627         return aggrNode;
    628 
    629     if (aggrNode->isConstructor())
    630         return foldConstructor(aggrNode);
    631 
    632     TIntermSequence& children = aggrNode->getSequence();
    633 
    634     // First, see if this is an operation to constant fold, kick out if not,
    635     // see what size the result is if so.
    636 
    637     bool componentwise = false;  // will also say componentwise if a scalar argument gets repeated to make per-component results
    638     int objectSize;
    639     switch (aggrNode->getOp()) {
    640     case EOpAtan:
    641     case EOpPow:
    642     case EOpMin:
    643     case EOpMax:
    644     case EOpMix:
    645     case EOpClamp:
    646     case EOpLessThan:
    647     case EOpGreaterThan:
    648     case EOpLessThanEqual:
    649     case EOpGreaterThanEqual:
    650     case EOpVectorEqual:
    651     case EOpVectorNotEqual:
    652         componentwise = true;
    653         objectSize = children[0]->getAsConstantUnion()->getType().computeNumComponents();
    654         break;
    655     case EOpCross:
    656     case EOpReflect:
    657     case EOpRefract:
    658     case EOpFaceForward:
    659         objectSize = children[0]->getAsConstantUnion()->getType().computeNumComponents();
    660         break;
    661     case EOpDistance:
    662     case EOpDot:
    663         objectSize = 1;
    664         break;
    665     case EOpOuterProduct:
    666         objectSize = children[0]->getAsTyped()->getType().getVectorSize() *
    667                      children[1]->getAsTyped()->getType().getVectorSize();
    668         break;
    669     case EOpStep:
    670         componentwise = true;
    671         objectSize = std::max(children[0]->getAsTyped()->getType().getVectorSize(),
    672                               children[1]->getAsTyped()->getType().getVectorSize());
    673         break;
    674     case EOpSmoothStep:
    675         componentwise = true;
    676         objectSize = std::max(children[0]->getAsTyped()->getType().getVectorSize(),
    677                               children[2]->getAsTyped()->getType().getVectorSize());
    678         break;
    679     default:
    680         return aggrNode;
    681     }
    682     TConstUnionArray newConstArray(objectSize);
    683 
    684     TVector<TConstUnionArray> childConstUnions;
    685     for (unsigned int arg = 0; arg < children.size(); ++arg)
    686         childConstUnions.push_back(children[arg]->getAsConstantUnion()->getConstArray());
    687 
    688     // Second, do the actual folding
    689 
    690     bool isFloatingPoint = children[0]->getAsTyped()->getBasicType() == EbtFloat ||
    691                            children[0]->getAsTyped()->getBasicType() == EbtDouble;
    692     bool isSigned = children[0]->getAsTyped()->getBasicType() == EbtInt ||
    693                     children[0]->getAsTyped()->getBasicType() == EbtInt64;
    694     bool isInt64 = children[0]->getAsTyped()->getBasicType() == EbtInt64 ||
    695                    children[0]->getAsTyped()->getBasicType() == EbtUint64;
    696     if (componentwise) {
    697         for (int comp = 0; comp < objectSize; comp++) {
    698 
    699             // some arguments are scalars instead of matching vectors; simulate a smear
    700             int arg0comp = std::min(comp, children[0]->getAsTyped()->getType().getVectorSize() - 1);
    701             int arg1comp = 0;
    702             if (children.size() > 1)
    703                 arg1comp = std::min(comp, children[1]->getAsTyped()->getType().getVectorSize() - 1);
    704             int arg2comp = 0;
    705             if (children.size() > 2)
    706                 arg2comp = std::min(comp, children[2]->getAsTyped()->getType().getVectorSize() - 1);
    707 
    708             switch (aggrNode->getOp()) {
    709             case EOpAtan:
    710                 newConstArray[comp].setDConst(atan2(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst()));
    711                 break;
    712             case EOpPow:
    713                 newConstArray[comp].setDConst(pow(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst()));
    714                 break;
    715             case EOpMin:
    716                 if (isFloatingPoint)
    717                     newConstArray[comp].setDConst(std::min(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst()));
    718                 else if (isSigned) {
    719                     if (isInt64)
    720                         newConstArray[comp].setI64Const(std::min(childConstUnions[0][arg0comp].getI64Const(), childConstUnions[1][arg1comp].getI64Const()));
    721                     else
    722                         newConstArray[comp].setIConst(std::min(childConstUnions[0][arg0comp].getIConst(), childConstUnions[1][arg1comp].getIConst()));
    723                 } else {
    724                     if (isInt64)
    725                         newConstArray[comp].setU64Const(std::min(childConstUnions[0][arg0comp].getU64Const(), childConstUnions[1][arg1comp].getU64Const()));
    726                     else
    727                         newConstArray[comp].setUConst(std::min(childConstUnions[0][arg0comp].getUConst(), childConstUnions[1][arg1comp].getUConst()));
    728                 }
    729                 break;
    730             case EOpMax:
    731                 if (isFloatingPoint)
    732                     newConstArray[comp].setDConst(std::max(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst()));
    733                 else if (isSigned) {
    734                     if (isInt64)
    735                         newConstArray[comp].setI64Const(std::max(childConstUnions[0][arg0comp].getI64Const(), childConstUnions[1][arg1comp].getI64Const()));
    736                     else
    737                         newConstArray[comp].setIConst(std::max(childConstUnions[0][arg0comp].getIConst(), childConstUnions[1][arg1comp].getIConst()));
    738                 } else {
    739                     if (isInt64)
    740                         newConstArray[comp].setU64Const(std::max(childConstUnions[0][arg0comp].getU64Const(), childConstUnions[1][arg1comp].getU64Const()));
    741                     else
    742                         newConstArray[comp].setUConst(std::max(childConstUnions[0][arg0comp].getUConst(), childConstUnions[1][arg1comp].getUConst()));
    743                 }
    744                 break;
    745             case EOpClamp:
    746                 if (isFloatingPoint)
    747                     newConstArray[comp].setDConst(std::min(std::max(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst()),
    748                                                                                                                childConstUnions[2][arg2comp].getDConst()));
    749                 else if (isSigned) {
    750                     if (isInt64)
    751                         newConstArray[comp].setI64Const(std::min(std::max(childConstUnions[0][arg0comp].getI64Const(), childConstUnions[1][arg1comp].getI64Const()),
    752                                                                                                                        childConstUnions[2][arg2comp].getI64Const()));
    753                     else
    754                         newConstArray[comp].setIConst(std::min(std::max(childConstUnions[0][arg0comp].getIConst(), childConstUnions[1][arg1comp].getIConst()),
    755                                                                                                                    childConstUnions[2][arg2comp].getIConst()));
    756                 } else {
    757                     if (isInt64)
    758                         newConstArray[comp].setU64Const(std::min(std::max(childConstUnions[0][arg0comp].getU64Const(), childConstUnions[1][arg1comp].getU64Const()),
    759                                                                                                                        childConstUnions[2][arg2comp].getU64Const()));
    760                     else
    761                         newConstArray[comp].setUConst(std::min(std::max(childConstUnions[0][arg0comp].getUConst(), childConstUnions[1][arg1comp].getUConst()),
    762                                                                                                                    childConstUnions[2][arg2comp].getUConst()));
    763                 }
    764                 break;
    765             case EOpLessThan:
    766                 newConstArray[comp].setBConst(childConstUnions[0][arg0comp] < childConstUnions[1][arg1comp]);
    767                 break;
    768             case EOpGreaterThan:
    769                 newConstArray[comp].setBConst(childConstUnions[0][arg0comp] > childConstUnions[1][arg1comp]);
    770                 break;
    771             case EOpLessThanEqual:
    772                 newConstArray[comp].setBConst(! (childConstUnions[0][arg0comp] > childConstUnions[1][arg1comp]));
    773                 break;
    774             case EOpGreaterThanEqual:
    775                 newConstArray[comp].setBConst(! (childConstUnions[0][arg0comp] < childConstUnions[1][arg1comp]));
    776                 break;
    777             case EOpVectorEqual:
    778                 newConstArray[comp].setBConst(childConstUnions[0][arg0comp] == childConstUnions[1][arg1comp]);
    779                 break;
    780             case EOpVectorNotEqual:
    781                 newConstArray[comp].setBConst(childConstUnions[0][arg0comp] != childConstUnions[1][arg1comp]);
    782                 break;
    783             case EOpMix:
    784                 if (children[2]->getAsTyped()->getBasicType() == EbtBool)
    785                     newConstArray[comp].setDConst(childConstUnions[2][arg2comp].getBConst() ? childConstUnions[1][arg1comp].getDConst() :
    786                                                                                               childConstUnions[0][arg0comp].getDConst());
    787                 else
    788                     newConstArray[comp].setDConst(childConstUnions[0][arg0comp].getDConst() * (1.0 - childConstUnions[2][arg2comp].getDConst()) +
    789                                                   childConstUnions[1][arg1comp].getDConst() *        childConstUnions[2][arg2comp].getDConst());
    790                 break;
    791             case EOpStep:
    792                 newConstArray[comp].setDConst(childConstUnions[1][arg1comp].getDConst() < childConstUnions[0][arg0comp].getDConst() ? 0.0 : 1.0);
    793                 break;
    794             case EOpSmoothStep:
    795             {
    796                 double t = (childConstUnions[2][arg2comp].getDConst() - childConstUnions[0][arg0comp].getDConst()) /
    797                            (childConstUnions[1][arg1comp].getDConst() - childConstUnions[0][arg0comp].getDConst());
    798                 if (t < 0.0)
    799                     t = 0.0;
    800                 if (t > 1.0)
    801                     t = 1.0;
    802                 newConstArray[comp].setDConst(t * t * (3.0 - 2.0 * t));
    803                 break;
    804             }
    805             default:
    806                 return aggrNode;
    807             }
    808         }
    809     } else {
    810         // Non-componentwise...
    811 
    812         int numComps = children[0]->getAsConstantUnion()->getType().computeNumComponents();
    813         double dot;
    814 
    815         switch (aggrNode->getOp()) {
    816         case EOpDistance:
    817         {
    818             double sum = 0.0;
    819             for (int comp = 0; comp < numComps; ++comp) {
    820                 double diff = childConstUnions[1][comp].getDConst() - childConstUnions[0][comp].getDConst();
    821                 sum += diff * diff;
    822             }
    823             newConstArray[0].setDConst(sqrt(sum));
    824             break;
    825         }
    826         case EOpDot:
    827             newConstArray[0].setDConst(childConstUnions[0].dot(childConstUnions[1]));
    828             break;
    829         case EOpCross:
    830             newConstArray[0] = childConstUnions[0][1] * childConstUnions[1][2] - childConstUnions[0][2] * childConstUnions[1][1];
    831             newConstArray[1] = childConstUnions[0][2] * childConstUnions[1][0] - childConstUnions[0][0] * childConstUnions[1][2];
    832             newConstArray[2] = childConstUnions[0][0] * childConstUnions[1][1] - childConstUnions[0][1] * childConstUnions[1][0];
    833             break;
    834         case EOpFaceForward:
    835             // If dot(Nref, I) < 0 return N, otherwise return N:  Arguments are (N, I, Nref).
    836             dot = childConstUnions[1].dot(childConstUnions[2]);
    837             for (int comp = 0; comp < numComps; ++comp) {
    838                 if (dot < 0.0)
    839                     newConstArray[comp] = childConstUnions[0][comp];
    840                 else
    841                     newConstArray[comp].setDConst(-childConstUnions[0][comp].getDConst());
    842             }
    843             break;
    844         case EOpReflect:
    845             // I - 2 * dot(N, I) * N:  Arguments are (I, N).
    846             dot = childConstUnions[0].dot(childConstUnions[1]);
    847             dot *= 2.0;
    848             for (int comp = 0; comp < numComps; ++comp)
    849                 newConstArray[comp].setDConst(childConstUnions[0][comp].getDConst() - dot * childConstUnions[1][comp].getDConst());
    850             break;
    851         case EOpRefract:
    852         {
    853             // Arguments are (I, N, eta).
    854             // k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
    855             // if (k < 0.0)
    856             //     return dvec(0.0)
    857             // else
    858             //     return eta * I - (eta * dot(N, I) + sqrt(k)) * N
    859             dot = childConstUnions[0].dot(childConstUnions[1]);
    860             double eta = childConstUnions[2][0].getDConst();
    861             double k = 1.0 - eta * eta * (1.0 - dot * dot);
    862             if (k < 0.0) {
    863                 for (int comp = 0; comp < numComps; ++comp)
    864                     newConstArray[comp].setDConst(0.0);
    865             } else {
    866                 for (int comp = 0; comp < numComps; ++comp)
    867                     newConstArray[comp].setDConst(eta * childConstUnions[0][comp].getDConst() - (eta * dot + sqrt(k)) * childConstUnions[1][comp].getDConst());
    868             }
    869             break;
    870         }
    871         case EOpOuterProduct:
    872         {
    873             int numRows = numComps;
    874             int numCols = children[1]->getAsConstantUnion()->getType().computeNumComponents();
    875             for (int row = 0; row < numRows; ++row)
    876                 for (int col = 0; col < numCols; ++col)
    877                     newConstArray[col * numRows + row] = childConstUnions[0][row] * childConstUnions[1][col];
    878             break;
    879         }
    880         default:
    881             return aggrNode;
    882         }
    883     }
    884 
    885     TIntermConstantUnion *newNode = new TIntermConstantUnion(newConstArray, aggrNode->getType());
    886     newNode->getWritableType().getQualifier().storage = EvqConst;
    887     newNode->setLoc(aggrNode->getLoc());
    888 
    889     return newNode;
    890 }
    891 
    892 bool TIntermediate::areAllChildConst(TIntermAggregate* aggrNode)
    893 {
    894     bool allConstant = true;
    895 
    896     // check if all the child nodes are constants so that they can be inserted into
    897     // the parent node
    898     if (aggrNode) {
    899         TIntermSequence& childSequenceVector = aggrNode->getSequence();
    900         for (TIntermSequence::iterator p  = childSequenceVector.begin();
    901                                        p != childSequenceVector.end(); p++) {
    902             if (!(*p)->getAsTyped()->getAsConstantUnion())
    903                 return false;
    904         }
    905     }
    906 
    907     return allConstant;
    908 }
    909 
    910 TIntermTyped* TIntermediate::foldConstructor(TIntermAggregate* aggrNode)
    911 {
    912     bool error = false;
    913 
    914     TConstUnionArray unionArray(aggrNode->getType().computeNumComponents());
    915     if (aggrNode->getSequence().size() == 1)
    916         error = parseConstTree(aggrNode, unionArray, aggrNode->getOp(), aggrNode->getType(), true);
    917     else
    918         error = parseConstTree(aggrNode, unionArray, aggrNode->getOp(), aggrNode->getType());
    919 
    920     if (error)
    921         return aggrNode;
    922 
    923     return addConstantUnion(unionArray, aggrNode->getType(), aggrNode->getLoc());
    924 }
    925 
    926 //
    927 // Constant folding of a bracket (array-style) dereference or struct-like dot
    928 // dereference.  Can handle anything except a multi-character swizzle, though
    929 // all swizzles may go to foldSwizzle().
    930 //
    931 TIntermTyped* TIntermediate::foldDereference(TIntermTyped* node, int index, const TSourceLoc& loc)
    932 {
    933     TType dereferencedType(node->getType(), index);
    934     dereferencedType.getQualifier().storage = EvqConst;
    935     TIntermTyped* result = 0;
    936     int size = dereferencedType.computeNumComponents();
    937 
    938     // arrays, vectors, matrices, all use simple multiplicative math
    939     // while structures need to add up heterogeneous members
    940     int start;
    941     if (node->isArray() || ! node->isStruct())
    942         start = size * index;
    943     else {
    944         // it is a structure
    945         assert(node->isStruct());
    946         start = 0;
    947         for (int i = 0; i < index; ++i)
    948             start += (*node->getType().getStruct())[i].type->computeNumComponents();
    949     }
    950 
    951     result = addConstantUnion(TConstUnionArray(node->getAsConstantUnion()->getConstArray(), start, size), node->getType(), loc);
    952 
    953     if (result == 0)
    954         result = node;
    955     else
    956         result->setType(dereferencedType);
    957 
    958     return result;
    959 }
    960 
    961 //
    962 // Make a constant vector node or constant scalar node, representing a given
    963 // constant vector and constant swizzle into it.
    964 //
    965 TIntermTyped* TIntermediate::foldSwizzle(TIntermTyped* node, TVectorFields& fields, const TSourceLoc& loc)
    966 {
    967     const TConstUnionArray& unionArray = node->getAsConstantUnion()->getConstArray();
    968     TConstUnionArray constArray(fields.num);
    969 
    970     for (int i = 0; i < fields.num; i++)
    971         constArray[i] = unionArray[fields.offsets[i]];
    972 
    973     TIntermTyped* result = addConstantUnion(constArray, node->getType(), loc);
    974 
    975     if (result == 0)
    976         result = node;
    977     else
    978         result->setType(TType(node->getBasicType(), EvqConst, fields.num));
    979 
    980     return result;
    981 }
    982 
    983 } // end namespace glslang
    984