Home | History | Annotate | Download | only in compiler
      1 //
      2 // Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved.
      3 // Use of this source code is governed by a BSD-style license that can be
      4 // found in the LICENSE file.
      5 //
      6 // UnfoldShortCircuit is an AST traverser to output short-circuiting operators as if-else statements.
      7 // The results are assigned to s# temporaries, which are used by the main translator instead of
      8 // the original expression.
      9 //
     10 
     11 #include "compiler/UnfoldShortCircuit.h"
     12 
     13 #include "compiler/InfoSink.h"
     14 #include "compiler/OutputHLSL.h"
     15 
     16 namespace sh
     17 {
     18 UnfoldShortCircuit::UnfoldShortCircuit(TParseContext &context, OutputHLSL *outputHLSL) : mContext(context), mOutputHLSL(outputHLSL)
     19 {
     20     mTemporaryIndex = 0;
     21 }
     22 
     23 void UnfoldShortCircuit::traverse(TIntermNode *node)
     24 {
     25     int rewindIndex = mTemporaryIndex;
     26     node->traverse(this);
     27     mTemporaryIndex = rewindIndex;
     28 }
     29 
     30 bool UnfoldShortCircuit::visitBinary(Visit visit, TIntermBinary *node)
     31 {
     32     TInfoSinkBase &out = mOutputHLSL->getBodyStream();
     33 
     34     switch (node->getOp())
     35     {
     36       case EOpLogicalOr:
     37         // "x || y" is equivalent to "x ? true : y", which unfolds to "bool s; if(x) s = true; else s = y;",
     38         // and then further simplifies down to "bool s = x; if(!s) s = y;".
     39         {
     40             int i = mTemporaryIndex;
     41 
     42             out << "bool s" << i << ";\n";
     43 
     44             out << "{\n";
     45 
     46             mTemporaryIndex = i + 1;
     47             node->getLeft()->traverse(this);
     48             out << "s" << i << " = ";
     49             mTemporaryIndex = i + 1;
     50             node->getLeft()->traverse(mOutputHLSL);
     51             out << ";\n";
     52             out << "if(!s" << i << ")\n"
     53                    "{\n";
     54             mTemporaryIndex = i + 1;
     55             node->getRight()->traverse(this);
     56             out << "    s" << i << " = ";
     57             mTemporaryIndex = i + 1;
     58             node->getRight()->traverse(mOutputHLSL);
     59             out << ";\n"
     60                    "}\n";
     61 
     62             out << "}\n";
     63 
     64             mTemporaryIndex = i + 1;
     65         }
     66         return false;
     67       case EOpLogicalAnd:
     68         // "x && y" is equivalent to "x ? y : false", which unfolds to "bool s; if(x) s = y; else s = false;",
     69         // and then further simplifies down to "bool s = x; if(s) s = y;".
     70         {
     71             int i = mTemporaryIndex;
     72 
     73             out << "bool s" << i << ";\n";
     74 
     75             out << "{\n";
     76 
     77             mTemporaryIndex = i + 1;
     78             node->getLeft()->traverse(this);
     79             out << "s" << i << " = ";
     80             mTemporaryIndex = i + 1;
     81             node->getLeft()->traverse(mOutputHLSL);
     82             out << ";\n";
     83             out << "if(s" << i << ")\n"
     84                    "{\n";
     85             mTemporaryIndex = i + 1;
     86             node->getRight()->traverse(this);
     87             out << "    s" << i << " = ";
     88             mTemporaryIndex = i + 1;
     89             node->getRight()->traverse(mOutputHLSL);
     90             out << ";\n"
     91                    "}\n";
     92 
     93             out << "}\n";
     94 
     95             mTemporaryIndex = i + 1;
     96         }
     97         return false;
     98       default:
     99         return true;
    100     }
    101 }
    102 
    103 bool UnfoldShortCircuit::visitSelection(Visit visit, TIntermSelection *node)
    104 {
    105     TInfoSinkBase &out = mOutputHLSL->getBodyStream();
    106 
    107     // Unfold "b ? x : y" into "type s; if(b) s = x; else s = y;"
    108     if (node->usesTernaryOperator())
    109     {
    110         int i = mTemporaryIndex;
    111 
    112         out << mOutputHLSL->typeString(node->getType()) << " s" << i << ";\n";
    113 
    114         out << "{\n";
    115 
    116         mTemporaryIndex = i + 1;
    117         node->getCondition()->traverse(this);
    118         out << "if(";
    119         mTemporaryIndex = i + 1;
    120         node->getCondition()->traverse(mOutputHLSL);
    121         out << ")\n"
    122                "{\n";
    123         mTemporaryIndex = i + 1;
    124         node->getTrueBlock()->traverse(this);
    125         out << "    s" << i << " = ";
    126         mTemporaryIndex = i + 1;
    127         node->getTrueBlock()->traverse(mOutputHLSL);
    128         out << ";\n"
    129                "}\n"
    130                "else\n"
    131                "{\n";
    132         mTemporaryIndex = i + 1;
    133         node->getFalseBlock()->traverse(this);
    134         out << "    s" << i << " = ";
    135         mTemporaryIndex = i + 1;
    136         node->getFalseBlock()->traverse(mOutputHLSL);
    137         out << ";\n"
    138                "}\n";
    139 
    140         out << "}\n";
    141 
    142         mTemporaryIndex = i + 1;
    143     }
    144 
    145     return false;
    146 }
    147 
    148 bool UnfoldShortCircuit::visitLoop(Visit visit, TIntermLoop *node)
    149 {
    150     int rewindIndex = mTemporaryIndex;
    151 
    152     if (node->getInit())
    153     {
    154         node->getInit()->traverse(this);
    155     }
    156 
    157     if (node->getCondition())
    158     {
    159         node->getCondition()->traverse(this);
    160     }
    161 
    162     if (node->getExpression())
    163     {
    164         node->getExpression()->traverse(this);
    165     }
    166 
    167     mTemporaryIndex = rewindIndex;
    168 
    169     return false;
    170 }
    171 
    172 int UnfoldShortCircuit::getNextTemporaryIndex()
    173 {
    174     return mTemporaryIndex++;
    175 }
    176 }
    177