1 // 2 // Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style license that can be 4 // found in the LICENSE file. 5 // 6 // UnfoldShortCircuit is an AST traverser to output short-circuiting operators as if-else statements. 7 // The results are assigned to s# temporaries, which are used by the main translator instead of 8 // the original expression. 9 // 10 11 #include "compiler/UnfoldShortCircuit.h" 12 13 #include "compiler/InfoSink.h" 14 #include "compiler/OutputHLSL.h" 15 16 namespace sh 17 { 18 UnfoldShortCircuit::UnfoldShortCircuit(TParseContext &context, OutputHLSL *outputHLSL) : mContext(context), mOutputHLSL(outputHLSL) 19 { 20 mTemporaryIndex = 0; 21 } 22 23 void UnfoldShortCircuit::traverse(TIntermNode *node) 24 { 25 int rewindIndex = mTemporaryIndex; 26 node->traverse(this); 27 mTemporaryIndex = rewindIndex; 28 } 29 30 bool UnfoldShortCircuit::visitBinary(Visit visit, TIntermBinary *node) 31 { 32 TInfoSinkBase &out = mOutputHLSL->getBodyStream(); 33 34 switch (node->getOp()) 35 { 36 case EOpLogicalOr: 37 // "x || y" is equivalent to "x ? true : y", which unfolds to "bool s; if(x) s = true; else s = y;", 38 // and then further simplifies down to "bool s = x; if(!s) s = y;". 39 { 40 int i = mTemporaryIndex; 41 42 out << "bool s" << i << ";\n"; 43 44 out << "{\n"; 45 46 mTemporaryIndex = i + 1; 47 node->getLeft()->traverse(this); 48 out << "s" << i << " = "; 49 mTemporaryIndex = i + 1; 50 node->getLeft()->traverse(mOutputHLSL); 51 out << ";\n"; 52 out << "if(!s" << i << ")\n" 53 "{\n"; 54 mTemporaryIndex = i + 1; 55 node->getRight()->traverse(this); 56 out << " s" << i << " = "; 57 mTemporaryIndex = i + 1; 58 node->getRight()->traverse(mOutputHLSL); 59 out << ";\n" 60 "}\n"; 61 62 out << "}\n"; 63 64 mTemporaryIndex = i + 1; 65 } 66 return false; 67 case EOpLogicalAnd: 68 // "x && y" is equivalent to "x ? y : false", which unfolds to "bool s; if(x) s = y; else s = false;", 69 // and then further simplifies down to "bool s = x; if(s) s = y;". 70 { 71 int i = mTemporaryIndex; 72 73 out << "bool s" << i << ";\n"; 74 75 out << "{\n"; 76 77 mTemporaryIndex = i + 1; 78 node->getLeft()->traverse(this); 79 out << "s" << i << " = "; 80 mTemporaryIndex = i + 1; 81 node->getLeft()->traverse(mOutputHLSL); 82 out << ";\n"; 83 out << "if(s" << i << ")\n" 84 "{\n"; 85 mTemporaryIndex = i + 1; 86 node->getRight()->traverse(this); 87 out << " s" << i << " = "; 88 mTemporaryIndex = i + 1; 89 node->getRight()->traverse(mOutputHLSL); 90 out << ";\n" 91 "}\n"; 92 93 out << "}\n"; 94 95 mTemporaryIndex = i + 1; 96 } 97 return false; 98 default: 99 return true; 100 } 101 } 102 103 bool UnfoldShortCircuit::visitSelection(Visit visit, TIntermSelection *node) 104 { 105 TInfoSinkBase &out = mOutputHLSL->getBodyStream(); 106 107 // Unfold "b ? x : y" into "type s; if(b) s = x; else s = y;" 108 if (node->usesTernaryOperator()) 109 { 110 int i = mTemporaryIndex; 111 112 out << mOutputHLSL->typeString(node->getType()) << " s" << i << ";\n"; 113 114 out << "{\n"; 115 116 mTemporaryIndex = i + 1; 117 node->getCondition()->traverse(this); 118 out << "if("; 119 mTemporaryIndex = i + 1; 120 node->getCondition()->traverse(mOutputHLSL); 121 out << ")\n" 122 "{\n"; 123 mTemporaryIndex = i + 1; 124 node->getTrueBlock()->traverse(this); 125 out << " s" << i << " = "; 126 mTemporaryIndex = i + 1; 127 node->getTrueBlock()->traverse(mOutputHLSL); 128 out << ";\n" 129 "}\n" 130 "else\n" 131 "{\n"; 132 mTemporaryIndex = i + 1; 133 node->getFalseBlock()->traverse(this); 134 out << " s" << i << " = "; 135 mTemporaryIndex = i + 1; 136 node->getFalseBlock()->traverse(mOutputHLSL); 137 out << ";\n" 138 "}\n"; 139 140 out << "}\n"; 141 142 mTemporaryIndex = i + 1; 143 } 144 145 return false; 146 } 147 148 bool UnfoldShortCircuit::visitLoop(Visit visit, TIntermLoop *node) 149 { 150 int rewindIndex = mTemporaryIndex; 151 152 if (node->getInit()) 153 { 154 node->getInit()->traverse(this); 155 } 156 157 if (node->getCondition()) 158 { 159 node->getCondition()->traverse(this); 160 } 161 162 if (node->getExpression()) 163 { 164 node->getExpression()->traverse(this); 165 } 166 167 mTemporaryIndex = rewindIndex; 168 169 return false; 170 } 171 172 int UnfoldShortCircuit::getNextTemporaryIndex() 173 { 174 return mTemporaryIndex++; 175 } 176 } 177