Home | History | Annotate | Download | only in ARCMigrate
      1 //===--- TransProtectedScope.cpp - Transformations to ARC mode ------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // Adds brackets in case statements that "contain" initialization of retaining
     11 // variable, thus emitting the "switch case is in protected scope" error.
     12 //
     13 //===----------------------------------------------------------------------===//
     14 
     15 #include "Transforms.h"
     16 #include "Internals.h"
     17 #include "clang/AST/ASTContext.h"
     18 #include "clang/Sema/SemaDiagnostic.h"
     19 
     20 using namespace clang;
     21 using namespace arcmt;
     22 using namespace trans;
     23 
     24 namespace {
     25 
     26 class LocalRefsCollector : public RecursiveASTVisitor<LocalRefsCollector> {
     27   SmallVectorImpl<DeclRefExpr *> &Refs;
     28 
     29 public:
     30   LocalRefsCollector(SmallVectorImpl<DeclRefExpr *> &refs)
     31     : Refs(refs) { }
     32 
     33   bool VisitDeclRefExpr(DeclRefExpr *E) {
     34     if (ValueDecl *D = E->getDecl())
     35       if (D->getDeclContext()->getRedeclContext()->isFunctionOrMethod())
     36         Refs.push_back(E);
     37     return true;
     38   }
     39 };
     40 
     41 struct CaseInfo {
     42   SwitchCase *SC;
     43   SourceRange Range;
     44   enum {
     45     St_Unchecked,
     46     St_CannotFix,
     47     St_Fixed
     48   } State;
     49 
     50   CaseInfo() : SC(nullptr), State(St_Unchecked) {}
     51   CaseInfo(SwitchCase *S, SourceRange Range)
     52     : SC(S), Range(Range), State(St_Unchecked) {}
     53 };
     54 
     55 class CaseCollector : public RecursiveASTVisitor<CaseCollector> {
     56   ParentMap &PMap;
     57   SmallVectorImpl<CaseInfo> &Cases;
     58 
     59 public:
     60   CaseCollector(ParentMap &PMap, SmallVectorImpl<CaseInfo> &Cases)
     61     : PMap(PMap), Cases(Cases) { }
     62 
     63   bool VisitSwitchStmt(SwitchStmt *S) {
     64     SwitchCase *Curr = S->getSwitchCaseList();
     65     if (!Curr)
     66       return true;
     67     Stmt *Parent = getCaseParent(Curr);
     68     Curr = Curr->getNextSwitchCase();
     69     // Make sure all case statements are in the same scope.
     70     while (Curr) {
     71       if (getCaseParent(Curr) != Parent)
     72         return true;
     73       Curr = Curr->getNextSwitchCase();
     74     }
     75 
     76     SourceLocation NextLoc = S->getLocEnd();
     77     Curr = S->getSwitchCaseList();
     78     // We iterate over case statements in reverse source-order.
     79     while (Curr) {
     80       Cases.push_back(CaseInfo(Curr,SourceRange(Curr->getLocStart(), NextLoc)));
     81       NextLoc = Curr->getLocStart();
     82       Curr = Curr->getNextSwitchCase();
     83     }
     84     return true;
     85   }
     86 
     87   Stmt *getCaseParent(SwitchCase *S) {
     88     Stmt *Parent = PMap.getParent(S);
     89     while (Parent && (isa<SwitchCase>(Parent) || isa<LabelStmt>(Parent)))
     90       Parent = PMap.getParent(Parent);
     91     return Parent;
     92   }
     93 };
     94 
     95 class ProtectedScopeFixer {
     96   MigrationPass &Pass;
     97   SourceManager &SM;
     98   SmallVector<CaseInfo, 16> Cases;
     99   SmallVector<DeclRefExpr *, 16> LocalRefs;
    100 
    101 public:
    102   ProtectedScopeFixer(BodyContext &BodyCtx)
    103     : Pass(BodyCtx.getMigrationContext().Pass),
    104       SM(Pass.Ctx.getSourceManager()) {
    105 
    106     CaseCollector(BodyCtx.getParentMap(), Cases)
    107         .TraverseStmt(BodyCtx.getTopStmt());
    108     LocalRefsCollector(LocalRefs).TraverseStmt(BodyCtx.getTopStmt());
    109 
    110     SourceRange BodyRange = BodyCtx.getTopStmt()->getSourceRange();
    111     const CapturedDiagList &DiagList = Pass.getDiags();
    112     // Copy the diagnostics so we don't have to worry about invaliding iterators
    113     // from the diagnostic list.
    114     SmallVector<StoredDiagnostic, 16> StoredDiags;
    115     StoredDiags.append(DiagList.begin(), DiagList.end());
    116     SmallVectorImpl<StoredDiagnostic>::iterator
    117         I = StoredDiags.begin(), E = StoredDiags.end();
    118     while (I != E) {
    119       if (I->getID() == diag::err_switch_into_protected_scope &&
    120           isInRange(I->getLocation(), BodyRange)) {
    121         handleProtectedScopeError(I, E);
    122         continue;
    123       }
    124       ++I;
    125     }
    126   }
    127 
    128   void handleProtectedScopeError(
    129                              SmallVectorImpl<StoredDiagnostic>::iterator &DiagI,
    130                              SmallVectorImpl<StoredDiagnostic>::iterator DiagE){
    131     Transaction Trans(Pass.TA);
    132     assert(DiagI->getID() == diag::err_switch_into_protected_scope);
    133     SourceLocation ErrLoc = DiagI->getLocation();
    134     bool handledAllNotes = true;
    135     ++DiagI;
    136     for (; DiagI != DiagE && DiagI->getLevel() == DiagnosticsEngine::Note;
    137          ++DiagI) {
    138       if (!handleProtectedNote(*DiagI))
    139         handledAllNotes = false;
    140     }
    141 
    142     if (handledAllNotes)
    143       Pass.TA.clearDiagnostic(diag::err_switch_into_protected_scope, ErrLoc);
    144   }
    145 
    146   bool handleProtectedNote(const StoredDiagnostic &Diag) {
    147     assert(Diag.getLevel() == DiagnosticsEngine::Note);
    148 
    149     for (unsigned i = 0; i != Cases.size(); i++) {
    150       CaseInfo &info = Cases[i];
    151       if (isInRange(Diag.getLocation(), info.Range)) {
    152 
    153         if (info.State == CaseInfo::St_Unchecked)
    154           tryFixing(info);
    155         assert(info.State != CaseInfo::St_Unchecked);
    156 
    157         if (info.State == CaseInfo::St_Fixed) {
    158           Pass.TA.clearDiagnostic(Diag.getID(), Diag.getLocation());
    159           return true;
    160         }
    161         return false;
    162       }
    163     }
    164 
    165     return false;
    166   }
    167 
    168   void tryFixing(CaseInfo &info) {
    169     assert(info.State == CaseInfo::St_Unchecked);
    170     if (hasVarReferencedOutside(info)) {
    171       info.State = CaseInfo::St_CannotFix;
    172       return;
    173     }
    174 
    175     Pass.TA.insertAfterToken(info.SC->getColonLoc(), " {");
    176     Pass.TA.insert(info.Range.getEnd(), "}\n");
    177     info.State = CaseInfo::St_Fixed;
    178   }
    179 
    180   bool hasVarReferencedOutside(CaseInfo &info) {
    181     for (unsigned i = 0, e = LocalRefs.size(); i != e; ++i) {
    182       DeclRefExpr *DRE = LocalRefs[i];
    183       if (isInRange(DRE->getDecl()->getLocation(), info.Range) &&
    184           !isInRange(DRE->getLocation(), info.Range))
    185         return true;
    186     }
    187     return false;
    188   }
    189 
    190   bool isInRange(SourceLocation Loc, SourceRange R) {
    191     if (Loc.isInvalid())
    192       return false;
    193     return !SM.isBeforeInTranslationUnit(Loc, R.getBegin()) &&
    194             SM.isBeforeInTranslationUnit(Loc, R.getEnd());
    195   }
    196 };
    197 
    198 } // anonymous namespace
    199 
    200 void ProtectedScopeTraverser::traverseBody(BodyContext &BodyCtx) {
    201   ProtectedScopeFixer Fix(BodyCtx);
    202 }
    203