Home | History | Annotate | Download | only in ARCMigrate
      1 //===--- Tranforms.cpp - Tranformations to ARC mode -----------------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 
     10 #include "Transforms.h"
     11 #include "Internals.h"
     12 #include "clang/Sema/SemaDiagnostic.h"
     13 #include "clang/AST/RecursiveASTVisitor.h"
     14 #include "clang/AST/StmtVisitor.h"
     15 #include "clang/Lex/Lexer.h"
     16 #include "clang/Basic/SourceManager.h"
     17 #include "llvm/ADT/StringSwitch.h"
     18 #include "llvm/ADT/DenseSet.h"
     19 #include <map>
     20 
     21 using namespace clang;
     22 using namespace arcmt;
     23 using namespace trans;
     24 
     25 ASTTraverser::~ASTTraverser() { }
     26 
     27 //===----------------------------------------------------------------------===//
     28 // Helpers.
     29 //===----------------------------------------------------------------------===//
     30 
     31 bool trans::canApplyWeak(ASTContext &Ctx, QualType type,
     32                          bool AllowOnUnknownClass) {
     33   if (!Ctx.getLangOpts().ObjCRuntimeHasWeak)
     34     return false;
     35 
     36   QualType T = type;
     37   if (T.isNull())
     38     return false;
     39 
     40   // iOS is always safe to use 'weak'.
     41   if (Ctx.getTargetInfo().getTriple().getOS() == llvm::Triple::IOS)
     42     AllowOnUnknownClass = true;
     43 
     44   while (const PointerType *ptr = T->getAs<PointerType>())
     45     T = ptr->getPointeeType();
     46   if (const ObjCObjectPointerType *ObjT = T->getAs<ObjCObjectPointerType>()) {
     47     ObjCInterfaceDecl *Class = ObjT->getInterfaceDecl();
     48     if (!AllowOnUnknownClass && (!Class || Class->getName() == "NSObject"))
     49       return false; // id/NSObject is not safe for weak.
     50     if (!AllowOnUnknownClass && !Class->hasDefinition())
     51       return false; // forward classes are not verifiable, therefore not safe.
     52     if (Class->isArcWeakrefUnavailable())
     53       return false;
     54   }
     55 
     56   return true;
     57 }
     58 
     59 /// \brief 'Loc' is the end of a statement range. This returns the location
     60 /// immediately after the semicolon following the statement.
     61 /// If no semicolon is found or the location is inside a macro, the returned
     62 /// source location will be invalid.
     63 SourceLocation trans::findLocationAfterSemi(SourceLocation loc,
     64                                             ASTContext &Ctx) {
     65   SourceLocation SemiLoc = findSemiAfterLocation(loc, Ctx);
     66   if (SemiLoc.isInvalid())
     67     return SourceLocation();
     68   return SemiLoc.getLocWithOffset(1);
     69 }
     70 
     71 /// \brief \arg Loc is the end of a statement range. This returns the location
     72 /// of the semicolon following the statement.
     73 /// If no semicolon is found or the location is inside a macro, the returned
     74 /// source location will be invalid.
     75 SourceLocation trans::findSemiAfterLocation(SourceLocation loc,
     76                                             ASTContext &Ctx) {
     77   SourceManager &SM = Ctx.getSourceManager();
     78   if (loc.isMacroID()) {
     79     if (!Lexer::isAtEndOfMacroExpansion(loc, SM, Ctx.getLangOpts(), &loc))
     80       return SourceLocation();
     81   }
     82   loc = Lexer::getLocForEndOfToken(loc, /*Offset=*/0, SM, Ctx.getLangOpts());
     83 
     84   // Break down the source location.
     85   std::pair<FileID, unsigned> locInfo = SM.getDecomposedLoc(loc);
     86 
     87   // Try to load the file buffer.
     88   bool invalidTemp = false;
     89   StringRef file = SM.getBufferData(locInfo.first, &invalidTemp);
     90   if (invalidTemp)
     91     return SourceLocation();
     92 
     93   const char *tokenBegin = file.data() + locInfo.second;
     94 
     95   // Lex from the start of the given location.
     96   Lexer lexer(SM.getLocForStartOfFile(locInfo.first),
     97               Ctx.getLangOpts(),
     98               file.begin(), tokenBegin, file.end());
     99   Token tok;
    100   lexer.LexFromRawLexer(tok);
    101   if (tok.isNot(tok::semi))
    102     return SourceLocation();
    103 
    104   return tok.getLocation();
    105 }
    106 
    107 bool trans::hasSideEffects(Expr *E, ASTContext &Ctx) {
    108   if (!E || !E->HasSideEffects(Ctx))
    109     return false;
    110 
    111   E = E->IgnoreParenCasts();
    112   ObjCMessageExpr *ME = dyn_cast<ObjCMessageExpr>(E);
    113   if (!ME)
    114     return true;
    115   switch (ME->getMethodFamily()) {
    116   case OMF_autorelease:
    117   case OMF_dealloc:
    118   case OMF_release:
    119   case OMF_retain:
    120     switch (ME->getReceiverKind()) {
    121     case ObjCMessageExpr::SuperInstance:
    122       return false;
    123     case ObjCMessageExpr::Instance:
    124       return hasSideEffects(ME->getInstanceReceiver(), Ctx);
    125     default:
    126       break;
    127     }
    128     break;
    129   default:
    130     break;
    131   }
    132 
    133   return true;
    134 }
    135 
    136 bool trans::isGlobalVar(Expr *E) {
    137   E = E->IgnoreParenCasts();
    138   if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E))
    139     return DRE->getDecl()->getDeclContext()->isFileContext() &&
    140            DRE->getDecl()->getLinkage() == ExternalLinkage;
    141   if (ConditionalOperator *condOp = dyn_cast<ConditionalOperator>(E))
    142     return isGlobalVar(condOp->getTrueExpr()) &&
    143            isGlobalVar(condOp->getFalseExpr());
    144 
    145   return false;
    146 }
    147 
    148 StringRef trans::getNilString(ASTContext &Ctx) {
    149   if (Ctx.Idents.get("nil").hasMacroDefinition())
    150     return "nil";
    151   else
    152     return "0";
    153 }
    154 
    155 namespace {
    156 
    157 class ReferenceClear : public RecursiveASTVisitor<ReferenceClear> {
    158   ExprSet &Refs;
    159 public:
    160   ReferenceClear(ExprSet &refs) : Refs(refs) { }
    161   bool VisitDeclRefExpr(DeclRefExpr *E) { Refs.erase(E); return true; }
    162 };
    163 
    164 class ReferenceCollector : public RecursiveASTVisitor<ReferenceCollector> {
    165   ValueDecl *Dcl;
    166   ExprSet &Refs;
    167 
    168 public:
    169   ReferenceCollector(ValueDecl *D, ExprSet &refs)
    170     : Dcl(D), Refs(refs) { }
    171 
    172   bool VisitDeclRefExpr(DeclRefExpr *E) {
    173     if (E->getDecl() == Dcl)
    174       Refs.insert(E);
    175     return true;
    176   }
    177 };
    178 
    179 class RemovablesCollector : public RecursiveASTVisitor<RemovablesCollector> {
    180   ExprSet &Removables;
    181 
    182 public:
    183   RemovablesCollector(ExprSet &removables)
    184   : Removables(removables) { }
    185 
    186   bool shouldWalkTypesOfTypeLocs() const { return false; }
    187 
    188   bool TraverseStmtExpr(StmtExpr *E) {
    189     CompoundStmt *S = E->getSubStmt();
    190     for (CompoundStmt::body_iterator
    191         I = S->body_begin(), E = S->body_end(); I != E; ++I) {
    192       if (I != E - 1)
    193         mark(*I);
    194       TraverseStmt(*I);
    195     }
    196     return true;
    197   }
    198 
    199   bool VisitCompoundStmt(CompoundStmt *S) {
    200     for (CompoundStmt::body_iterator
    201         I = S->body_begin(), E = S->body_end(); I != E; ++I)
    202       mark(*I);
    203     return true;
    204   }
    205 
    206   bool VisitIfStmt(IfStmt *S) {
    207     mark(S->getThen());
    208     mark(S->getElse());
    209     return true;
    210   }
    211 
    212   bool VisitWhileStmt(WhileStmt *S) {
    213     mark(S->getBody());
    214     return true;
    215   }
    216 
    217   bool VisitDoStmt(DoStmt *S) {
    218     mark(S->getBody());
    219     return true;
    220   }
    221 
    222   bool VisitForStmt(ForStmt *S) {
    223     mark(S->getInit());
    224     mark(S->getInc());
    225     mark(S->getBody());
    226     return true;
    227   }
    228 
    229 private:
    230   void mark(Stmt *S) {
    231     if (!S) return;
    232 
    233     while (LabelStmt *Label = dyn_cast<LabelStmt>(S))
    234       S = Label->getSubStmt();
    235     S = S->IgnoreImplicit();
    236     if (Expr *E = dyn_cast<Expr>(S))
    237       Removables.insert(E);
    238   }
    239 };
    240 
    241 } // end anonymous namespace
    242 
    243 void trans::clearRefsIn(Stmt *S, ExprSet &refs) {
    244   ReferenceClear(refs).TraverseStmt(S);
    245 }
    246 
    247 void trans::collectRefs(ValueDecl *D, Stmt *S, ExprSet &refs) {
    248   ReferenceCollector(D, refs).TraverseStmt(S);
    249 }
    250 
    251 void trans::collectRemovables(Stmt *S, ExprSet &exprs) {
    252   RemovablesCollector(exprs).TraverseStmt(S);
    253 }
    254 
    255 //===----------------------------------------------------------------------===//
    256 // MigrationContext
    257 //===----------------------------------------------------------------------===//
    258 
    259 namespace {
    260 
    261 class ASTTransform : public RecursiveASTVisitor<ASTTransform> {
    262   MigrationContext &MigrateCtx;
    263   typedef RecursiveASTVisitor<ASTTransform> base;
    264 
    265 public:
    266   ASTTransform(MigrationContext &MigrateCtx) : MigrateCtx(MigrateCtx) { }
    267 
    268   bool shouldWalkTypesOfTypeLocs() const { return false; }
    269 
    270   bool TraverseObjCImplementationDecl(ObjCImplementationDecl *D) {
    271     ObjCImplementationContext ImplCtx(MigrateCtx, D);
    272     for (MigrationContext::traverser_iterator
    273            I = MigrateCtx.traversers_begin(),
    274            E = MigrateCtx.traversers_end(); I != E; ++I)
    275       (*I)->traverseObjCImplementation(ImplCtx);
    276 
    277     return base::TraverseObjCImplementationDecl(D);
    278   }
    279 
    280   bool TraverseStmt(Stmt *rootS) {
    281     if (!rootS)
    282       return true;
    283 
    284     BodyContext BodyCtx(MigrateCtx, rootS);
    285     for (MigrationContext::traverser_iterator
    286            I = MigrateCtx.traversers_begin(),
    287            E = MigrateCtx.traversers_end(); I != E; ++I)
    288       (*I)->traverseBody(BodyCtx);
    289 
    290     return true;
    291   }
    292 };
    293 
    294 }
    295 
    296 MigrationContext::~MigrationContext() {
    297   for (traverser_iterator
    298          I = traversers_begin(), E = traversers_end(); I != E; ++I)
    299     delete *I;
    300 }
    301 
    302 bool MigrationContext::isGCOwnedNonObjC(QualType T) {
    303   while (!T.isNull()) {
    304     if (const AttributedType *AttrT = T->getAs<AttributedType>()) {
    305       if (AttrT->getAttrKind() == AttributedType::attr_objc_ownership)
    306         return !AttrT->getModifiedType()->isObjCRetainableType();
    307     }
    308 
    309     if (T->isArrayType())
    310       T = Pass.Ctx.getBaseElementType(T);
    311     else if (const PointerType *PT = T->getAs<PointerType>())
    312       T = PT->getPointeeType();
    313     else if (const ReferenceType *RT = T->getAs<ReferenceType>())
    314       T = RT->getPointeeType();
    315     else
    316       break;
    317   }
    318 
    319   return false;
    320 }
    321 
    322 bool MigrationContext::rewritePropertyAttribute(StringRef fromAttr,
    323                                                 StringRef toAttr,
    324                                                 SourceLocation atLoc) {
    325   if (atLoc.isMacroID())
    326     return false;
    327 
    328   SourceManager &SM = Pass.Ctx.getSourceManager();
    329 
    330   // Break down the source location.
    331   std::pair<FileID, unsigned> locInfo = SM.getDecomposedLoc(atLoc);
    332 
    333   // Try to load the file buffer.
    334   bool invalidTemp = false;
    335   StringRef file = SM.getBufferData(locInfo.first, &invalidTemp);
    336   if (invalidTemp)
    337     return false;
    338 
    339   const char *tokenBegin = file.data() + locInfo.second;
    340 
    341   // Lex from the start of the given location.
    342   Lexer lexer(SM.getLocForStartOfFile(locInfo.first),
    343               Pass.Ctx.getLangOpts(),
    344               file.begin(), tokenBegin, file.end());
    345   Token tok;
    346   lexer.LexFromRawLexer(tok);
    347   if (tok.isNot(tok::at)) return false;
    348   lexer.LexFromRawLexer(tok);
    349   if (tok.isNot(tok::raw_identifier)) return false;
    350   if (StringRef(tok.getRawIdentifierData(), tok.getLength())
    351         != "property")
    352     return false;
    353   lexer.LexFromRawLexer(tok);
    354   if (tok.isNot(tok::l_paren)) return false;
    355 
    356   Token BeforeTok = tok;
    357   Token AfterTok;
    358   AfterTok.startToken();
    359   SourceLocation AttrLoc;
    360 
    361   lexer.LexFromRawLexer(tok);
    362   if (tok.is(tok::r_paren))
    363     return false;
    364 
    365   while (1) {
    366     if (tok.isNot(tok::raw_identifier)) return false;
    367     StringRef ident(tok.getRawIdentifierData(), tok.getLength());
    368     if (ident == fromAttr) {
    369       if (!toAttr.empty()) {
    370         Pass.TA.replaceText(tok.getLocation(), fromAttr, toAttr);
    371         return true;
    372       }
    373       // We want to remove the attribute.
    374       AttrLoc = tok.getLocation();
    375     }
    376 
    377     do {
    378       lexer.LexFromRawLexer(tok);
    379       if (AttrLoc.isValid() && AfterTok.is(tok::unknown))
    380         AfterTok = tok;
    381     } while (tok.isNot(tok::comma) && tok.isNot(tok::r_paren));
    382     if (tok.is(tok::r_paren))
    383       break;
    384     if (AttrLoc.isInvalid())
    385       BeforeTok = tok;
    386     lexer.LexFromRawLexer(tok);
    387   }
    388 
    389   if (toAttr.empty() && AttrLoc.isValid() && AfterTok.isNot(tok::unknown)) {
    390     // We want to remove the attribute.
    391     if (BeforeTok.is(tok::l_paren) && AfterTok.is(tok::r_paren)) {
    392       Pass.TA.remove(SourceRange(BeforeTok.getLocation(),
    393                                  AfterTok.getLocation()));
    394     } else if (BeforeTok.is(tok::l_paren) && AfterTok.is(tok::comma)) {
    395       Pass.TA.remove(SourceRange(AttrLoc, AfterTok.getLocation()));
    396     } else {
    397       Pass.TA.remove(SourceRange(BeforeTok.getLocation(), AttrLoc));
    398     }
    399 
    400     return true;
    401   }
    402 
    403   return false;
    404 }
    405 
    406 bool MigrationContext::addPropertyAttribute(StringRef attr,
    407                                             SourceLocation atLoc) {
    408   if (atLoc.isMacroID())
    409     return false;
    410 
    411   SourceManager &SM = Pass.Ctx.getSourceManager();
    412 
    413   // Break down the source location.
    414   std::pair<FileID, unsigned> locInfo = SM.getDecomposedLoc(atLoc);
    415 
    416   // Try to load the file buffer.
    417   bool invalidTemp = false;
    418   StringRef file = SM.getBufferData(locInfo.first, &invalidTemp);
    419   if (invalidTemp)
    420     return false;
    421 
    422   const char *tokenBegin = file.data() + locInfo.second;
    423 
    424   // Lex from the start of the given location.
    425   Lexer lexer(SM.getLocForStartOfFile(locInfo.first),
    426               Pass.Ctx.getLangOpts(),
    427               file.begin(), tokenBegin, file.end());
    428   Token tok;
    429   lexer.LexFromRawLexer(tok);
    430   if (tok.isNot(tok::at)) return false;
    431   lexer.LexFromRawLexer(tok);
    432   if (tok.isNot(tok::raw_identifier)) return false;
    433   if (StringRef(tok.getRawIdentifierData(), tok.getLength())
    434         != "property")
    435     return false;
    436   lexer.LexFromRawLexer(tok);
    437 
    438   if (tok.isNot(tok::l_paren)) {
    439     Pass.TA.insert(tok.getLocation(), std::string("(") + attr.str() + ") ");
    440     return true;
    441   }
    442 
    443   lexer.LexFromRawLexer(tok);
    444   if (tok.is(tok::r_paren)) {
    445     Pass.TA.insert(tok.getLocation(), attr);
    446     return true;
    447   }
    448 
    449   if (tok.isNot(tok::raw_identifier)) return false;
    450 
    451   Pass.TA.insert(tok.getLocation(), std::string(attr) + ", ");
    452   return true;
    453 }
    454 
    455 void MigrationContext::traverse(TranslationUnitDecl *TU) {
    456   for (traverser_iterator
    457          I = traversers_begin(), E = traversers_end(); I != E; ++I)
    458     (*I)->traverseTU(*this);
    459 
    460   ASTTransform(*this).TraverseDecl(TU);
    461 }
    462 
    463 static void GCRewriteFinalize(MigrationPass &pass) {
    464   ASTContext &Ctx = pass.Ctx;
    465   TransformActions &TA = pass.TA;
    466   DeclContext *DC = Ctx.getTranslationUnitDecl();
    467   Selector FinalizeSel =
    468    Ctx.Selectors.getNullarySelector(&pass.Ctx.Idents.get("finalize"));
    469 
    470   typedef DeclContext::specific_decl_iterator<ObjCImplementationDecl>
    471   impl_iterator;
    472   for (impl_iterator I = impl_iterator(DC->decls_begin()),
    473        E = impl_iterator(DC->decls_end()); I != E; ++I) {
    474     for (ObjCImplementationDecl::instmeth_iterator
    475          MI = (*I)->instmeth_begin(),
    476          ME = (*I)->instmeth_end(); MI != ME; ++MI) {
    477       ObjCMethodDecl *MD = *MI;
    478       if (!MD->hasBody())
    479         continue;
    480 
    481       if (MD->isInstanceMethod() && MD->getSelector() == FinalizeSel) {
    482         ObjCMethodDecl *FinalizeM = MD;
    483         Transaction Trans(TA);
    484         TA.insert(FinalizeM->getSourceRange().getBegin(),
    485                   "#if !__has_feature(objc_arc)\n");
    486         CharSourceRange::getTokenRange(FinalizeM->getSourceRange());
    487         const SourceManager &SM = pass.Ctx.getSourceManager();
    488         const LangOptions &LangOpts = pass.Ctx.getLangOpts();
    489         bool Invalid;
    490         std::string str = "\n#endif\n";
    491         str += Lexer::getSourceText(
    492                   CharSourceRange::getTokenRange(FinalizeM->getSourceRange()),
    493                                     SM, LangOpts, &Invalid);
    494         TA.insertAfterToken(FinalizeM->getSourceRange().getEnd(), str);
    495 
    496         break;
    497       }
    498     }
    499   }
    500 }
    501 
    502 //===----------------------------------------------------------------------===//
    503 // getAllTransformations.
    504 //===----------------------------------------------------------------------===//
    505 
    506 static void traverseAST(MigrationPass &pass) {
    507   MigrationContext MigrateCtx(pass);
    508 
    509   if (pass.isGCMigration()) {
    510     MigrateCtx.addTraverser(new GCCollectableCallsTraverser);
    511     MigrateCtx.addTraverser(new GCAttrsTraverser());
    512   }
    513   MigrateCtx.addTraverser(new PropertyRewriteTraverser());
    514   MigrateCtx.addTraverser(new BlockObjCVariableTraverser());
    515 
    516   MigrateCtx.traverse(pass.Ctx.getTranslationUnitDecl());
    517 }
    518 
    519 static void independentTransforms(MigrationPass &pass) {
    520   rewriteAutoreleasePool(pass);
    521   removeRetainReleaseDeallocFinalize(pass);
    522   rewriteUnusedInitDelegate(pass);
    523   removeZeroOutPropsInDeallocFinalize(pass);
    524   makeAssignARCSafe(pass);
    525   rewriteUnbridgedCasts(pass);
    526   checkAPIUses(pass);
    527   traverseAST(pass);
    528 }
    529 
    530 std::vector<TransformFn> arcmt::getAllTransformations(
    531                                                LangOptions::GCMode OrigGCMode,
    532                                                bool NoFinalizeRemoval) {
    533   std::vector<TransformFn> transforms;
    534 
    535   if (OrigGCMode ==  LangOptions::GCOnly && NoFinalizeRemoval)
    536     transforms.push_back(GCRewriteFinalize);
    537   transforms.push_back(independentTransforms);
    538   // This depends on previous transformations removing various expressions.
    539   transforms.push_back(removeEmptyStatementsAndDeallocFinalize);
    540 
    541   return transforms;
    542 }
    543