Home | History | Annotate | Download | only in ARCMigrate
      1 //===--- TransAutoreleasePool.cpp - Tranformations to ARC mode ------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // rewriteAutoreleasePool:
     11 //
     12 // Calls to NSAutoreleasePools will be rewritten as an @autorelease scope.
     13 //
     14 //  NSAutoreleasePool *pool = [[NSAutoreleasePool alloc] init];
     15 //  ...
     16 //  [pool release];
     17 // ---->
     18 //  @autorelease {
     19 //  ...
     20 //  }
     21 //
     22 // An NSAutoreleasePool will not be touched if:
     23 // - There is not a corresponding -release/-drain in the same scope
     24 // - Not all references of the NSAutoreleasePool variable can be removed
     25 // - There is a variable that is declared inside the intended @autorelease scope
     26 //   which is also used outside it.
     27 //
     28 //===----------------------------------------------------------------------===//
     29 
     30 #include "Transforms.h"
     31 #include "Internals.h"
     32 #include "clang/Sema/SemaDiagnostic.h"
     33 #include "clang/Basic/SourceManager.h"
     34 #include <map>
     35 
     36 using namespace clang;
     37 using namespace arcmt;
     38 using namespace trans;
     39 using llvm::StringRef;
     40 
     41 namespace {
     42 
     43 class ReleaseCollector : public RecursiveASTVisitor<ReleaseCollector> {
     44   Decl *Dcl;
     45   llvm::SmallVectorImpl<ObjCMessageExpr *> &Releases;
     46 
     47 public:
     48   ReleaseCollector(Decl *D, llvm::SmallVectorImpl<ObjCMessageExpr *> &releases)
     49     : Dcl(D), Releases(releases) { }
     50 
     51   bool VisitObjCMessageExpr(ObjCMessageExpr *E) {
     52     if (!E->isInstanceMessage())
     53       return true;
     54     if (E->getMethodFamily() != OMF_release)
     55       return true;
     56     Expr *instance = E->getInstanceReceiver()->IgnoreParenCasts();
     57     if (DeclRefExpr *DE = dyn_cast<DeclRefExpr>(instance)) {
     58       if (DE->getDecl() == Dcl)
     59         Releases.push_back(E);
     60     }
     61     return true;
     62   }
     63 };
     64 
     65 }
     66 
     67 namespace {
     68 
     69 class AutoreleasePoolRewriter
     70                          : public RecursiveASTVisitor<AutoreleasePoolRewriter> {
     71 public:
     72   AutoreleasePoolRewriter(MigrationPass &pass)
     73     : Body(0), Pass(pass) {
     74     PoolII = &pass.Ctx.Idents.get("NSAutoreleasePool");
     75     DrainSel = pass.Ctx.Selectors.getNullarySelector(
     76                                                  &pass.Ctx.Idents.get("drain"));
     77   }
     78 
     79   void transformBody(Stmt *body) {
     80     Body = body;
     81     TraverseStmt(body);
     82   }
     83 
     84   ~AutoreleasePoolRewriter() {
     85     llvm::SmallVector<VarDecl *, 8> VarsToHandle;
     86 
     87     for (std::map<VarDecl *, PoolVarInfo>::iterator
     88            I = PoolVars.begin(), E = PoolVars.end(); I != E; ++I) {
     89       VarDecl *var = I->first;
     90       PoolVarInfo &info = I->second;
     91 
     92       // Check that we can handle/rewrite all references of the pool.
     93 
     94       clearRefsIn(info.Dcl, info.Refs);
     95       for (llvm::SmallVectorImpl<PoolScope>::iterator
     96              scpI = info.Scopes.begin(),
     97              scpE = info.Scopes.end(); scpI != scpE; ++scpI) {
     98         PoolScope &scope = *scpI;
     99         clearRefsIn(*scope.Begin, info.Refs);
    100         clearRefsIn(*scope.End, info.Refs);
    101         clearRefsIn(scope.Releases.begin(), scope.Releases.end(), info.Refs);
    102       }
    103 
    104       // Even if one reference is not handled we will not do anything about that
    105       // pool variable.
    106       if (info.Refs.empty())
    107         VarsToHandle.push_back(var);
    108     }
    109 
    110     for (unsigned i = 0, e = VarsToHandle.size(); i != e; ++i) {
    111       PoolVarInfo &info = PoolVars[VarsToHandle[i]];
    112 
    113       Transaction Trans(Pass.TA);
    114 
    115       clearUnavailableDiags(info.Dcl);
    116       Pass.TA.removeStmt(info.Dcl);
    117 
    118       // Add "@autoreleasepool { }"
    119       for (llvm::SmallVectorImpl<PoolScope>::iterator
    120              scpI = info.Scopes.begin(),
    121              scpE = info.Scopes.end(); scpI != scpE; ++scpI) {
    122         PoolScope &scope = *scpI;
    123         clearUnavailableDiags(*scope.Begin);
    124         clearUnavailableDiags(*scope.End);
    125         if (scope.IsFollowedBySimpleReturnStmt) {
    126           // Include the return in the scope.
    127           Pass.TA.replaceStmt(*scope.Begin, "@autoreleasepool {");
    128           Pass.TA.removeStmt(*scope.End);
    129           Stmt::child_iterator retI = scope.End;
    130           ++retI;
    131           SourceLocation afterSemi = findLocationAfterSemi((*retI)->getLocEnd(),
    132                                                            Pass.Ctx);
    133           assert(afterSemi.isValid() &&
    134                  "Didn't we check before setting IsFollowedBySimpleReturnStmt "
    135                  "to true?");
    136           Pass.TA.insertAfterToken(afterSemi, "\n}");
    137           Pass.TA.increaseIndentation(
    138                                 SourceRange(scope.getIndentedRange().getBegin(),
    139                                             (*retI)->getLocEnd()),
    140                                       scope.CompoundParent->getLocStart());
    141         } else {
    142           Pass.TA.replaceStmt(*scope.Begin, "@autoreleasepool {");
    143           Pass.TA.replaceStmt(*scope.End, "}");
    144           Pass.TA.increaseIndentation(scope.getIndentedRange(),
    145                                       scope.CompoundParent->getLocStart());
    146         }
    147       }
    148 
    149       // Remove rest of pool var references.
    150       for (llvm::SmallVectorImpl<PoolScope>::iterator
    151              scpI = info.Scopes.begin(),
    152              scpE = info.Scopes.end(); scpI != scpE; ++scpI) {
    153         PoolScope &scope = *scpI;
    154         for (llvm::SmallVectorImpl<ObjCMessageExpr *>::iterator
    155                relI = scope.Releases.begin(),
    156                relE = scope.Releases.end(); relI != relE; ++relI) {
    157           clearUnavailableDiags(*relI);
    158           Pass.TA.removeStmt(*relI);
    159         }
    160       }
    161     }
    162   }
    163 
    164   bool VisitCompoundStmt(CompoundStmt *S) {
    165     llvm::SmallVector<PoolScope, 4> Scopes;
    166 
    167     for (Stmt::child_iterator
    168            I = S->body_begin(), E = S->body_end(); I != E; ++I) {
    169       Stmt *child = getEssential(*I);
    170       if (DeclStmt *DclS = dyn_cast<DeclStmt>(child)) {
    171         if (DclS->isSingleDecl()) {
    172           if (VarDecl *VD = dyn_cast<VarDecl>(DclS->getSingleDecl())) {
    173             if (isNSAutoreleasePool(VD->getType())) {
    174               PoolVarInfo &info = PoolVars[VD];
    175               info.Dcl = DclS;
    176               collectRefs(VD, S, info.Refs);
    177               // Does this statement follow the pattern:
    178               // NSAutoreleasePool * pool = [NSAutoreleasePool  new];
    179               if (isPoolCreation(VD->getInit())) {
    180                 Scopes.push_back(PoolScope());
    181                 Scopes.back().PoolVar = VD;
    182                 Scopes.back().CompoundParent = S;
    183                 Scopes.back().Begin = I;
    184               }
    185             }
    186           }
    187         }
    188       } else if (BinaryOperator *bop = dyn_cast<BinaryOperator>(child)) {
    189         if (DeclRefExpr *dref = dyn_cast<DeclRefExpr>(bop->getLHS())) {
    190           if (VarDecl *VD = dyn_cast<VarDecl>(dref->getDecl())) {
    191             // Does this statement follow the pattern:
    192             // pool = [NSAutoreleasePool  new];
    193             if (isNSAutoreleasePool(VD->getType()) &&
    194                 isPoolCreation(bop->getRHS())) {
    195               Scopes.push_back(PoolScope());
    196               Scopes.back().PoolVar = VD;
    197               Scopes.back().CompoundParent = S;
    198               Scopes.back().Begin = I;
    199             }
    200           }
    201         }
    202       }
    203 
    204       if (Scopes.empty())
    205         continue;
    206 
    207       if (isPoolDrain(Scopes.back().PoolVar, child)) {
    208         PoolScope &scope = Scopes.back();
    209         scope.End = I;
    210         handlePoolScope(scope, S);
    211         Scopes.pop_back();
    212       }
    213     }
    214     return true;
    215   }
    216 
    217 private:
    218   void clearUnavailableDiags(Stmt *S) {
    219     if (S)
    220       Pass.TA.clearDiagnostic(diag::err_unavailable,
    221                               diag::err_unavailable_message,
    222                               S->getSourceRange());
    223   }
    224 
    225   struct PoolScope {
    226     VarDecl *PoolVar;
    227     CompoundStmt *CompoundParent;
    228     Stmt::child_iterator Begin;
    229     Stmt::child_iterator End;
    230     bool IsFollowedBySimpleReturnStmt;
    231     llvm::SmallVector<ObjCMessageExpr *, 4> Releases;
    232 
    233     PoolScope() : PoolVar(0), CompoundParent(0), Begin(), End(),
    234                   IsFollowedBySimpleReturnStmt(false) { }
    235 
    236     SourceRange getIndentedRange() const {
    237       Stmt::child_iterator rangeS = Begin;
    238       ++rangeS;
    239       if (rangeS == End)
    240         return SourceRange();
    241       Stmt::child_iterator rangeE = Begin;
    242       for (Stmt::child_iterator I = rangeS; I != End; ++I)
    243         ++rangeE;
    244       return SourceRange((*rangeS)->getLocStart(), (*rangeE)->getLocEnd());
    245     }
    246   };
    247 
    248   class NameReferenceChecker : public RecursiveASTVisitor<NameReferenceChecker>{
    249     ASTContext &Ctx;
    250     SourceRange ScopeRange;
    251     SourceLocation &referenceLoc, &declarationLoc;
    252 
    253   public:
    254     NameReferenceChecker(ASTContext &ctx, PoolScope &scope,
    255                          SourceLocation &referenceLoc,
    256                          SourceLocation &declarationLoc)
    257       : Ctx(ctx), referenceLoc(referenceLoc),
    258         declarationLoc(declarationLoc) {
    259       ScopeRange = SourceRange((*scope.Begin)->getLocStart(),
    260                                (*scope.End)->getLocStart());
    261     }
    262 
    263     bool VisitDeclRefExpr(DeclRefExpr *E) {
    264       return checkRef(E->getLocation(), E->getDecl()->getLocation());
    265     }
    266 
    267     bool VisitBlockDeclRefExpr(BlockDeclRefExpr *E) {
    268       return checkRef(E->getLocation(), E->getDecl()->getLocation());
    269     }
    270 
    271     bool VisitTypedefTypeLoc(TypedefTypeLoc TL) {
    272       return checkRef(TL.getBeginLoc(), TL.getTypedefNameDecl()->getLocation());
    273     }
    274 
    275     bool VisitTagTypeLoc(TagTypeLoc TL) {
    276       return checkRef(TL.getBeginLoc(), TL.getDecl()->getLocation());
    277     }
    278 
    279   private:
    280     bool checkRef(SourceLocation refLoc, SourceLocation declLoc) {
    281       if (isInScope(declLoc)) {
    282         referenceLoc = refLoc;
    283         declarationLoc = declLoc;
    284         return false;
    285       }
    286       return true;
    287     }
    288 
    289     bool isInScope(SourceLocation loc) {
    290       SourceManager &SM = Ctx.getSourceManager();
    291       if (SM.isBeforeInTranslationUnit(loc, ScopeRange.getBegin()))
    292         return false;
    293       return SM.isBeforeInTranslationUnit(loc, ScopeRange.getEnd());
    294     }
    295   };
    296 
    297   void handlePoolScope(PoolScope &scope, CompoundStmt *compoundS) {
    298     // Check that all names declared inside the scope are not used
    299     // outside the scope.
    300     {
    301       bool nameUsedOutsideScope = false;
    302       SourceLocation referenceLoc, declarationLoc;
    303       Stmt::child_iterator SI = scope.End, SE = compoundS->body_end();
    304       ++SI;
    305       // Check if the autoreleasepool scope is followed by a simple return
    306       // statement, in which case we will include the return in the scope.
    307       if (SI != SE)
    308         if (ReturnStmt *retS = dyn_cast<ReturnStmt>(*SI))
    309           if ((retS->getRetValue() == 0 ||
    310                isa<DeclRefExpr>(retS->getRetValue()->IgnoreParenCasts())) &&
    311               findLocationAfterSemi(retS->getLocEnd(), Pass.Ctx).isValid()) {
    312             scope.IsFollowedBySimpleReturnStmt = true;
    313             ++SI; // the return will be included in scope, don't check it.
    314           }
    315 
    316       for (; SI != SE; ++SI) {
    317         nameUsedOutsideScope = !NameReferenceChecker(Pass.Ctx, scope,
    318                                                      referenceLoc,
    319                                               declarationLoc).TraverseStmt(*SI);
    320         if (nameUsedOutsideScope)
    321           break;
    322       }
    323 
    324       // If not all references were cleared it means some variables/typenames/etc
    325       // declared inside the pool scope are used outside of it.
    326       // We won't try to rewrite the pool.
    327       if (nameUsedOutsideScope) {
    328         Pass.TA.reportError("a name is referenced outside the "
    329             "NSAutoreleasePool scope that it was declared in", referenceLoc);
    330         Pass.TA.reportNote("name declared here", declarationLoc);
    331         Pass.TA.reportNote("intended @autoreleasepool scope begins here",
    332                            (*scope.Begin)->getLocStart());
    333         Pass.TA.reportNote("intended @autoreleasepool scope ends here",
    334                            (*scope.End)->getLocStart());
    335         return;
    336       }
    337     }
    338 
    339     // Collect all releases of the pool; they will be removed.
    340     {
    341       ReleaseCollector releaseColl(scope.PoolVar, scope.Releases);
    342       Stmt::child_iterator I = scope.Begin;
    343       ++I;
    344       for (; I != scope.End; ++I)
    345         releaseColl.TraverseStmt(*I);
    346     }
    347 
    348     PoolVars[scope.PoolVar].Scopes.push_back(scope);
    349   }
    350 
    351   bool isPoolCreation(Expr *E) {
    352     if (!E) return false;
    353     E = getEssential(E);
    354     ObjCMessageExpr *ME = dyn_cast<ObjCMessageExpr>(E);
    355     if (!ME) return false;
    356     if (ME->getMethodFamily() == OMF_new &&
    357         ME->getReceiverKind() == ObjCMessageExpr::Class &&
    358         isNSAutoreleasePool(ME->getReceiverInterface()))
    359       return true;
    360     if (ME->getReceiverKind() == ObjCMessageExpr::Instance &&
    361         ME->getMethodFamily() == OMF_init) {
    362       Expr *rec = getEssential(ME->getInstanceReceiver());
    363       if (ObjCMessageExpr *recME = dyn_cast_or_null<ObjCMessageExpr>(rec)) {
    364         if (recME->getMethodFamily() == OMF_alloc &&
    365             recME->getReceiverKind() == ObjCMessageExpr::Class &&
    366             isNSAutoreleasePool(recME->getReceiverInterface()))
    367           return true;
    368       }
    369     }
    370 
    371     return false;
    372   }
    373 
    374   bool isPoolDrain(VarDecl *poolVar, Stmt *S) {
    375     if (!S) return false;
    376     S = getEssential(S);
    377     ObjCMessageExpr *ME = dyn_cast<ObjCMessageExpr>(S);
    378     if (!ME) return false;
    379     if (ME->getReceiverKind() == ObjCMessageExpr::Instance) {
    380       Expr *rec = getEssential(ME->getInstanceReceiver());
    381       if (DeclRefExpr *dref = dyn_cast<DeclRefExpr>(rec))
    382         if (dref->getDecl() == poolVar)
    383           return ME->getMethodFamily() == OMF_release ||
    384                  ME->getSelector() == DrainSel;
    385     }
    386 
    387     return false;
    388   }
    389 
    390   bool isNSAutoreleasePool(ObjCInterfaceDecl *IDecl) {
    391     return IDecl && IDecl->getIdentifier() == PoolII;
    392   }
    393 
    394   bool isNSAutoreleasePool(QualType Ty) {
    395     QualType pointee = Ty->getPointeeType();
    396     if (pointee.isNull())
    397       return false;
    398     if (const ObjCInterfaceType *interT = pointee->getAs<ObjCInterfaceType>())
    399       return isNSAutoreleasePool(interT->getDecl());
    400     return false;
    401   }
    402 
    403   static Expr *getEssential(Expr *E) {
    404     return cast<Expr>(getEssential((Stmt*)E));
    405   }
    406   static Stmt *getEssential(Stmt *S) {
    407     if (ExprWithCleanups *EWC = dyn_cast<ExprWithCleanups>(S))
    408       S = EWC->getSubExpr();
    409     if (Expr *E = dyn_cast<Expr>(S))
    410       S = E->IgnoreParenCasts();
    411     return S;
    412   }
    413 
    414   Stmt *Body;
    415   MigrationPass &Pass;
    416 
    417   IdentifierInfo *PoolII;
    418   Selector DrainSel;
    419 
    420   struct PoolVarInfo {
    421     DeclStmt *Dcl;
    422     ExprSet Refs;
    423     llvm::SmallVector<PoolScope, 2> Scopes;
    424 
    425     PoolVarInfo() : Dcl(0) { }
    426   };
    427 
    428   std::map<VarDecl *, PoolVarInfo> PoolVars;
    429 };
    430 
    431 } // anonymous namespace
    432 
    433 void trans::rewriteAutoreleasePool(MigrationPass &pass) {
    434   BodyTransform<AutoreleasePoolRewriter> trans(pass);
    435   trans.TraverseDecl(pass.Ctx.getTranslationUnitDecl());
    436 }
    437