1 // Copyright 2016 The SwiftShader Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "AnalyzeCallDepth.h" 16 17 static TIntermSequence::iterator 18 traverseCaseBody(AnalyzeCallDepth* analysis, 19 TIntermSequence::iterator& start, 20 const TIntermSequence::iterator& end) { 21 TIntermSequence::iterator current = start; 22 for (++current; current != end; ++current) 23 { 24 (*current)->traverse(analysis); 25 if((*current)->getAsBranchNode()) // Kill, Break, Continue or Return 26 { 27 break; 28 } 29 } 30 return current; 31 } 32 33 34 AnalyzeCallDepth::FunctionNode::FunctionNode(TIntermAggregate *node) : node(node) 35 { 36 visit = PreVisit; 37 callDepth = 0; 38 } 39 40 const TString &AnalyzeCallDepth::FunctionNode::getName() const 41 { 42 return node->getName(); 43 } 44 45 void AnalyzeCallDepth::FunctionNode::addCallee(AnalyzeCallDepth::FunctionNode *callee) 46 { 47 for(size_t i = 0; i < callees.size(); i++) 48 { 49 if(callees[i] == callee) 50 { 51 return; 52 } 53 } 54 55 callees.push_back(callee); 56 } 57 58 unsigned int AnalyzeCallDepth::FunctionNode::analyzeCallDepth(AnalyzeCallDepth *analyzeCallDepth) 59 { 60 ASSERT(visit == PreVisit); 61 ASSERT(analyzeCallDepth); 62 63 callDepth = 0; 64 visit = InVisit; 65 66 for(size_t i = 0; i < callees.size(); i++) 67 { 68 unsigned int calleeDepth = 0; 69 switch(callees[i]->visit) 70 { 71 case InVisit: 72 // Cycle detected (recursion) 73 return UINT_MAX; 74 case PostVisit: 75 calleeDepth = callees[i]->getLastDepth(); 76 break; 77 case PreVisit: 78 calleeDepth = callees[i]->analyzeCallDepth(analyzeCallDepth); 79 break; 80 default: 81 UNREACHABLE(callees[i]->visit); 82 break; 83 } 84 if(calleeDepth != UINT_MAX) ++calleeDepth; 85 callDepth = std::max(callDepth, calleeDepth); 86 } 87 88 visit = PostVisit; 89 return callDepth; 90 } 91 92 unsigned int AnalyzeCallDepth::FunctionNode::getLastDepth() const 93 { 94 return callDepth; 95 } 96 97 void AnalyzeCallDepth::FunctionNode::removeIfUnreachable() 98 { 99 if(visit == PreVisit) 100 { 101 node->setOp(EOpPrototype); 102 node->getSequence().resize(1); // Remove function body 103 } 104 } 105 106 AnalyzeCallDepth::AnalyzeCallDepth(TIntermNode *root) 107 : TIntermTraverser(true, false, true, false), 108 currentFunction(0) 109 { 110 root->traverse(this); 111 } 112 113 AnalyzeCallDepth::~AnalyzeCallDepth() 114 { 115 for(size_t i = 0; i < functions.size(); i++) 116 { 117 delete functions[i]; 118 } 119 } 120 121 bool AnalyzeCallDepth::visitSwitch(Visit visit, TIntermSwitch *node) 122 { 123 TIntermTyped* switchValue = node->getInit(); 124 TIntermAggregate* opList = node->getStatementList(); 125 126 if(!switchValue || !opList) 127 { 128 return false; 129 } 130 131 // TODO: We need to dig into switch statement cases from 132 // visitSwitch for all traversers. Is there a way to 133 // preserve existing functionality while moving the iteration 134 // to the general traverser? 135 TIntermSequence& sequence = opList->getSequence(); 136 TIntermSequence::iterator it = sequence.begin(); 137 TIntermSequence::iterator defaultIt = sequence.end(); 138 for(; it != sequence.end(); ++it) 139 { 140 TIntermCase* currentCase = (*it)->getAsCaseNode(); 141 if(currentCase) 142 { 143 TIntermSequence::iterator caseIt = it; 144 TIntermTyped* condition = currentCase->getCondition(); 145 if(condition) // non default case 146 { 147 condition->traverse(this); 148 traverseCaseBody(this, caseIt, sequence.end()); 149 } 150 else 151 { 152 defaultIt = it; // The default case might not be the last case, keep it for last 153 } 154 } 155 } 156 157 // If there's a default case, traverse it here 158 if(defaultIt != sequence.end()) 159 { 160 traverseCaseBody(this, defaultIt, sequence.end()); 161 } 162 return false; 163 } 164 165 bool AnalyzeCallDepth::visitAggregate(Visit visit, TIntermAggregate *node) 166 { 167 switch(node->getOp()) 168 { 169 case EOpFunction: // Function definition 170 { 171 if(visit == PreVisit) 172 { 173 currentFunction = findFunctionByName(node->getName()); 174 175 if(!currentFunction) 176 { 177 currentFunction = new FunctionNode(node); 178 functions.push_back(currentFunction); 179 } 180 } 181 else if(visit == PostVisit) 182 { 183 currentFunction = 0; 184 } 185 } 186 break; 187 case EOpFunctionCall: 188 { 189 if(!node->isUserDefined()) 190 { 191 return true; // Check the arguments for function calls 192 } 193 194 if(visit == PreVisit) 195 { 196 FunctionNode *function = findFunctionByName(node->getName()); 197 198 if(!function) 199 { 200 function = new FunctionNode(node); 201 functions.push_back(function); 202 } 203 204 if(currentFunction) 205 { 206 currentFunction->addCallee(function); 207 } 208 else 209 { 210 globalFunctionCalls.insert(function); 211 } 212 } 213 } 214 break; 215 default: 216 break; 217 } 218 219 return true; 220 } 221 222 unsigned int AnalyzeCallDepth::analyzeCallDepth() 223 { 224 FunctionNode *main = findFunctionByName("main("); 225 226 if(!main) 227 { 228 return 0; 229 } 230 231 unsigned int depth = main->analyzeCallDepth(this); 232 if(depth != UINT_MAX) ++depth; 233 234 for(FunctionSet::iterator globalCall = globalFunctionCalls.begin(); globalCall != globalFunctionCalls.end(); globalCall++) 235 { 236 unsigned int globalDepth = (*globalCall)->analyzeCallDepth(this); 237 if(globalDepth != UINT_MAX) ++globalDepth; 238 239 if(globalDepth > depth) 240 { 241 depth = globalDepth; 242 } 243 } 244 245 for(size_t i = 0; i < functions.size(); i++) 246 { 247 functions[i]->removeIfUnreachable(); 248 } 249 250 return depth; 251 } 252 253 AnalyzeCallDepth::FunctionNode *AnalyzeCallDepth::findFunctionByName(const TString &name) 254 { 255 for(size_t i = 0; i < functions.size(); i++) 256 { 257 if(functions[i]->getName() == name) 258 { 259 return functions[i]; 260 } 261 } 262 263 return 0; 264 } 265 266