1 //===--- TransProtectedScope.cpp - Transformations to ARC mode ------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Adds brackets in case statements that "contain" initialization of retaining 11 // variable, thus emitting the "switch case is in protected scope" error. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "Transforms.h" 16 #include "Internals.h" 17 #include "clang/AST/ASTContext.h" 18 #include "clang/Sema/SemaDiagnostic.h" 19 20 using namespace clang; 21 using namespace arcmt; 22 using namespace trans; 23 24 namespace { 25 26 class LocalRefsCollector : public RecursiveASTVisitor<LocalRefsCollector> { 27 SmallVectorImpl<DeclRefExpr *> &Refs; 28 29 public: 30 LocalRefsCollector(SmallVectorImpl<DeclRefExpr *> &refs) 31 : Refs(refs) { } 32 33 bool VisitDeclRefExpr(DeclRefExpr *E) { 34 if (ValueDecl *D = E->getDecl()) 35 if (D->getDeclContext()->getRedeclContext()->isFunctionOrMethod()) 36 Refs.push_back(E); 37 return true; 38 } 39 }; 40 41 struct CaseInfo { 42 SwitchCase *SC; 43 SourceRange Range; 44 enum { 45 St_Unchecked, 46 St_CannotFix, 47 St_Fixed 48 } State; 49 50 CaseInfo() : SC(0), State(St_Unchecked) {} 51 CaseInfo(SwitchCase *S, SourceRange Range) 52 : SC(S), Range(Range), State(St_Unchecked) {} 53 }; 54 55 class CaseCollector : public RecursiveASTVisitor<CaseCollector> { 56 ParentMap &PMap; 57 SmallVectorImpl<CaseInfo> &Cases; 58 59 public: 60 CaseCollector(ParentMap &PMap, SmallVectorImpl<CaseInfo> &Cases) 61 : PMap(PMap), Cases(Cases) { } 62 63 bool VisitSwitchStmt(SwitchStmt *S) { 64 SwitchCase *Curr = S->getSwitchCaseList(); 65 if (!Curr) 66 return true; 67 Stmt *Parent = getCaseParent(Curr); 68 Curr = Curr->getNextSwitchCase(); 69 // Make sure all case statements are in the same scope. 70 while (Curr) { 71 if (getCaseParent(Curr) != Parent) 72 return true; 73 Curr = Curr->getNextSwitchCase(); 74 } 75 76 SourceLocation NextLoc = S->getLocEnd(); 77 Curr = S->getSwitchCaseList(); 78 // We iterate over case statements in reverse source-order. 79 while (Curr) { 80 Cases.push_back(CaseInfo(Curr,SourceRange(Curr->getLocStart(), NextLoc))); 81 NextLoc = Curr->getLocStart(); 82 Curr = Curr->getNextSwitchCase(); 83 } 84 return true; 85 } 86 87 Stmt *getCaseParent(SwitchCase *S) { 88 Stmt *Parent = PMap.getParent(S); 89 while (Parent && (isa<SwitchCase>(Parent) || isa<LabelStmt>(Parent))) 90 Parent = PMap.getParent(Parent); 91 return Parent; 92 } 93 }; 94 95 class ProtectedScopeFixer { 96 MigrationPass &Pass; 97 SourceManager &SM; 98 SmallVector<CaseInfo, 16> Cases; 99 SmallVector<DeclRefExpr *, 16> LocalRefs; 100 101 public: 102 ProtectedScopeFixer(BodyContext &BodyCtx) 103 : Pass(BodyCtx.getMigrationContext().Pass), 104 SM(Pass.Ctx.getSourceManager()) { 105 106 CaseCollector(BodyCtx.getParentMap(), Cases) 107 .TraverseStmt(BodyCtx.getTopStmt()); 108 LocalRefsCollector(LocalRefs).TraverseStmt(BodyCtx.getTopStmt()); 109 110 SourceRange BodyRange = BodyCtx.getTopStmt()->getSourceRange(); 111 const CapturedDiagList &DiagList = Pass.getDiags(); 112 // Copy the diagnostics so we don't have to worry about invaliding iterators 113 // from the diagnostic list. 114 SmallVector<StoredDiagnostic, 16> StoredDiags; 115 StoredDiags.append(DiagList.begin(), DiagList.end()); 116 SmallVectorImpl<StoredDiagnostic>::iterator 117 I = StoredDiags.begin(), E = StoredDiags.end(); 118 while (I != E) { 119 if (I->getID() == diag::err_switch_into_protected_scope && 120 isInRange(I->getLocation(), BodyRange)) { 121 handleProtectedScopeError(I, E); 122 continue; 123 } 124 ++I; 125 } 126 } 127 128 void handleProtectedScopeError( 129 SmallVectorImpl<StoredDiagnostic>::iterator &DiagI, 130 SmallVectorImpl<StoredDiagnostic>::iterator DiagE){ 131 Transaction Trans(Pass.TA); 132 assert(DiagI->getID() == diag::err_switch_into_protected_scope); 133 SourceLocation ErrLoc = DiagI->getLocation(); 134 bool handledAllNotes = true; 135 ++DiagI; 136 for (; DiagI != DiagE && DiagI->getLevel() == DiagnosticsEngine::Note; 137 ++DiagI) { 138 if (!handleProtectedNote(*DiagI)) 139 handledAllNotes = false; 140 } 141 142 if (handledAllNotes) 143 Pass.TA.clearDiagnostic(diag::err_switch_into_protected_scope, ErrLoc); 144 } 145 146 bool handleProtectedNote(const StoredDiagnostic &Diag) { 147 assert(Diag.getLevel() == DiagnosticsEngine::Note); 148 149 for (unsigned i = 0; i != Cases.size(); i++) { 150 CaseInfo &info = Cases[i]; 151 if (isInRange(Diag.getLocation(), info.Range)) { 152 153 if (info.State == CaseInfo::St_Unchecked) 154 tryFixing(info); 155 assert(info.State != CaseInfo::St_Unchecked); 156 157 if (info.State == CaseInfo::St_Fixed) { 158 Pass.TA.clearDiagnostic(Diag.getID(), Diag.getLocation()); 159 return true; 160 } 161 return false; 162 } 163 } 164 165 return false; 166 } 167 168 void tryFixing(CaseInfo &info) { 169 assert(info.State == CaseInfo::St_Unchecked); 170 if (hasVarReferencedOutside(info)) { 171 info.State = CaseInfo::St_CannotFix; 172 return; 173 } 174 175 Pass.TA.insertAfterToken(info.SC->getColonLoc(), " {"); 176 Pass.TA.insert(info.Range.getEnd(), "}\n"); 177 info.State = CaseInfo::St_Fixed; 178 } 179 180 bool hasVarReferencedOutside(CaseInfo &info) { 181 for (unsigned i = 0, e = LocalRefs.size(); i != e; ++i) { 182 DeclRefExpr *DRE = LocalRefs[i]; 183 if (isInRange(DRE->getDecl()->getLocation(), info.Range) && 184 !isInRange(DRE->getLocation(), info.Range)) 185 return true; 186 } 187 return false; 188 } 189 190 bool isInRange(SourceLocation Loc, SourceRange R) { 191 if (Loc.isInvalid()) 192 return false; 193 return !SM.isBeforeInTranslationUnit(Loc, R.getBegin()) && 194 SM.isBeforeInTranslationUnit(Loc, R.getEnd()); 195 } 196 }; 197 198 } // anonymous namespace 199 200 void ProtectedScopeTraverser::traverseBody(BodyContext &BodyCtx) { 201 ProtectedScopeFixer Fix(BodyCtx); 202 } 203