Home | History | Annotate | Download | only in compiler
      1 //
      2 // Copyright (c) 2002-2010 The ANGLE Project Authors. All rights reserved.
      3 // Use of this source code is governed by a BSD-style license that can be
      4 // found in the LICENSE file.
      5 //
      6 
      7 #include "compiler/VariableInfo.h"
      8 
      9 namespace {
     10 
     11 TString arrayBrackets(int index)
     12 {
     13     TStringStream stream;
     14     stream << "[" << index << "]";
     15     return stream.str();
     16 }
     17 
     18 // Returns the data type for an attribute, uniform, or varying.
     19 ShDataType getVariableDataType(const TType& type)
     20 {
     21     switch (type.getBasicType()) {
     22       case EbtFloat:
     23           if (type.isMatrix()) {
     24               switch (type.getNominalSize()) {
     25                 case 2: return SH_FLOAT_MAT2;
     26                 case 3: return SH_FLOAT_MAT3;
     27                 case 4: return SH_FLOAT_MAT4;
     28                 default: UNREACHABLE();
     29               }
     30           } else if (type.isVector()) {
     31               switch (type.getNominalSize()) {
     32                 case 2: return SH_FLOAT_VEC2;
     33                 case 3: return SH_FLOAT_VEC3;
     34                 case 4: return SH_FLOAT_VEC4;
     35                 default: UNREACHABLE();
     36               }
     37           } else {
     38               return SH_FLOAT;
     39           }
     40       case EbtInt:
     41           if (type.isMatrix()) {
     42               UNREACHABLE();
     43           } else if (type.isVector()) {
     44               switch (type.getNominalSize()) {
     45                 case 2: return SH_INT_VEC2;
     46                 case 3: return SH_INT_VEC3;
     47                 case 4: return SH_INT_VEC4;
     48                 default: UNREACHABLE();
     49               }
     50           } else {
     51               return SH_INT;
     52           }
     53       case EbtBool:
     54           if (type.isMatrix()) {
     55               UNREACHABLE();
     56           } else if (type.isVector()) {
     57               switch (type.getNominalSize()) {
     58                 case 2: return SH_BOOL_VEC2;
     59                 case 3: return SH_BOOL_VEC3;
     60                 case 4: return SH_BOOL_VEC4;
     61                 default: UNREACHABLE();
     62               }
     63           } else {
     64               return SH_BOOL;
     65           }
     66       case EbtSampler2D: return SH_SAMPLER_2D;
     67       case EbtSamplerCube: return SH_SAMPLER_CUBE;
     68       case EbtSamplerExternalOES: return SH_SAMPLER_EXTERNAL_OES;
     69       case EbtSampler2DRect: return SH_SAMPLER_2D_RECT_ARB;
     70       default: UNREACHABLE();
     71     }
     72     return SH_NONE;
     73 }
     74 
     75 void getBuiltInVariableInfo(const TType& type,
     76                             const TString& name,
     77                             const TString& mappedName,
     78                             TVariableInfoList& infoList);
     79 void getUserDefinedVariableInfo(const TType& type,
     80                                 const TString& name,
     81                                 const TString& mappedName,
     82                                 TVariableInfoList& infoList,
     83                                 ShHashFunction64 hashFunction);
     84 
     85 // Returns info for an attribute, uniform, or varying.
     86 void getVariableInfo(const TType& type,
     87                      const TString& name,
     88                      const TString& mappedName,
     89                      TVariableInfoList& infoList,
     90                      ShHashFunction64 hashFunction)
     91 {
     92     if (type.getBasicType() == EbtStruct) {
     93         if (type.isArray()) {
     94             for (int i = 0; i < type.getArraySize(); ++i) {
     95                 TString lname = name + arrayBrackets(i);
     96                 TString lmappedName = mappedName + arrayBrackets(i);
     97                 getUserDefinedVariableInfo(type, lname, lmappedName, infoList, hashFunction);
     98             }
     99         } else {
    100             getUserDefinedVariableInfo(type, name, mappedName, infoList, hashFunction);
    101         }
    102     } else {
    103         getBuiltInVariableInfo(type, name, mappedName, infoList);
    104     }
    105 }
    106 
    107 void getBuiltInVariableInfo(const TType& type,
    108                             const TString& name,
    109                             const TString& mappedName,
    110                             TVariableInfoList& infoList)
    111 {
    112     ASSERT(type.getBasicType() != EbtStruct);
    113 
    114     TVariableInfo varInfo;
    115     if (type.isArray()) {
    116         varInfo.name = (name + "[0]").c_str();
    117         varInfo.mappedName = (mappedName + "[0]").c_str();
    118         varInfo.size = type.getArraySize();
    119     } else {
    120         varInfo.name = name.c_str();
    121         varInfo.mappedName = mappedName.c_str();
    122         varInfo.size = 1;
    123     }
    124     varInfo.precision = type.getPrecision();
    125     varInfo.type = getVariableDataType(type);
    126     infoList.push_back(varInfo);
    127 }
    128 
    129 void getUserDefinedVariableInfo(const TType& type,
    130                                 const TString& name,
    131                                 const TString& mappedName,
    132                                 TVariableInfoList& infoList,
    133                                 ShHashFunction64 hashFunction)
    134 {
    135     ASSERT(type.getBasicType() == EbtStruct);
    136 
    137     const TFieldList& fields = type.getStruct()->fields();
    138     for (size_t i = 0; i < fields.size(); ++i) {
    139         const TType& fieldType = *(fields[i]->type());
    140         const TString& fieldName = fields[i]->name();
    141         getVariableInfo(fieldType,
    142                         name + "." + fieldName,
    143                         mappedName + "." + TIntermTraverser::hash(fieldName, hashFunction),
    144                         infoList,
    145                         hashFunction);
    146     }
    147 }
    148 
    149 TVariableInfo* findVariable(const TType& type,
    150                             const TString& name,
    151                             TVariableInfoList& infoList)
    152 {
    153     // TODO(zmo): optimize this function.
    154     TString myName = name;
    155     if (type.isArray())
    156         myName += "[0]";
    157     for (size_t ii = 0; ii < infoList.size(); ++ii)
    158     {
    159         if (infoList[ii].name.c_str() == myName)
    160             return &(infoList[ii]);
    161     }
    162     return NULL;
    163 }
    164 
    165 }  // namespace anonymous
    166 
    167 TVariableInfo::TVariableInfo()
    168     : type(SH_NONE),
    169       size(0),
    170       precision(EbpUndefined),
    171       staticUse(false)
    172 {
    173 }
    174 
    175 TVariableInfo::TVariableInfo(ShDataType type, int size)
    176     : type(type),
    177       size(size),
    178       precision(EbpUndefined),
    179       staticUse(false)
    180 {
    181 }
    182 
    183 CollectVariables::CollectVariables(TVariableInfoList& attribs,
    184                                    TVariableInfoList& uniforms,
    185                                    TVariableInfoList& varyings,
    186                                    ShHashFunction64 hashFunction)
    187     : mAttribs(attribs),
    188       mUniforms(uniforms),
    189       mVaryings(varyings),
    190       mPointCoordAdded(false),
    191       mFrontFacingAdded(false),
    192       mFragCoordAdded(false),
    193       mHashFunction(hashFunction)
    194 {
    195 }
    196 
    197 // We want to check whether a uniform/varying is statically used
    198 // because we only count the used ones in packing computing.
    199 // Also, gl_FragCoord, gl_PointCoord, and gl_FrontFacing count
    200 // toward varying counting if they are statically used in a fragment
    201 // shader.
    202 void CollectVariables::visitSymbol(TIntermSymbol* symbol)
    203 {
    204     ASSERT(symbol != NULL);
    205     TVariableInfo* var = NULL;
    206     switch (symbol->getQualifier())
    207     {
    208     case EvqVaryingOut:
    209     case EvqInvariantVaryingOut:
    210     case EvqVaryingIn:
    211     case EvqInvariantVaryingIn:
    212         var = findVariable(symbol->getType(), symbol->getSymbol(), mVaryings);
    213         break;
    214     case EvqUniform:
    215         var = findVariable(symbol->getType(), symbol->getSymbol(), mUniforms);
    216         break;
    217     case EvqFragCoord:
    218         if (!mFragCoordAdded) {
    219             TVariableInfo info;
    220             info.name = "gl_FragCoord";
    221             info.mappedName = "gl_FragCoord";
    222             info.type = SH_FLOAT_VEC4;
    223             info.size = 1;
    224             info.precision = EbpMedium;  // Use mediump as it doesn't really matter.
    225             info.staticUse = true;
    226 	    mVaryings.push_back(info);
    227             mFragCoordAdded = true;
    228         }
    229         return;
    230     case EvqFrontFacing:
    231         if (!mFrontFacingAdded) {
    232             TVariableInfo info;
    233             info.name = "gl_FrontFacing";
    234             info.mappedName = "gl_FrontFacing";
    235             info.type = SH_BOOL;
    236             info.size = 1;
    237             info.precision = EbpUndefined;
    238             info.staticUse = true;
    239 	    mVaryings.push_back(info);
    240             mFrontFacingAdded = true;
    241         }
    242         return;
    243     case EvqPointCoord:
    244         if (!mPointCoordAdded) {
    245             TVariableInfo info;
    246             info.name = "gl_PointCoord";
    247             info.mappedName = "gl_PointCoord";
    248             info.type = SH_FLOAT_VEC2;
    249             info.size = 1;
    250             info.precision = EbpMedium;  // Use mediump as it doesn't really matter.
    251             info.staticUse = true;
    252 	    mVaryings.push_back(info);
    253             mPointCoordAdded = true;
    254         }
    255         return;
    256     default:
    257         break;
    258     }
    259     if (var)
    260         var->staticUse = true;
    261 }
    262 
    263 bool CollectVariables::visitAggregate(Visit, TIntermAggregate* node)
    264 {
    265     bool visitChildren = true;
    266 
    267     switch (node->getOp())
    268     {
    269     case EOpDeclaration: {
    270         const TIntermSequence& sequence = node->getSequence();
    271         TQualifier qualifier = sequence.front()->getAsTyped()->getQualifier();
    272         if (qualifier == EvqAttribute || qualifier == EvqUniform ||
    273             qualifier == EvqVaryingIn || qualifier == EvqVaryingOut ||
    274             qualifier == EvqInvariantVaryingIn || qualifier == EvqInvariantVaryingOut)
    275         {
    276             TVariableInfoList& infoList = qualifier == EvqAttribute ? mAttribs :
    277                 (qualifier == EvqUniform ? mUniforms : mVaryings);
    278             for (TIntermSequence::const_iterator i = sequence.begin();
    279                  i != sequence.end(); ++i)
    280             {
    281                 const TIntermSymbol* variable = (*i)->getAsSymbolNode();
    282                 // The only case in which the sequence will not contain a
    283                 // TIntermSymbol node is initialization. It will contain a
    284                 // TInterBinary node in that case. Since attributes, uniforms,
    285                 // and varyings cannot be initialized in a shader, we must have
    286                 // only TIntermSymbol nodes in the sequence.
    287                 ASSERT(variable != NULL);
    288                 TString processedSymbol;
    289                 if (mHashFunction == NULL)
    290                     processedSymbol = variable->getSymbol();
    291                 else
    292                     processedSymbol = TIntermTraverser::hash(variable->getOriginalSymbol(), mHashFunction);
    293                 getVariableInfo(variable->getType(),
    294                                 variable->getOriginalSymbol(),
    295                                 processedSymbol,
    296                                 infoList,
    297                                 mHashFunction);
    298                 visitChildren = false;
    299             }
    300         }
    301         break;
    302     }
    303     default: break;
    304     }
    305 
    306     return visitChildren;
    307 }
    308 
    309