Home | History | Annotate | Download | only in compiler
      1 //
      2 // Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved.
      3 // Use of this source code is governed by a BSD-style license that can be
      4 // found in the LICENSE file.
      5 //
      6 // UnfoldShortCircuit is an AST traverser to output short-circuiting operators as if-else statements.
      7 // The results are assigned to s# temporaries, which are used by the main translator instead of
      8 // the original expression.
      9 //
     10 
     11 #include "compiler/UnfoldShortCircuit.h"
     12 
     13 #include "compiler/InfoSink.h"
     14 #include "compiler/OutputHLSL.h"
     15 
     16 namespace sh
     17 {
     18 UnfoldShortCircuit::UnfoldShortCircuit(TParseContext &context, OutputHLSL *outputHLSL) : mContext(context), mOutputHLSL(outputHLSL)
     19 {
     20     mTemporaryIndex = 0;
     21 }
     22 
     23 void UnfoldShortCircuit::traverse(TIntermNode *node)
     24 {
     25     int rewindIndex = mTemporaryIndex;
     26     node->traverse(this);
     27     mTemporaryIndex = rewindIndex;
     28 }
     29 
     30 bool UnfoldShortCircuit::visitBinary(Visit visit, TIntermBinary *node)
     31 {
     32     TInfoSinkBase &out = mOutputHLSL->getBodyStream();
     33 
     34     // If our right node doesn't have side effects, we know we don't need to unfold this
     35     // expression: there will be no short-circuiting side effects to avoid
     36     // (note: unfolding doesn't depend on the left node -- it will always be evaluated)
     37     if (!node->getRight()->hasSideEffects())
     38     {
     39         return true;
     40     }
     41 
     42     switch (node->getOp())
     43     {
     44       case EOpLogicalOr:
     45         // "x || y" is equivalent to "x ? true : y", which unfolds to "bool s; if(x) s = true; else s = y;",
     46         // and then further simplifies down to "bool s = x; if(!s) s = y;".
     47         {
     48             int i = mTemporaryIndex;
     49 
     50             out << "bool s" << i << ";\n";
     51 
     52             out << "{\n";
     53 
     54             mTemporaryIndex = i + 1;
     55             node->getLeft()->traverse(this);
     56             out << "s" << i << " = ";
     57             mTemporaryIndex = i + 1;
     58             node->getLeft()->traverse(mOutputHLSL);
     59             out << ";\n";
     60             out << "if (!s" << i << ")\n"
     61                    "{\n";
     62             mTemporaryIndex = i + 1;
     63             node->getRight()->traverse(this);
     64             out << "    s" << i << " = ";
     65             mTemporaryIndex = i + 1;
     66             node->getRight()->traverse(mOutputHLSL);
     67             out << ";\n"
     68                    "}\n";
     69 
     70             out << "}\n";
     71 
     72             mTemporaryIndex = i + 1;
     73         }
     74         return false;
     75       case EOpLogicalAnd:
     76         // "x && y" is equivalent to "x ? y : false", which unfolds to "bool s; if(x) s = y; else s = false;",
     77         // and then further simplifies down to "bool s = x; if(s) s = y;".
     78         {
     79             int i = mTemporaryIndex;
     80 
     81             out << "bool s" << i << ";\n";
     82 
     83             out << "{\n";
     84 
     85             mTemporaryIndex = i + 1;
     86             node->getLeft()->traverse(this);
     87             out << "s" << i << " = ";
     88             mTemporaryIndex = i + 1;
     89             node->getLeft()->traverse(mOutputHLSL);
     90             out << ";\n";
     91             out << "if (s" << i << ")\n"
     92                    "{\n";
     93             mTemporaryIndex = i + 1;
     94             node->getRight()->traverse(this);
     95             out << "    s" << i << " = ";
     96             mTemporaryIndex = i + 1;
     97             node->getRight()->traverse(mOutputHLSL);
     98             out << ";\n"
     99                    "}\n";
    100 
    101             out << "}\n";
    102 
    103             mTemporaryIndex = i + 1;
    104         }
    105         return false;
    106       default:
    107         return true;
    108     }
    109 }
    110 
    111 bool UnfoldShortCircuit::visitSelection(Visit visit, TIntermSelection *node)
    112 {
    113     TInfoSinkBase &out = mOutputHLSL->getBodyStream();
    114 
    115     // Unfold "b ? x : y" into "type s; if(b) s = x; else s = y;"
    116     if (node->usesTernaryOperator())
    117     {
    118         int i = mTemporaryIndex;
    119 
    120         out << mOutputHLSL->typeString(node->getType()) << " s" << i << ";\n";
    121 
    122         out << "{\n";
    123 
    124         mTemporaryIndex = i + 1;
    125         node->getCondition()->traverse(this);
    126         out << "if (";
    127         mTemporaryIndex = i + 1;
    128         node->getCondition()->traverse(mOutputHLSL);
    129         out << ")\n"
    130                "{\n";
    131         mTemporaryIndex = i + 1;
    132         node->getTrueBlock()->traverse(this);
    133         out << "    s" << i << " = ";
    134         mTemporaryIndex = i + 1;
    135         node->getTrueBlock()->traverse(mOutputHLSL);
    136         out << ";\n"
    137                "}\n"
    138                "else\n"
    139                "{\n";
    140         mTemporaryIndex = i + 1;
    141         node->getFalseBlock()->traverse(this);
    142         out << "    s" << i << " = ";
    143         mTemporaryIndex = i + 1;
    144         node->getFalseBlock()->traverse(mOutputHLSL);
    145         out << ";\n"
    146                "}\n";
    147 
    148         out << "}\n";
    149 
    150         mTemporaryIndex = i + 1;
    151     }
    152 
    153     return false;
    154 }
    155 
    156 bool UnfoldShortCircuit::visitLoop(Visit visit, TIntermLoop *node)
    157 {
    158     int rewindIndex = mTemporaryIndex;
    159 
    160     if (node->getInit())
    161     {
    162         node->getInit()->traverse(this);
    163     }
    164 
    165     if (node->getCondition())
    166     {
    167         node->getCondition()->traverse(this);
    168     }
    169 
    170     if (node->getExpression())
    171     {
    172         node->getExpression()->traverse(this);
    173     }
    174 
    175     mTemporaryIndex = rewindIndex;
    176 
    177     return false;
    178 }
    179 
    180 int UnfoldShortCircuit::getNextTemporaryIndex()
    181 {
    182     return mTemporaryIndex++;
    183 }
    184 }
    185