Home | History | Annotate | Download | only in ARCMigrate
      1 //===--- TransUnbridgedCasts.cpp - Tranformations to ARC mode -------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // rewriteUnbridgedCasts:
     11 //
     12 // A cast of non-objc pointer to an objc one is checked. If the non-objc pointer
     13 // is from a file-level variable, __bridge cast is used to convert it.
     14 // For the result of a function call that we know is +1/+0,
     15 // __bridge/__bridge_transfer is used.
     16 //
     17 //  NSString *str = (NSString *)kUTTypePlainText;
     18 //  str = b ? kUTTypeRTF : kUTTypePlainText;
     19 //  NSString *_uuidString = (NSString *)CFUUIDCreateString(kCFAllocatorDefault,
     20 //                                                         _uuid);
     21 // ---->
     22 //  NSString *str = (__bridge NSString *)kUTTypePlainText;
     23 //  str = (__bridge NSString *)(b ? kUTTypeRTF : kUTTypePlainText);
     24 // NSString *_uuidString = (__bridge_transfer NSString *)
     25 //                               CFUUIDCreateString(kCFAllocatorDefault, _uuid);
     26 //
     27 // For a C pointer to ObjC, for casting 'self', __bridge is used.
     28 //
     29 //  CFStringRef str = (CFStringRef)self;
     30 // ---->
     31 //  CFStringRef str = (__bridge CFStringRef)self;
     32 //
     33 //===----------------------------------------------------------------------===//
     34 
     35 #include "Transforms.h"
     36 #include "Internals.h"
     37 #include "clang/Analysis/DomainSpecific/CocoaConventions.h"
     38 #include "clang/Sema/SemaDiagnostic.h"
     39 #include "clang/AST/ParentMap.h"
     40 #include "clang/Basic/SourceManager.h"
     41 #include "llvm/ADT/SmallString.h"
     42 
     43 using namespace clang;
     44 using namespace arcmt;
     45 using namespace trans;
     46 
     47 namespace {
     48 
     49 class UnbridgedCastRewriter : public RecursiveASTVisitor<UnbridgedCastRewriter>{
     50   MigrationPass &Pass;
     51   IdentifierInfo *SelfII;
     52   OwningPtr<ParentMap> StmtMap;
     53 
     54 public:
     55   UnbridgedCastRewriter(MigrationPass &pass) : Pass(pass) {
     56     SelfII = &Pass.Ctx.Idents.get("self");
     57   }
     58 
     59   void transformBody(Stmt *body) {
     60     StmtMap.reset(new ParentMap(body));
     61     TraverseStmt(body);
     62   }
     63 
     64   bool VisitCastExpr(CastExpr *E) {
     65     if (E->getCastKind() != CK_CPointerToObjCPointerCast
     66         && E->getCastKind() != CK_BitCast)
     67       return true;
     68 
     69     QualType castType = E->getType();
     70     Expr *castExpr = E->getSubExpr();
     71     QualType castExprType = castExpr->getType();
     72 
     73     if (castType->isObjCObjectPointerType() &&
     74         castExprType->isObjCObjectPointerType())
     75       return true;
     76     if (!castType->isObjCObjectPointerType() &&
     77         !castExprType->isObjCObjectPointerType())
     78       return true;
     79 
     80     bool exprRetainable = castExprType->isObjCIndirectLifetimeType();
     81     bool castRetainable = castType->isObjCIndirectLifetimeType();
     82     if (exprRetainable == castRetainable) return true;
     83 
     84     if (castExpr->isNullPointerConstant(Pass.Ctx,
     85                                         Expr::NPC_ValueDependentIsNull))
     86       return true;
     87 
     88     SourceLocation loc = castExpr->getExprLoc();
     89     if (loc.isValid() && Pass.Ctx.getSourceManager().isInSystemHeader(loc))
     90       return true;
     91 
     92     if (castType->isObjCObjectPointerType())
     93       transformNonObjCToObjCCast(E);
     94     else
     95       transformObjCToNonObjCCast(E);
     96 
     97     return true;
     98   }
     99 
    100 private:
    101   void transformNonObjCToObjCCast(CastExpr *E) {
    102     if (!E) return;
    103 
    104     // Global vars are assumed that are cast as unretained.
    105     if (isGlobalVar(E))
    106       if (E->getSubExpr()->getType()->isPointerType()) {
    107         castToObjCObject(E, /*retained=*/false);
    108         return;
    109       }
    110 
    111     // If the cast is directly over the result of a Core Foundation function
    112     // try to figure out whether it should be cast as retained or unretained.
    113     Expr *inner = E->IgnoreParenCasts();
    114     if (CallExpr *callE = dyn_cast<CallExpr>(inner)) {
    115       if (FunctionDecl *FD = callE->getDirectCallee()) {
    116         if (FD->getAttr<CFReturnsRetainedAttr>()) {
    117           castToObjCObject(E, /*retained=*/true);
    118           return;
    119         }
    120         if (FD->getAttr<CFReturnsNotRetainedAttr>()) {
    121           castToObjCObject(E, /*retained=*/false);
    122           return;
    123         }
    124         if (FD->isGlobal() &&
    125             FD->getIdentifier() &&
    126             ento::cocoa::isRefType(E->getSubExpr()->getType(), "CF",
    127                                    FD->getIdentifier()->getName())) {
    128           StringRef fname = FD->getIdentifier()->getName();
    129           if (fname.endswith("Retain") ||
    130               fname.find("Create") != StringRef::npos ||
    131               fname.find("Copy") != StringRef::npos) {
    132             // Do not migrate to couple of bridge transfer casts which
    133             // cancel each other out. Leave it unchanged so error gets user
    134             // attention instead.
    135             if (FD->getName() == "CFRetain" &&
    136                 FD->getNumParams() == 1 &&
    137                 FD->getParent()->isTranslationUnit() &&
    138                 FD->getLinkage() == ExternalLinkage) {
    139               Expr *Arg = callE->getArg(0);
    140               if (const ImplicitCastExpr *ICE = dyn_cast<ImplicitCastExpr>(Arg)) {
    141                 const Expr *sub = ICE->getSubExpr();
    142                 QualType T = sub->getType();
    143                 if (T->isObjCObjectPointerType())
    144                   return;
    145               }
    146             }
    147             castToObjCObject(E, /*retained=*/true);
    148             return;
    149           }
    150 
    151           if (fname.find("Get") != StringRef::npos) {
    152             castToObjCObject(E, /*retained=*/false);
    153             return;
    154           }
    155         }
    156       }
    157     }
    158   }
    159 
    160   void castToObjCObject(CastExpr *E, bool retained) {
    161     rewriteToBridgedCast(E, retained ? OBC_BridgeTransfer : OBC_Bridge);
    162   }
    163 
    164   void rewriteToBridgedCast(CastExpr *E, ObjCBridgeCastKind Kind) {
    165     Transaction Trans(Pass.TA);
    166     rewriteToBridgedCast(E, Kind, Trans);
    167   }
    168 
    169   void rewriteToBridgedCast(CastExpr *E, ObjCBridgeCastKind Kind,
    170                             Transaction &Trans) {
    171     TransformActions &TA = Pass.TA;
    172 
    173     // We will remove the compiler diagnostic.
    174     if (!TA.hasDiagnostic(diag::err_arc_mismatched_cast,
    175                           diag::err_arc_cast_requires_bridge,
    176                           E->getLocStart())) {
    177       Trans.abort();
    178       return;
    179     }
    180 
    181     StringRef bridge;
    182     switch(Kind) {
    183     case OBC_Bridge:
    184       bridge = "__bridge "; break;
    185     case OBC_BridgeTransfer:
    186       bridge = "__bridge_transfer "; break;
    187     case OBC_BridgeRetained:
    188       bridge = "__bridge_retained "; break;
    189     }
    190 
    191     TA.clearDiagnostic(diag::err_arc_mismatched_cast,
    192                        diag::err_arc_cast_requires_bridge,
    193                        E->getLocStart());
    194     if (CStyleCastExpr *CCE = dyn_cast<CStyleCastExpr>(E)) {
    195       TA.insertAfterToken(CCE->getLParenLoc(), bridge);
    196     } else {
    197       SourceLocation insertLoc = E->getSubExpr()->getLocStart();
    198       SmallString<128> newCast;
    199       newCast += '(';
    200       newCast += bridge;
    201       newCast += E->getType().getAsString(Pass.Ctx.getPrintingPolicy());
    202       newCast += ')';
    203 
    204       if (isa<ParenExpr>(E->getSubExpr())) {
    205         TA.insert(insertLoc, newCast.str());
    206       } else {
    207         newCast += '(';
    208         TA.insert(insertLoc, newCast.str());
    209         TA.insertAfterToken(E->getLocEnd(), ")");
    210       }
    211     }
    212   }
    213 
    214   void rewriteCastForCFRetain(CastExpr *castE, CallExpr *callE) {
    215     Transaction Trans(Pass.TA);
    216     Pass.TA.replace(callE->getSourceRange(), callE->getArg(0)->getSourceRange());
    217     rewriteToBridgedCast(castE, OBC_BridgeRetained, Trans);
    218   }
    219 
    220   void transformObjCToNonObjCCast(CastExpr *E) {
    221     if (isSelf(E->getSubExpr()))
    222       return rewriteToBridgedCast(E, OBC_Bridge);
    223 
    224     CallExpr *callE;
    225     if (isPassedToCFRetain(E, callE))
    226       return rewriteCastForCFRetain(E, callE);
    227 
    228     ObjCMethodFamily family = getFamilyOfMessage(E->getSubExpr());
    229     if (family == OMF_retain)
    230       return rewriteToBridgedCast(E, OBC_BridgeRetained);
    231 
    232     if (family == OMF_autorelease || family == OMF_release) {
    233       std::string err = "it is not safe to cast to '";
    234       err += E->getType().getAsString(Pass.Ctx.getPrintingPolicy());
    235       err += "' the result of '";
    236       err += family == OMF_autorelease ? "autorelease" : "release";
    237       err += "' message; a __bridge cast may result in a pointer to a "
    238           "destroyed object and a __bridge_retained may leak the object";
    239       Pass.TA.reportError(err, E->getLocStart(),
    240                           E->getSubExpr()->getSourceRange());
    241       Stmt *parent = E;
    242       do {
    243         parent = StmtMap->getParentIgnoreParenImpCasts(parent);
    244       } while (parent && isa<ExprWithCleanups>(parent));
    245 
    246       if (ReturnStmt *retS = dyn_cast_or_null<ReturnStmt>(parent)) {
    247         std::string note = "remove the cast and change return type of function "
    248             "to '";
    249         note += E->getSubExpr()->getType().getAsString(Pass.Ctx.getPrintingPolicy());
    250         note += "' to have the object automatically autoreleased";
    251         Pass.TA.reportNote(note, retS->getLocStart());
    252       }
    253     }
    254 
    255     Expr *subExpr = E->getSubExpr();
    256 
    257     // Look through pseudo-object expressions.
    258     if (PseudoObjectExpr *pseudo = dyn_cast<PseudoObjectExpr>(subExpr)) {
    259       subExpr = pseudo->getResultExpr();
    260       assert(subExpr && "no result for pseudo-object of non-void type?");
    261     }
    262 
    263     if (ImplicitCastExpr *implCE = dyn_cast<ImplicitCastExpr>(subExpr)) {
    264       if (implCE->getCastKind() == CK_ARCConsumeObject)
    265         return rewriteToBridgedCast(E, OBC_BridgeRetained);
    266       if (implCE->getCastKind() == CK_ARCReclaimReturnedObject)
    267         return rewriteToBridgedCast(E, OBC_Bridge);
    268     }
    269 
    270     bool isConsumed = false;
    271     if (isPassedToCParamWithKnownOwnership(E, isConsumed))
    272       return rewriteToBridgedCast(E, isConsumed ? OBC_BridgeRetained
    273                                                 : OBC_Bridge);
    274   }
    275 
    276   static ObjCMethodFamily getFamilyOfMessage(Expr *E) {
    277     E = E->IgnoreParenCasts();
    278     if (ObjCMessageExpr *ME = dyn_cast<ObjCMessageExpr>(E))
    279       return ME->getMethodFamily();
    280 
    281     return OMF_None;
    282   }
    283 
    284   bool isPassedToCFRetain(Expr *E, CallExpr *&callE) const {
    285     if ((callE = dyn_cast_or_null<CallExpr>(
    286                                      StmtMap->getParentIgnoreParenImpCasts(E))))
    287       if (FunctionDecl *
    288             FD = dyn_cast_or_null<FunctionDecl>(callE->getCalleeDecl()))
    289         if (FD->getName() == "CFRetain" && FD->getNumParams() == 1 &&
    290             FD->getParent()->isTranslationUnit() &&
    291             FD->getLinkage() == ExternalLinkage)
    292           return true;
    293 
    294     return false;
    295   }
    296 
    297   bool isPassedToCParamWithKnownOwnership(Expr *E, bool &isConsumed) const {
    298     if (CallExpr *callE = dyn_cast_or_null<CallExpr>(
    299                                      StmtMap->getParentIgnoreParenImpCasts(E)))
    300       if (FunctionDecl *
    301             FD = dyn_cast_or_null<FunctionDecl>(callE->getCalleeDecl())) {
    302         unsigned i = 0;
    303         for (unsigned e = callE->getNumArgs(); i != e; ++i) {
    304           Expr *arg = callE->getArg(i);
    305           if (arg == E || arg->IgnoreParenImpCasts() == E)
    306             break;
    307         }
    308         if (i < callE->getNumArgs()) {
    309           ParmVarDecl *PD = FD->getParamDecl(i);
    310           if (PD->getAttr<CFConsumedAttr>()) {
    311             isConsumed = true;
    312             return true;
    313           }
    314         }
    315       }
    316 
    317     return false;
    318   }
    319 
    320   bool isSelf(Expr *E) const {
    321     E = E->IgnoreParenLValueCasts();
    322     if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E))
    323       if (ImplicitParamDecl *IPD = dyn_cast<ImplicitParamDecl>(DRE->getDecl()))
    324         if (IPD->getIdentifier() == SelfII)
    325           return true;
    326 
    327     return false;
    328   }
    329 };
    330 
    331 } // end anonymous namespace
    332 
    333 void trans::rewriteUnbridgedCasts(MigrationPass &pass) {
    334   BodyTransform<UnbridgedCastRewriter> trans(pass);
    335   trans.TraverseDecl(pass.Ctx.getTranslationUnitDecl());
    336 }
    337