Home | History | Annotate | Download | only in ARCMigrate
      1 //===--- TransAutoreleasePool.cpp - Transformations to ARC mode -----------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // rewriteAutoreleasePool:
     11 //
     12 // Calls to NSAutoreleasePools will be rewritten as an @autorelease scope.
     13 //
     14 //  NSAutoreleasePool *pool = [[NSAutoreleasePool alloc] init];
     15 //  ...
     16 //  [pool release];
     17 // ---->
     18 //  @autorelease {
     19 //  ...
     20 //  }
     21 //
     22 // An NSAutoreleasePool will not be touched if:
     23 // - There is not a corresponding -release/-drain in the same scope
     24 // - Not all references of the NSAutoreleasePool variable can be removed
     25 // - There is a variable that is declared inside the intended @autorelease scope
     26 //   which is also used outside it.
     27 //
     28 //===----------------------------------------------------------------------===//
     29 
     30 #include "Transforms.h"
     31 #include "Internals.h"
     32 #include "clang/AST/ASTContext.h"
     33 #include "clang/Basic/SourceManager.h"
     34 #include "clang/Sema/SemaDiagnostic.h"
     35 #include <map>
     36 
     37 using namespace clang;
     38 using namespace arcmt;
     39 using namespace trans;
     40 
     41 namespace {
     42 
     43 class ReleaseCollector : public RecursiveASTVisitor<ReleaseCollector> {
     44   Decl *Dcl;
     45   SmallVectorImpl<ObjCMessageExpr *> &Releases;
     46 
     47 public:
     48   ReleaseCollector(Decl *D, SmallVectorImpl<ObjCMessageExpr *> &releases)
     49     : Dcl(D), Releases(releases) { }
     50 
     51   bool VisitObjCMessageExpr(ObjCMessageExpr *E) {
     52     if (!E->isInstanceMessage())
     53       return true;
     54     if (E->getMethodFamily() != OMF_release)
     55       return true;
     56     Expr *instance = E->getInstanceReceiver()->IgnoreParenCasts();
     57     if (DeclRefExpr *DE = dyn_cast<DeclRefExpr>(instance)) {
     58       if (DE->getDecl() == Dcl)
     59         Releases.push_back(E);
     60     }
     61     return true;
     62   }
     63 };
     64 
     65 }
     66 
     67 namespace {
     68 
     69 class AutoreleasePoolRewriter
     70                          : public RecursiveASTVisitor<AutoreleasePoolRewriter> {
     71 public:
     72   AutoreleasePoolRewriter(MigrationPass &pass)
     73     : Body(0), Pass(pass) {
     74     PoolII = &pass.Ctx.Idents.get("NSAutoreleasePool");
     75     DrainSel = pass.Ctx.Selectors.getNullarySelector(
     76                                                  &pass.Ctx.Idents.get("drain"));
     77   }
     78 
     79   void transformBody(Stmt *body, Decl *ParentD) {
     80     Body = body;
     81     TraverseStmt(body);
     82   }
     83 
     84   ~AutoreleasePoolRewriter() {
     85     SmallVector<VarDecl *, 8> VarsToHandle;
     86 
     87     for (std::map<VarDecl *, PoolVarInfo>::iterator
     88            I = PoolVars.begin(), E = PoolVars.end(); I != E; ++I) {
     89       VarDecl *var = I->first;
     90       PoolVarInfo &info = I->second;
     91 
     92       // Check that we can handle/rewrite all references of the pool.
     93 
     94       clearRefsIn(info.Dcl, info.Refs);
     95       for (SmallVectorImpl<PoolScope>::iterator
     96              scpI = info.Scopes.begin(),
     97              scpE = info.Scopes.end(); scpI != scpE; ++scpI) {
     98         PoolScope &scope = *scpI;
     99         clearRefsIn(*scope.Begin, info.Refs);
    100         clearRefsIn(*scope.End, info.Refs);
    101         clearRefsIn(scope.Releases.begin(), scope.Releases.end(), info.Refs);
    102       }
    103 
    104       // Even if one reference is not handled we will not do anything about that
    105       // pool variable.
    106       if (info.Refs.empty())
    107         VarsToHandle.push_back(var);
    108     }
    109 
    110     for (unsigned i = 0, e = VarsToHandle.size(); i != e; ++i) {
    111       PoolVarInfo &info = PoolVars[VarsToHandle[i]];
    112 
    113       Transaction Trans(Pass.TA);
    114 
    115       clearUnavailableDiags(info.Dcl);
    116       Pass.TA.removeStmt(info.Dcl);
    117 
    118       // Add "@autoreleasepool { }"
    119       for (SmallVectorImpl<PoolScope>::iterator
    120              scpI = info.Scopes.begin(),
    121              scpE = info.Scopes.end(); scpI != scpE; ++scpI) {
    122         PoolScope &scope = *scpI;
    123         clearUnavailableDiags(*scope.Begin);
    124         clearUnavailableDiags(*scope.End);
    125         if (scope.IsFollowedBySimpleReturnStmt) {
    126           // Include the return in the scope.
    127           Pass.TA.replaceStmt(*scope.Begin, "@autoreleasepool {");
    128           Pass.TA.removeStmt(*scope.End);
    129           Stmt::child_iterator retI = scope.End;
    130           ++retI;
    131           SourceLocation afterSemi = findLocationAfterSemi((*retI)->getLocEnd(),
    132                                                            Pass.Ctx);
    133           assert(afterSemi.isValid() &&
    134                  "Didn't we check before setting IsFollowedBySimpleReturnStmt "
    135                  "to true?");
    136           Pass.TA.insertAfterToken(afterSemi, "\n}");
    137           Pass.TA.increaseIndentation(
    138                                 SourceRange(scope.getIndentedRange().getBegin(),
    139                                             (*retI)->getLocEnd()),
    140                                       scope.CompoundParent->getLocStart());
    141         } else {
    142           Pass.TA.replaceStmt(*scope.Begin, "@autoreleasepool {");
    143           Pass.TA.replaceStmt(*scope.End, "}");
    144           Pass.TA.increaseIndentation(scope.getIndentedRange(),
    145                                       scope.CompoundParent->getLocStart());
    146         }
    147       }
    148 
    149       // Remove rest of pool var references.
    150       for (SmallVectorImpl<PoolScope>::iterator
    151              scpI = info.Scopes.begin(),
    152              scpE = info.Scopes.end(); scpI != scpE; ++scpI) {
    153         PoolScope &scope = *scpI;
    154         for (SmallVectorImpl<ObjCMessageExpr *>::iterator
    155                relI = scope.Releases.begin(),
    156                relE = scope.Releases.end(); relI != relE; ++relI) {
    157           clearUnavailableDiags(*relI);
    158           Pass.TA.removeStmt(*relI);
    159         }
    160       }
    161     }
    162   }
    163 
    164   bool VisitCompoundStmt(CompoundStmt *S) {
    165     SmallVector<PoolScope, 4> Scopes;
    166 
    167     for (Stmt::child_iterator
    168            I = S->body_begin(), E = S->body_end(); I != E; ++I) {
    169       Stmt *child = getEssential(*I);
    170       if (DeclStmt *DclS = dyn_cast<DeclStmt>(child)) {
    171         if (DclS->isSingleDecl()) {
    172           if (VarDecl *VD = dyn_cast<VarDecl>(DclS->getSingleDecl())) {
    173             if (isNSAutoreleasePool(VD->getType())) {
    174               PoolVarInfo &info = PoolVars[VD];
    175               info.Dcl = DclS;
    176               collectRefs(VD, S, info.Refs);
    177               // Does this statement follow the pattern:
    178               // NSAutoreleasePool * pool = [NSAutoreleasePool  new];
    179               if (isPoolCreation(VD->getInit())) {
    180                 Scopes.push_back(PoolScope());
    181                 Scopes.back().PoolVar = VD;
    182                 Scopes.back().CompoundParent = S;
    183                 Scopes.back().Begin = I;
    184               }
    185             }
    186           }
    187         }
    188       } else if (BinaryOperator *bop = dyn_cast<BinaryOperator>(child)) {
    189         if (DeclRefExpr *dref = dyn_cast<DeclRefExpr>(bop->getLHS())) {
    190           if (VarDecl *VD = dyn_cast<VarDecl>(dref->getDecl())) {
    191             // Does this statement follow the pattern:
    192             // pool = [NSAutoreleasePool  new];
    193             if (isNSAutoreleasePool(VD->getType()) &&
    194                 isPoolCreation(bop->getRHS())) {
    195               Scopes.push_back(PoolScope());
    196               Scopes.back().PoolVar = VD;
    197               Scopes.back().CompoundParent = S;
    198               Scopes.back().Begin = I;
    199             }
    200           }
    201         }
    202       }
    203 
    204       if (Scopes.empty())
    205         continue;
    206 
    207       if (isPoolDrain(Scopes.back().PoolVar, child)) {
    208         PoolScope &scope = Scopes.back();
    209         scope.End = I;
    210         handlePoolScope(scope, S);
    211         Scopes.pop_back();
    212       }
    213     }
    214     return true;
    215   }
    216 
    217 private:
    218   void clearUnavailableDiags(Stmt *S) {
    219     if (S)
    220       Pass.TA.clearDiagnostic(diag::err_unavailable,
    221                               diag::err_unavailable_message,
    222                               S->getSourceRange());
    223   }
    224 
    225   struct PoolScope {
    226     VarDecl *PoolVar;
    227     CompoundStmt *CompoundParent;
    228     Stmt::child_iterator Begin;
    229     Stmt::child_iterator End;
    230     bool IsFollowedBySimpleReturnStmt;
    231     SmallVector<ObjCMessageExpr *, 4> Releases;
    232 
    233     PoolScope() : PoolVar(0), CompoundParent(0), Begin(), End(),
    234                   IsFollowedBySimpleReturnStmt(false) { }
    235 
    236     SourceRange getIndentedRange() const {
    237       Stmt::child_iterator rangeS = Begin;
    238       ++rangeS;
    239       if (rangeS == End)
    240         return SourceRange();
    241       Stmt::child_iterator rangeE = Begin;
    242       for (Stmt::child_iterator I = rangeS; I != End; ++I)
    243         ++rangeE;
    244       return SourceRange((*rangeS)->getLocStart(), (*rangeE)->getLocEnd());
    245     }
    246   };
    247 
    248   class NameReferenceChecker : public RecursiveASTVisitor<NameReferenceChecker>{
    249     ASTContext &Ctx;
    250     SourceRange ScopeRange;
    251     SourceLocation &referenceLoc, &declarationLoc;
    252 
    253   public:
    254     NameReferenceChecker(ASTContext &ctx, PoolScope &scope,
    255                          SourceLocation &referenceLoc,
    256                          SourceLocation &declarationLoc)
    257       : Ctx(ctx), referenceLoc(referenceLoc),
    258         declarationLoc(declarationLoc) {
    259       ScopeRange = SourceRange((*scope.Begin)->getLocStart(),
    260                                (*scope.End)->getLocStart());
    261     }
    262 
    263     bool VisitDeclRefExpr(DeclRefExpr *E) {
    264       return checkRef(E->getLocation(), E->getDecl()->getLocation());
    265     }
    266 
    267     bool VisitTypedefTypeLoc(TypedefTypeLoc TL) {
    268       return checkRef(TL.getBeginLoc(), TL.getTypedefNameDecl()->getLocation());
    269     }
    270 
    271     bool VisitTagTypeLoc(TagTypeLoc TL) {
    272       return checkRef(TL.getBeginLoc(), TL.getDecl()->getLocation());
    273     }
    274 
    275   private:
    276     bool checkRef(SourceLocation refLoc, SourceLocation declLoc) {
    277       if (isInScope(declLoc)) {
    278         referenceLoc = refLoc;
    279         declarationLoc = declLoc;
    280         return false;
    281       }
    282       return true;
    283     }
    284 
    285     bool isInScope(SourceLocation loc) {
    286       if (loc.isInvalid())
    287         return false;
    288 
    289       SourceManager &SM = Ctx.getSourceManager();
    290       if (SM.isBeforeInTranslationUnit(loc, ScopeRange.getBegin()))
    291         return false;
    292       return SM.isBeforeInTranslationUnit(loc, ScopeRange.getEnd());
    293     }
    294   };
    295 
    296   void handlePoolScope(PoolScope &scope, CompoundStmt *compoundS) {
    297     // Check that all names declared inside the scope are not used
    298     // outside the scope.
    299     {
    300       bool nameUsedOutsideScope = false;
    301       SourceLocation referenceLoc, declarationLoc;
    302       Stmt::child_iterator SI = scope.End, SE = compoundS->body_end();
    303       ++SI;
    304       // Check if the autoreleasepool scope is followed by a simple return
    305       // statement, in which case we will include the return in the scope.
    306       if (SI != SE)
    307         if (ReturnStmt *retS = dyn_cast<ReturnStmt>(*SI))
    308           if ((retS->getRetValue() == 0 ||
    309                isa<DeclRefExpr>(retS->getRetValue()->IgnoreParenCasts())) &&
    310               findLocationAfterSemi(retS->getLocEnd(), Pass.Ctx).isValid()) {
    311             scope.IsFollowedBySimpleReturnStmt = true;
    312             ++SI; // the return will be included in scope, don't check it.
    313           }
    314 
    315       for (; SI != SE; ++SI) {
    316         nameUsedOutsideScope = !NameReferenceChecker(Pass.Ctx, scope,
    317                                                      referenceLoc,
    318                                               declarationLoc).TraverseStmt(*SI);
    319         if (nameUsedOutsideScope)
    320           break;
    321       }
    322 
    323       // If not all references were cleared it means some variables/typenames/etc
    324       // declared inside the pool scope are used outside of it.
    325       // We won't try to rewrite the pool.
    326       if (nameUsedOutsideScope) {
    327         Pass.TA.reportError("a name is referenced outside the "
    328             "NSAutoreleasePool scope that it was declared in", referenceLoc);
    329         Pass.TA.reportNote("name declared here", declarationLoc);
    330         Pass.TA.reportNote("intended @autoreleasepool scope begins here",
    331                            (*scope.Begin)->getLocStart());
    332         Pass.TA.reportNote("intended @autoreleasepool scope ends here",
    333                            (*scope.End)->getLocStart());
    334         return;
    335       }
    336     }
    337 
    338     // Collect all releases of the pool; they will be removed.
    339     {
    340       ReleaseCollector releaseColl(scope.PoolVar, scope.Releases);
    341       Stmt::child_iterator I = scope.Begin;
    342       ++I;
    343       for (; I != scope.End; ++I)
    344         releaseColl.TraverseStmt(*I);
    345     }
    346 
    347     PoolVars[scope.PoolVar].Scopes.push_back(scope);
    348   }
    349 
    350   bool isPoolCreation(Expr *E) {
    351     if (!E) return false;
    352     E = getEssential(E);
    353     ObjCMessageExpr *ME = dyn_cast<ObjCMessageExpr>(E);
    354     if (!ME) return false;
    355     if (ME->getMethodFamily() == OMF_new &&
    356         ME->getReceiverKind() == ObjCMessageExpr::Class &&
    357         isNSAutoreleasePool(ME->getReceiverInterface()))
    358       return true;
    359     if (ME->getReceiverKind() == ObjCMessageExpr::Instance &&
    360         ME->getMethodFamily() == OMF_init) {
    361       Expr *rec = getEssential(ME->getInstanceReceiver());
    362       if (ObjCMessageExpr *recME = dyn_cast_or_null<ObjCMessageExpr>(rec)) {
    363         if (recME->getMethodFamily() == OMF_alloc &&
    364             recME->getReceiverKind() == ObjCMessageExpr::Class &&
    365             isNSAutoreleasePool(recME->getReceiverInterface()))
    366           return true;
    367       }
    368     }
    369 
    370     return false;
    371   }
    372 
    373   bool isPoolDrain(VarDecl *poolVar, Stmt *S) {
    374     if (!S) return false;
    375     S = getEssential(S);
    376     ObjCMessageExpr *ME = dyn_cast<ObjCMessageExpr>(S);
    377     if (!ME) return false;
    378     if (ME->getReceiverKind() == ObjCMessageExpr::Instance) {
    379       Expr *rec = getEssential(ME->getInstanceReceiver());
    380       if (DeclRefExpr *dref = dyn_cast<DeclRefExpr>(rec))
    381         if (dref->getDecl() == poolVar)
    382           return ME->getMethodFamily() == OMF_release ||
    383                  ME->getSelector() == DrainSel;
    384     }
    385 
    386     return false;
    387   }
    388 
    389   bool isNSAutoreleasePool(ObjCInterfaceDecl *IDecl) {
    390     return IDecl && IDecl->getIdentifier() == PoolII;
    391   }
    392 
    393   bool isNSAutoreleasePool(QualType Ty) {
    394     QualType pointee = Ty->getPointeeType();
    395     if (pointee.isNull())
    396       return false;
    397     if (const ObjCInterfaceType *interT = pointee->getAs<ObjCInterfaceType>())
    398       return isNSAutoreleasePool(interT->getDecl());
    399     return false;
    400   }
    401 
    402   static Expr *getEssential(Expr *E) {
    403     return cast<Expr>(getEssential((Stmt*)E));
    404   }
    405   static Stmt *getEssential(Stmt *S) {
    406     if (ExprWithCleanups *EWC = dyn_cast<ExprWithCleanups>(S))
    407       S = EWC->getSubExpr();
    408     if (Expr *E = dyn_cast<Expr>(S))
    409       S = E->IgnoreParenCasts();
    410     return S;
    411   }
    412 
    413   Stmt *Body;
    414   MigrationPass &Pass;
    415 
    416   IdentifierInfo *PoolII;
    417   Selector DrainSel;
    418 
    419   struct PoolVarInfo {
    420     DeclStmt *Dcl;
    421     ExprSet Refs;
    422     SmallVector<PoolScope, 2> Scopes;
    423 
    424     PoolVarInfo() : Dcl(0) { }
    425   };
    426 
    427   std::map<VarDecl *, PoolVarInfo> PoolVars;
    428 };
    429 
    430 } // anonymous namespace
    431 
    432 void trans::rewriteAutoreleasePool(MigrationPass &pass) {
    433   BodyTransform<AutoreleasePoolRewriter> trans(pass);
    434   trans.TraverseDecl(pass.Ctx.getTranslationUnitDecl());
    435 }
    436