Home | History | Annotate | Download | only in compiler
      1 // Copyright 2016 The SwiftShader Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //    http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "AnalyzeCallDepth.h"
     16 
     17 static TIntermSequence::iterator
     18 traverseCaseBody(AnalyzeCallDepth* analysis,
     19 				 TIntermSequence::iterator& start,
     20 				 const TIntermSequence::iterator& end) {
     21 	TIntermSequence::iterator current = start;
     22 	for (++current; current != end; ++current)
     23 	{
     24 		(*current)->traverse(analysis);
     25 		if((*current)->getAsBranchNode()) // Kill, Break, Continue or Return
     26 		{
     27 			break;
     28 		}
     29 	}
     30 	return current;
     31 }
     32 
     33 
     34 AnalyzeCallDepth::FunctionNode::FunctionNode(TIntermAggregate *node) : node(node)
     35 {
     36 	visit = PreVisit;
     37 	callDepth = 0;
     38 }
     39 
     40 const TString &AnalyzeCallDepth::FunctionNode::getName() const
     41 {
     42 	return node->getName();
     43 }
     44 
     45 void AnalyzeCallDepth::FunctionNode::addCallee(AnalyzeCallDepth::FunctionNode *callee)
     46 {
     47 	for(size_t i = 0; i < callees.size(); i++)
     48 	{
     49 		if(callees[i] == callee)
     50 		{
     51 			return;
     52 		}
     53 	}
     54 
     55 	callees.push_back(callee);
     56 }
     57 
     58 unsigned int AnalyzeCallDepth::FunctionNode::analyzeCallDepth(AnalyzeCallDepth *analyzeCallDepth)
     59 {
     60 	ASSERT(visit == PreVisit);
     61 	ASSERT(analyzeCallDepth);
     62 
     63 	callDepth = 0;
     64 	visit = InVisit;
     65 
     66 	for(size_t i = 0; i < callees.size(); i++)
     67 	{
     68 		unsigned int calleeDepth = 0;
     69 		switch(callees[i]->visit)
     70 		{
     71 		case InVisit:
     72 			// Cycle detected (recursion)
     73 			return UINT_MAX;
     74 		case PostVisit:
     75 			calleeDepth = callees[i]->getLastDepth();
     76 			break;
     77 		case PreVisit:
     78 			calleeDepth = callees[i]->analyzeCallDepth(analyzeCallDepth);
     79 			break;
     80 		default:
     81 			UNREACHABLE(callees[i]->visit);
     82 			break;
     83 		}
     84 		if(calleeDepth != UINT_MAX) ++calleeDepth;
     85 		callDepth = std::max(callDepth, calleeDepth);
     86 	}
     87 
     88 	visit = PostVisit;
     89 	return callDepth;
     90 }
     91 
     92 unsigned int AnalyzeCallDepth::FunctionNode::getLastDepth() const
     93 {
     94 	return callDepth;
     95 }
     96 
     97 void AnalyzeCallDepth::FunctionNode::removeIfUnreachable()
     98 {
     99 	if(visit == PreVisit)
    100 	{
    101 		node->setOp(EOpPrototype);
    102 		node->getSequence().resize(1);   // Remove function body
    103 	}
    104 }
    105 
    106 AnalyzeCallDepth::AnalyzeCallDepth(TIntermNode *root)
    107 	: TIntermTraverser(true, false, true, false),
    108 	  currentFunction(0)
    109 {
    110 	root->traverse(this);
    111 }
    112 
    113 AnalyzeCallDepth::~AnalyzeCallDepth()
    114 {
    115 	for(size_t i = 0; i < functions.size(); i++)
    116 	{
    117 		delete functions[i];
    118 	}
    119 }
    120 
    121 bool AnalyzeCallDepth::visitSwitch(Visit visit, TIntermSwitch *node)
    122 {
    123 	TIntermTyped* switchValue = node->getInit();
    124 	TIntermAggregate* opList = node->getStatementList();
    125 
    126 	if(!switchValue || !opList)
    127 	{
    128 		return false;
    129 	}
    130 
    131 	// TODO: We need to dig into switch statement cases from
    132 	// visitSwitch for all traversers. Is there a way to
    133 	// preserve existing functionality while moving the iteration
    134 	// to the general traverser?
    135 	TIntermSequence& sequence = opList->getSequence();
    136 	TIntermSequence::iterator it = sequence.begin();
    137 	TIntermSequence::iterator defaultIt = sequence.end();
    138 	for(; it != sequence.end(); ++it)
    139 	{
    140 		TIntermCase* currentCase = (*it)->getAsCaseNode();
    141 		if(currentCase)
    142 		{
    143 			TIntermSequence::iterator caseIt = it;
    144 			TIntermTyped* condition = currentCase->getCondition();
    145 			if(condition) // non default case
    146 			{
    147 				condition->traverse(this);
    148 				traverseCaseBody(this, caseIt, sequence.end());
    149 			}
    150 			else
    151 			{
    152 				defaultIt = it; // The default case might not be the last case, keep it for last
    153 			}
    154 		}
    155 	}
    156 
    157 	// If there's a default case, traverse it here
    158 	if(defaultIt != sequence.end())
    159 	{
    160 		traverseCaseBody(this, defaultIt, sequence.end());
    161 	}
    162 	return false;
    163 }
    164 
    165 bool AnalyzeCallDepth::visitAggregate(Visit visit, TIntermAggregate *node)
    166 {
    167 	switch(node->getOp())
    168 	{
    169 	case EOpFunction:   // Function definition
    170 		{
    171 			if(visit == PreVisit)
    172 			{
    173 				currentFunction = findFunctionByName(node->getName());
    174 
    175 				if(!currentFunction)
    176 				{
    177 					currentFunction = new FunctionNode(node);
    178 					functions.push_back(currentFunction);
    179 				}
    180 			}
    181 			else if(visit == PostVisit)
    182 			{
    183 				currentFunction = 0;
    184 			}
    185 		}
    186 		break;
    187 	case EOpFunctionCall:
    188 		{
    189 			if(!node->isUserDefined())
    190 			{
    191 				return true;   // Check the arguments for function calls
    192 			}
    193 
    194 			if(visit == PreVisit)
    195 			{
    196 				FunctionNode *function = findFunctionByName(node->getName());
    197 
    198 				if(!function)
    199 				{
    200 					function = new FunctionNode(node);
    201 					functions.push_back(function);
    202 				}
    203 
    204 				if(currentFunction)
    205 				{
    206 					currentFunction->addCallee(function);
    207 				}
    208 				else
    209 				{
    210 					globalFunctionCalls.insert(function);
    211 				}
    212 			}
    213 		}
    214 		break;
    215 	default:
    216 		break;
    217 	}
    218 
    219 	return true;
    220 }
    221 
    222 unsigned int AnalyzeCallDepth::analyzeCallDepth()
    223 {
    224 	FunctionNode *main = findFunctionByName("main(");
    225 
    226 	if(!main)
    227 	{
    228 		return 0;
    229 	}
    230 
    231 	unsigned int depth = main->analyzeCallDepth(this);
    232 	if(depth != UINT_MAX) ++depth;
    233 
    234 	for(FunctionSet::iterator globalCall = globalFunctionCalls.begin(); globalCall != globalFunctionCalls.end(); globalCall++)
    235 	{
    236 		unsigned int globalDepth = (*globalCall)->analyzeCallDepth(this);
    237 		if(globalDepth != UINT_MAX) ++globalDepth;
    238 
    239 		if(globalDepth > depth)
    240 		{
    241 			depth = globalDepth;
    242 		}
    243 	}
    244 
    245 	for(size_t i = 0; i < functions.size(); i++)
    246 	{
    247 		functions[i]->removeIfUnreachable();
    248 	}
    249 
    250 	return depth;
    251 }
    252 
    253 AnalyzeCallDepth::FunctionNode *AnalyzeCallDepth::findFunctionByName(const TString &name)
    254 {
    255 	for(size_t i = 0; i < functions.size(); i++)
    256 	{
    257 		if(functions[i]->getName() == name)
    258 		{
    259 			return functions[i];
    260 		}
    261 	}
    262 
    263 	return 0;
    264 }
    265 
    266