Home | History | Annotate | Download | only in compiler
      1 //
      2 // Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
      3 // Use of this source code is governed by a BSD-style license that can be
      4 // found in the LICENSE file.
      5 //
      6 
      7 #include "compiler/DetectCallDepth.h"
      8 #include "compiler/InfoSink.h"
      9 
     10 DetectCallDepth::FunctionNode::FunctionNode(const TString& fname)
     11     : name(fname),
     12       visit(PreVisit)
     13 {
     14 }
     15 
     16 const TString& DetectCallDepth::FunctionNode::getName() const
     17 {
     18     return name;
     19 }
     20 
     21 void DetectCallDepth::FunctionNode::addCallee(
     22     DetectCallDepth::FunctionNode* callee)
     23 {
     24     for (size_t i = 0; i < callees.size(); ++i) {
     25         if (callees[i] == callee)
     26             return;
     27     }
     28     callees.push_back(callee);
     29 }
     30 
     31 int DetectCallDepth::FunctionNode::detectCallDepth(DetectCallDepth* detectCallDepth, int depth)
     32 {
     33     ASSERT(visit == PreVisit);
     34     ASSERT(detectCallDepth);
     35 
     36     int maxDepth = depth;
     37     visit = InVisit;
     38     for (size_t i = 0; i < callees.size(); ++i) {
     39         switch (callees[i]->visit) {
     40             case InVisit:
     41                 // cycle detected, i.e., recursion detected.
     42                 return kInfiniteCallDepth;
     43             case PostVisit:
     44                 break;
     45             case PreVisit: {
     46                 // Check before we recurse so we don't go too depth
     47                 if (detectCallDepth->checkExceedsMaxDepth(depth))
     48                     return depth;
     49                 int callDepth = callees[i]->detectCallDepth(detectCallDepth, depth + 1);
     50                 // Check after we recurse so we can exit immediately and provide info.
     51                 if (detectCallDepth->checkExceedsMaxDepth(callDepth)) {
     52                     detectCallDepth->getInfoSink().info << "<-" << callees[i]->getName();
     53                     return callDepth;
     54                 }
     55                 maxDepth = std::max(callDepth, maxDepth);
     56                 break;
     57             }
     58             default:
     59                 UNREACHABLE();
     60                 break;
     61         }
     62     }
     63     visit = PostVisit;
     64     return maxDepth;
     65 }
     66 
     67 void DetectCallDepth::FunctionNode::reset()
     68 {
     69     visit = PreVisit;
     70 }
     71 
     72 DetectCallDepth::DetectCallDepth(TInfoSink& infoSink, bool limitCallStackDepth, int maxCallStackDepth)
     73     : TIntermTraverser(true, false, true, false),
     74       currentFunction(NULL),
     75       infoSink(infoSink),
     76       maxDepth(limitCallStackDepth ? maxCallStackDepth : FunctionNode::kInfiniteCallDepth)
     77 {
     78 }
     79 
     80 DetectCallDepth::~DetectCallDepth()
     81 {
     82     for (size_t i = 0; i < functions.size(); ++i)
     83         delete functions[i];
     84 }
     85 
     86 bool DetectCallDepth::visitAggregate(Visit visit, TIntermAggregate* node)
     87 {
     88     switch (node->getOp())
     89     {
     90         case EOpPrototype:
     91             // Function declaration.
     92             // Don't add FunctionNode here because node->getName() is the
     93             // unmangled function name.
     94             break;
     95         case EOpFunction: {
     96             // Function definition.
     97             if (visit == PreVisit) {
     98                 currentFunction = findFunctionByName(node->getName());
     99                 if (currentFunction == NULL) {
    100                     currentFunction = new FunctionNode(node->getName());
    101                     functions.push_back(currentFunction);
    102                 }
    103             } else if (visit == PostVisit) {
    104                 currentFunction = NULL;
    105             }
    106             break;
    107         }
    108         case EOpFunctionCall: {
    109             // Function call.
    110             if (visit == PreVisit) {
    111                 FunctionNode* func = findFunctionByName(node->getName());
    112                 if (func == NULL) {
    113                     func = new FunctionNode(node->getName());
    114                     functions.push_back(func);
    115                 }
    116                 if (currentFunction)
    117                     currentFunction->addCallee(func);
    118             }
    119             break;
    120         }
    121         default:
    122             break;
    123     }
    124     return true;
    125 }
    126 
    127 bool DetectCallDepth::checkExceedsMaxDepth(int depth)
    128 {
    129     return depth >= maxDepth;
    130 }
    131 
    132 void DetectCallDepth::resetFunctionNodes()
    133 {
    134     for (size_t i = 0; i < functions.size(); ++i) {
    135         functions[i]->reset();
    136     }
    137 }
    138 
    139 DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepthForFunction(FunctionNode* func)
    140 {
    141     currentFunction = NULL;
    142     resetFunctionNodes();
    143 
    144     int maxCallDepth = func->detectCallDepth(this, 1);
    145 
    146     if (maxCallDepth == FunctionNode::kInfiniteCallDepth)
    147         return kErrorRecursion;
    148 
    149     if (maxCallDepth >= maxDepth)
    150         return kErrorMaxDepthExceeded;
    151 
    152     return kErrorNone;
    153 }
    154 
    155 DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepth()
    156 {
    157     if (maxDepth != FunctionNode::kInfiniteCallDepth) {
    158         // Check all functions because the driver may fail on them
    159         // TODO: Before detectingRecursion, strip unused functions.
    160         for (size_t i = 0; i < functions.size(); ++i) {
    161             ErrorCode error = detectCallDepthForFunction(functions[i]);
    162             if (error != kErrorNone)
    163                 return error;
    164         }
    165     } else {
    166         FunctionNode* main = findFunctionByName("main(");
    167         if (main == NULL)
    168             return kErrorMissingMain;
    169 
    170         return detectCallDepthForFunction(main);
    171     }
    172 
    173     return kErrorNone;
    174 }
    175 
    176 DetectCallDepth::FunctionNode* DetectCallDepth::findFunctionByName(
    177     const TString& name)
    178 {
    179     for (size_t i = 0; i < functions.size(); ++i) {
    180         if (functions[i]->getName() == name)
    181             return functions[i];
    182     }
    183     return NULL;
    184 }
    185 
    186