Home | History | Annotate | Download | only in slang
      1 /*
      2  * Copyright 2010, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "slang_rs_object_ref_count.h"
     18 
     19 #include <list>
     20 
     21 #include "clang/AST/DeclGroup.h"
     22 #include "clang/AST/Expr.h"
     23 #include "clang/AST/NestedNameSpecifier.h"
     24 #include "clang/AST/OperationKinds.h"
     25 #include "clang/AST/Stmt.h"
     26 #include "clang/AST/StmtVisitor.h"
     27 
     28 #include "slang_assert.h"
     29 #include "slang.h"
     30 #include "slang_rs_ast_replace.h"
     31 #include "slang_rs_export_type.h"
     32 
     33 namespace slang {
     34 
     35 /* Even though those two arrays are of size DataTypeMax, only entries that
     36  * correspond to object types will be set.
     37  */
     38 clang::FunctionDecl *
     39 RSObjectRefCount::RSSetObjectFD[DataTypeMax];
     40 clang::FunctionDecl *
     41 RSObjectRefCount::RSClearObjectFD[DataTypeMax];
     42 
     43 void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
     44   for (unsigned i = 0; i < DataTypeMax; i++) {
     45     RSSetObjectFD[i] = nullptr;
     46     RSClearObjectFD[i] = nullptr;
     47   }
     48 
     49   clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
     50 
     51   for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
     52           E = TUDecl->decls_end(); I != E; I++) {
     53     if ((I->getKind() >= clang::Decl::firstFunction) &&
     54         (I->getKind() <= clang::Decl::lastFunction)) {
     55       clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
     56 
     57       // points to RSSetObjectFD or RSClearObjectFD
     58       clang::FunctionDecl **RSObjectFD;
     59 
     60       if (FD->getName() == "rsSetObject") {
     61         slangAssert((FD->getNumParams() == 2) &&
     62                     "Invalid rsSetObject function prototype (# params)");
     63         RSObjectFD = RSSetObjectFD;
     64       } else if (FD->getName() == "rsClearObject") {
     65         slangAssert((FD->getNumParams() == 1) &&
     66                     "Invalid rsClearObject function prototype (# params)");
     67         RSObjectFD = RSClearObjectFD;
     68       } else {
     69         continue;
     70       }
     71 
     72       const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
     73       clang::QualType PVT = PVD->getOriginalType();
     74       // The first parameter must be a pointer like rs_allocation*
     75       slangAssert(PVT->isPointerType() &&
     76           "Invalid rs{Set,Clear}Object function prototype (pointer param)");
     77 
     78       // The rs object type passed to the FD
     79       clang::QualType RST = PVT->getPointeeType();
     80       DataType DT = RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
     81       slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
     82              && "must be RS object type");
     83 
     84       if (DT >= 0 && DT < DataTypeMax) {
     85           RSObjectFD[DT] = FD;
     86       } else {
     87           slangAssert(false && "incorrect type");
     88       }
     89     }
     90   }
     91 }
     92 
     93 namespace {
     94 
     95 // This function constructs a new CompoundStmt from the input StmtList.
     96 static clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
     97       std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
     98   unsigned NewStmtCount = StmtList.size();
     99   unsigned CompoundStmtCount = 0;
    100 
    101   clang::Stmt **CompoundStmtList;
    102   CompoundStmtList = new clang::Stmt*[NewStmtCount];
    103 
    104   std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
    105   std::list<clang::Stmt*>::const_iterator E = StmtList.end();
    106   for ( ; I != E; I++) {
    107     CompoundStmtList[CompoundStmtCount++] = *I;
    108   }
    109   slangAssert(CompoundStmtCount == NewStmtCount);
    110 
    111   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
    112       C, llvm::makeArrayRef(CompoundStmtList, CompoundStmtCount), Loc, Loc);
    113 
    114   delete [] CompoundStmtList;
    115 
    116   return CS;
    117 }
    118 
    119 static void AppendAfterStmt(clang::ASTContext &C,
    120                             clang::CompoundStmt *CS,
    121                             clang::Stmt *S,
    122                             std::list<clang::Stmt*> &StmtList) {
    123   slangAssert(CS);
    124   clang::CompoundStmt::body_iterator bI = CS->body_begin();
    125   clang::CompoundStmt::body_iterator bE = CS->body_end();
    126   clang::Stmt **UpdatedStmtList =
    127       new clang::Stmt*[CS->size() + StmtList.size()];
    128 
    129   unsigned UpdatedStmtCount = 0;
    130   unsigned Once = 0;
    131   for ( ; bI != bE; bI++) {
    132     if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
    133       // If we come across a return here, we don't have anything we can
    134       // reasonably replace. We should have already inserted our destructor
    135       // code in the proper spot, so we just clean up and return.
    136       delete [] UpdatedStmtList;
    137 
    138       return;
    139     }
    140 
    141     UpdatedStmtList[UpdatedStmtCount++] = *bI;
    142 
    143     if ((*bI == S) && !Once) {
    144       Once++;
    145       std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
    146       std::list<clang::Stmt*>::const_iterator E = StmtList.end();
    147       for ( ; I != E; I++) {
    148         UpdatedStmtList[UpdatedStmtCount++] = *I;
    149       }
    150     }
    151   }
    152   slangAssert(Once <= 1);
    153 
    154   // When S is nullptr, we are appending to the end of the CompoundStmt.
    155   if (!S) {
    156     slangAssert(Once == 0);
    157     std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
    158     std::list<clang::Stmt*>::const_iterator E = StmtList.end();
    159     for ( ; I != E; I++) {
    160       UpdatedStmtList[UpdatedStmtCount++] = *I;
    161     }
    162   }
    163 
    164   CS->setStmts(C, UpdatedStmtList, UpdatedStmtCount);
    165 
    166   delete [] UpdatedStmtList;
    167 }
    168 
    169 // This class visits a compound statement and inserts DtorStmt
    170 // in proper locations. This includes inserting it before any
    171 // return statement in any sub-block, at the end of the logical enclosing
    172 // scope (compound statement), and/or before any break/continue statement that
    173 // would resume outside the declared scope. We will not handle the case for
    174 // goto statements that leave a local scope.
    175 //
    176 // To accomplish these goals, it collects a list of sub-Stmt's that
    177 // correspond to scope exit points. It then uses an RSASTReplace visitor to
    178 // transform the AST, inserting appropriate destructors before each of those
    179 // sub-Stmt's (and also before the exit of the outermost containing Stmt for
    180 // the scope).
    181 class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
    182  private:
    183   clang::ASTContext &mCtx;
    184 
    185   // The loop depth of the currently visited node.
    186   int mLoopDepth;
    187 
    188   // The switch statement depth of the currently visited node.
    189   // Note that this is tracked separately from the loop depth because
    190   // SwitchStmt-contained ContinueStmt's should have destructors for the
    191   // corresponding loop scope.
    192   int mSwitchDepth;
    193 
    194   // The outermost statement block that we are currently visiting.
    195   // This should always be a CompoundStmt.
    196   clang::Stmt *mOuterStmt;
    197 
    198   // The destructor to execute for this scope/variable.
    199   clang::Stmt* mDtorStmt;
    200 
    201   // The stack of statements which should be replaced by a compound statement
    202   // containing the new destructor call followed by the original Stmt.
    203   std::stack<clang::Stmt*> mReplaceStmtStack;
    204 
    205   // The source location for the variable declaration that we are trying to
    206   // insert destructors for. Note that InsertDestructors() will not generate
    207   // destructor calls for source locations that occur lexically before this
    208   // location.
    209   clang::SourceLocation mVarLoc;
    210 
    211  public:
    212   DestructorVisitor(clang::ASTContext &C,
    213                     clang::Stmt* OuterStmt,
    214                     clang::Stmt* DtorStmt,
    215                     clang::SourceLocation VarLoc);
    216 
    217   // This code walks the collected list of Stmts to replace and actually does
    218   // the replacement. It also finishes up by appending the destructor to the
    219   // current outermost CompoundStmt.
    220   void InsertDestructors() {
    221     clang::Stmt *S = nullptr;
    222     clang::SourceManager &SM = mCtx.getSourceManager();
    223     std::list<clang::Stmt *> StmtList;
    224     StmtList.push_back(mDtorStmt);
    225 
    226     while (!mReplaceStmtStack.empty()) {
    227       S = mReplaceStmtStack.top();
    228       mReplaceStmtStack.pop();
    229 
    230       // Skip all source locations that occur before the variable's
    231       // declaration, since it won't have been initialized yet.
    232       if (SM.isBeforeInTranslationUnit(S->getLocStart(), mVarLoc)) {
    233         continue;
    234       }
    235 
    236       StmtList.push_back(S);
    237       clang::CompoundStmt *CS =
    238           BuildCompoundStmt(mCtx, StmtList, S->getLocEnd());
    239       StmtList.pop_back();
    240 
    241       RSASTReplace R(mCtx);
    242       R.ReplaceStmt(mOuterStmt, S, CS);
    243     }
    244     clang::CompoundStmt *CS =
    245       llvm::dyn_cast<clang::CompoundStmt>(mOuterStmt);
    246     slangAssert(CS);
    247     AppendAfterStmt(mCtx, CS, nullptr, StmtList);
    248   }
    249 
    250   void VisitStmt(clang::Stmt *S);
    251   void VisitCompoundStmt(clang::CompoundStmt *CS);
    252 
    253   void VisitBreakStmt(clang::BreakStmt *BS);
    254   void VisitCaseStmt(clang::CaseStmt *CS);
    255   void VisitContinueStmt(clang::ContinueStmt *CS);
    256   void VisitDefaultStmt(clang::DefaultStmt *DS);
    257   void VisitDoStmt(clang::DoStmt *DS);
    258   void VisitForStmt(clang::ForStmt *FS);
    259   void VisitIfStmt(clang::IfStmt *IS);
    260   void VisitReturnStmt(clang::ReturnStmt *RS);
    261   void VisitSwitchCase(clang::SwitchCase *SC);
    262   void VisitSwitchStmt(clang::SwitchStmt *SS);
    263   void VisitWhileStmt(clang::WhileStmt *WS);
    264 };
    265 
    266 DestructorVisitor::DestructorVisitor(clang::ASTContext &C,
    267                          clang::Stmt *OuterStmt,
    268                          clang::Stmt *DtorStmt,
    269                          clang::SourceLocation VarLoc)
    270   : mCtx(C),
    271     mLoopDepth(0),
    272     mSwitchDepth(0),
    273     mOuterStmt(OuterStmt),
    274     mDtorStmt(DtorStmt),
    275     mVarLoc(VarLoc) {
    276 }
    277 
    278 void DestructorVisitor::VisitStmt(clang::Stmt *S) {
    279   for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
    280        I != E;
    281        I++) {
    282     if (clang::Stmt *Child = *I) {
    283       Visit(Child);
    284     }
    285   }
    286 }
    287 
    288 void DestructorVisitor::VisitCompoundStmt(clang::CompoundStmt *CS) {
    289   VisitStmt(CS);
    290 }
    291 
    292 void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
    293   VisitStmt(BS);
    294   if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
    295     mReplaceStmtStack.push(BS);
    296   }
    297 }
    298 
    299 void DestructorVisitor::VisitCaseStmt(clang::CaseStmt *CS) {
    300   VisitStmt(CS);
    301 }
    302 
    303 void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
    304   VisitStmt(CS);
    305   if (mLoopDepth == 0) {
    306     // Switch statements can have nested continues.
    307     mReplaceStmtStack.push(CS);
    308   }
    309 }
    310 
    311 void DestructorVisitor::VisitDefaultStmt(clang::DefaultStmt *DS) {
    312   VisitStmt(DS);
    313 }
    314 
    315 void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
    316   mLoopDepth++;
    317   VisitStmt(DS);
    318   mLoopDepth--;
    319 }
    320 
    321 void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
    322   mLoopDepth++;
    323   VisitStmt(FS);
    324   mLoopDepth--;
    325 }
    326 
    327 void DestructorVisitor::VisitIfStmt(clang::IfStmt *IS) {
    328   VisitStmt(IS);
    329 }
    330 
    331 void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
    332   mReplaceStmtStack.push(RS);
    333 }
    334 
    335 void DestructorVisitor::VisitSwitchCase(clang::SwitchCase *SC) {
    336   slangAssert(false && "Both case and default have specialized handlers");
    337   VisitStmt(SC);
    338 }
    339 
    340 void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
    341   mSwitchDepth++;
    342   VisitStmt(SS);
    343   mSwitchDepth--;
    344 }
    345 
    346 void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
    347   mLoopDepth++;
    348   VisitStmt(WS);
    349   mLoopDepth--;
    350 }
    351 
    352 clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
    353                                  clang::Expr *RefRSVar,
    354                                  clang::SourceLocation Loc) {
    355   slangAssert(RefRSVar);
    356   const clang::Type *T = RefRSVar->getType().getTypePtr();
    357   slangAssert(!T->isArrayType() &&
    358               "Should not be destroying arrays with this function");
    359 
    360   clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
    361   slangAssert((ClearObjectFD != nullptr) &&
    362               "rsClearObject doesn't cover all RS object types");
    363 
    364   clang::QualType ClearObjectFDType = ClearObjectFD->getType();
    365   clang::QualType ClearObjectFDArgType =
    366       ClearObjectFD->getParamDecl(0)->getOriginalType();
    367 
    368   // Example destructor for "rs_font localFont;"
    369   //
    370   // (CallExpr 'void'
    371   //   (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
    372   //     (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
    373   //   (UnaryOperator 'rs_font *' prefix '&'
    374   //     (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
    375 
    376   // Get address of targeted RS object
    377   clang::Expr *AddrRefRSVar =
    378       new(C) clang::UnaryOperator(RefRSVar,
    379                                   clang::UO_AddrOf,
    380                                   ClearObjectFDArgType,
    381                                   clang::VK_RValue,
    382                                   clang::OK_Ordinary,
    383                                   Loc);
    384 
    385   clang::Expr *RefRSClearObjectFD =
    386       clang::DeclRefExpr::Create(C,
    387                                  clang::NestedNameSpecifierLoc(),
    388                                  clang::SourceLocation(),
    389                                  ClearObjectFD,
    390                                  false,
    391                                  ClearObjectFD->getLocation(),
    392                                  ClearObjectFDType,
    393                                  clang::VK_RValue,
    394                                  nullptr);
    395 
    396   clang::Expr *RSClearObjectFP =
    397       clang::ImplicitCastExpr::Create(C,
    398                                       C.getPointerType(ClearObjectFDType),
    399                                       clang::CK_FunctionToPointerDecay,
    400                                       RefRSClearObjectFD,
    401                                       nullptr,
    402                                       clang::VK_RValue);
    403 
    404   llvm::SmallVector<clang::Expr*, 1> ArgList;
    405   ArgList.push_back(AddrRefRSVar);
    406 
    407   clang::CallExpr *RSClearObjectCall =
    408       new(C) clang::CallExpr(C,
    409                              RSClearObjectFP,
    410                              ArgList,
    411                              ClearObjectFD->getCallResultType(),
    412                              clang::VK_RValue,
    413                              Loc);
    414 
    415   return RSClearObjectCall;
    416 }
    417 
    418 static int ArrayDim(const clang::Type *T) {
    419   if (!T || !T->isArrayType()) {
    420     return 0;
    421   }
    422 
    423   const clang::ConstantArrayType *CAT =
    424     static_cast<const clang::ConstantArrayType *>(T);
    425   return static_cast<int>(CAT->getSize().getSExtValue());
    426 }
    427 
    428 static clang::Stmt *ClearStructRSObject(
    429     clang::ASTContext &C,
    430     clang::DeclContext *DC,
    431     clang::Expr *RefRSStruct,
    432     clang::SourceLocation StartLoc,
    433     clang::SourceLocation Loc);
    434 
    435 static clang::Stmt *ClearArrayRSObject(
    436     clang::ASTContext &C,
    437     clang::DeclContext *DC,
    438     clang::Expr *RefRSArr,
    439     clang::SourceLocation StartLoc,
    440     clang::SourceLocation Loc) {
    441   const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
    442   slangAssert(BaseType->isArrayType());
    443 
    444   int NumArrayElements = ArrayDim(BaseType);
    445   // Actually extract out the base RS object type for use later
    446   BaseType = BaseType->getArrayElementTypeNoTypeQual();
    447 
    448   clang::Stmt *StmtArray[2] = {nullptr};
    449   int StmtCtr = 0;
    450 
    451   if (NumArrayElements <= 0) {
    452     return nullptr;
    453   }
    454 
    455   // Example destructor loop for "rs_font fontArr[10];"
    456   //
    457   // (CompoundStmt
    458   //   (DeclStmt "int rsIntIter")
    459   //   (ForStmt
    460   //     (BinaryOperator 'int' '='
    461   //       (DeclRefExpr 'int' Var='rsIntIter')
    462   //       (IntegerLiteral 'int' 0))
    463   //     (BinaryOperator 'int' '<'
    464   //       (DeclRefExpr 'int' Var='rsIntIter')
    465   //       (IntegerLiteral 'int' 10)
    466   //     nullptr << CondVar >>
    467   //     (UnaryOperator 'int' postfix '++'
    468   //       (DeclRefExpr 'int' Var='rsIntIter'))
    469   //     (CallExpr 'void'
    470   //       (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
    471   //         (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
    472   //       (UnaryOperator 'rs_font *' prefix '&'
    473   //         (ArraySubscriptExpr 'rs_font':'rs_font'
    474   //           (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
    475   //             (DeclRefExpr 'rs_font [10]' Var='fontArr'))
    476   //           (DeclRefExpr 'int' Var='rsIntIter')))))))
    477 
    478   // Create helper variable for iterating through elements
    479   clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
    480   clang::VarDecl *IIVD =
    481       clang::VarDecl::Create(C,
    482                              DC,
    483                              StartLoc,
    484                              Loc,
    485                              &II,
    486                              C.IntTy,
    487                              C.getTrivialTypeSourceInfo(C.IntTy),
    488                              clang::SC_None);
    489   // Mark "rsIntIter" as used
    490   IIVD->markUsed(C);
    491   clang::Decl *IID = (clang::Decl *)IIVD;
    492 
    493   clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
    494   StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
    495 
    496   // Form the actual destructor loop
    497   // for (Init; Cond; Inc)
    498   //   RSClearObjectCall;
    499 
    500   // Init -> "rsIntIter = 0"
    501   clang::DeclRefExpr *RefrsIntIter =
    502       clang::DeclRefExpr::Create(C,
    503                                  clang::NestedNameSpecifierLoc(),
    504                                  clang::SourceLocation(),
    505                                  IIVD,
    506                                  false,
    507                                  Loc,
    508                                  C.IntTy,
    509                                  clang::VK_RValue,
    510                                  nullptr);
    511 
    512   clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
    513       llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
    514 
    515   clang::BinaryOperator *Init =
    516       new(C) clang::BinaryOperator(RefrsIntIter,
    517                                    Int0,
    518                                    clang::BO_Assign,
    519                                    C.IntTy,
    520                                    clang::VK_RValue,
    521                                    clang::OK_Ordinary,
    522                                    Loc,
    523                                    false);
    524 
    525   // Cond -> "rsIntIter < NumArrayElements"
    526   clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
    527       llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
    528 
    529   clang::BinaryOperator *Cond =
    530       new(C) clang::BinaryOperator(RefrsIntIter,
    531                                    NumArrayElementsExpr,
    532                                    clang::BO_LT,
    533                                    C.IntTy,
    534                                    clang::VK_RValue,
    535                                    clang::OK_Ordinary,
    536                                    Loc,
    537                                    false);
    538 
    539   // Inc -> "rsIntIter++"
    540   clang::UnaryOperator *Inc =
    541       new(C) clang::UnaryOperator(RefrsIntIter,
    542                                   clang::UO_PostInc,
    543                                   C.IntTy,
    544                                   clang::VK_RValue,
    545                                   clang::OK_Ordinary,
    546                                   Loc);
    547 
    548   // Body -> "rsClearObject(&VD[rsIntIter]);"
    549   // Destructor loop operates on individual array elements
    550 
    551   clang::Expr *RefRSArrPtr =
    552       clang::ImplicitCastExpr::Create(C,
    553           C.getPointerType(BaseType->getCanonicalTypeInternal()),
    554           clang::CK_ArrayToPointerDecay,
    555           RefRSArr,
    556           nullptr,
    557           clang::VK_RValue);
    558 
    559   clang::Expr *RefRSArrPtrSubscript =
    560       new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
    561                                        RefrsIntIter,
    562                                        BaseType->getCanonicalTypeInternal(),
    563                                        clang::VK_RValue,
    564                                        clang::OK_Ordinary,
    565                                        Loc);
    566 
    567   DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
    568 
    569   clang::Stmt *RSClearObjectCall = nullptr;
    570   if (BaseType->isArrayType()) {
    571     RSClearObjectCall =
    572         ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
    573   } else if (DT == DataTypeUnknown) {
    574     RSClearObjectCall =
    575         ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
    576   } else {
    577     RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
    578   }
    579 
    580   clang::ForStmt *DestructorLoop =
    581       new(C) clang::ForStmt(C,
    582                             Init,
    583                             Cond,
    584                             nullptr,  // no condVar
    585                             Inc,
    586                             RSClearObjectCall,
    587                             Loc,
    588                             Loc,
    589                             Loc);
    590 
    591   StmtArray[StmtCtr++] = DestructorLoop;
    592   slangAssert(StmtCtr == 2);
    593 
    594   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
    595       C, llvm::makeArrayRef(StmtArray, StmtCtr), Loc, Loc);
    596 
    597   return CS;
    598 }
    599 
    600 static unsigned CountRSObjectTypes(clang::ASTContext &C,
    601                                    const clang::Type *T,
    602                                    clang::SourceLocation Loc) {
    603   slangAssert(T);
    604   unsigned RSObjectCount = 0;
    605 
    606   if (T->isArrayType()) {
    607     return CountRSObjectTypes(C, T->getArrayElementTypeNoTypeQual(), Loc);
    608   }
    609 
    610   DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
    611   if (DT != DataTypeUnknown) {
    612     return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
    613   }
    614 
    615   if (T->isUnionType()) {
    616     clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
    617     RD = RD->getDefinition();
    618     for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
    619            FE = RD->field_end();
    620          FI != FE;
    621          FI++) {
    622       const clang::FieldDecl *FD = *FI;
    623       const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
    624       if (CountRSObjectTypes(C, FT, Loc)) {
    625         slangAssert(false && "can't have unions with RS object types!");
    626         return 0;
    627       }
    628     }
    629   }
    630 
    631   if (!T->isStructureType()) {
    632     return 0;
    633   }
    634 
    635   clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
    636   RD = RD->getDefinition();
    637   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
    638          FE = RD->field_end();
    639        FI != FE;
    640        FI++) {
    641     const clang::FieldDecl *FD = *FI;
    642     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
    643     if (CountRSObjectTypes(C, FT, Loc)) {
    644       // Sub-structs should only count once (as should arrays, etc.)
    645       RSObjectCount++;
    646     }
    647   }
    648 
    649   return RSObjectCount;
    650 }
    651 
    652 static clang::Stmt *ClearStructRSObject(
    653     clang::ASTContext &C,
    654     clang::DeclContext *DC,
    655     clang::Expr *RefRSStruct,
    656     clang::SourceLocation StartLoc,
    657     clang::SourceLocation Loc) {
    658   const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
    659 
    660   slangAssert(!BaseType->isArrayType());
    661 
    662   // Structs should show up as unknown primitive types
    663   slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
    664               DataTypeUnknown);
    665 
    666   unsigned FieldsToDestroy = CountRSObjectTypes(C, BaseType, Loc);
    667   slangAssert(FieldsToDestroy != 0);
    668 
    669   unsigned StmtCount = 0;
    670   clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
    671   for (unsigned i = 0; i < FieldsToDestroy; i++) {
    672     StmtArray[i] = nullptr;
    673   }
    674 
    675   // Populate StmtArray by creating a destructor for each RS object field
    676   clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
    677   RD = RD->getDefinition();
    678   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
    679          FE = RD->field_end();
    680        FI != FE;
    681        FI++) {
    682     // We just look through all field declarations to see if we find a
    683     // declaration for an RS object type (or an array of one).
    684     bool IsArrayType = false;
    685     clang::FieldDecl *FD = *FI;
    686     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
    687     const clang::Type *OrigType = FT;
    688     while (FT && FT->isArrayType()) {
    689       FT = FT->getArrayElementTypeNoTypeQual();
    690       IsArrayType = true;
    691     }
    692 
    693     // Pass a DeclarationNameInfo with a valid DeclName, since name equality
    694     // gets asserted during CodeGen.
    695     clang::DeclarationNameInfo FDDeclNameInfo(FD->getDeclName(),
    696                                               FD->getLocation());
    697 
    698     if (RSExportPrimitiveType::IsRSObjectType(FT)) {
    699       clang::DeclAccessPair FoundDecl =
    700           clang::DeclAccessPair::make(FD, clang::AS_none);
    701       clang::MemberExpr *RSObjectMember =
    702           clang::MemberExpr::Create(C,
    703                                     RefRSStruct,
    704                                     false,
    705                                     clang::SourceLocation(),
    706                                     clang::NestedNameSpecifierLoc(),
    707                                     clang::SourceLocation(),
    708                                     FD,
    709                                     FoundDecl,
    710                                     FDDeclNameInfo,
    711                                     nullptr,
    712                                     OrigType->getCanonicalTypeInternal(),
    713                                     clang::VK_RValue,
    714                                     clang::OK_Ordinary);
    715 
    716       slangAssert(StmtCount < FieldsToDestroy);
    717 
    718       if (IsArrayType) {
    719         StmtArray[StmtCount++] = ClearArrayRSObject(C,
    720                                                     DC,
    721                                                     RSObjectMember,
    722                                                     StartLoc,
    723                                                     Loc);
    724       } else {
    725         StmtArray[StmtCount++] = ClearSingleRSObject(C,
    726                                                      RSObjectMember,
    727                                                      Loc);
    728       }
    729     } else if (FT->isStructureType() && CountRSObjectTypes(C, FT, Loc)) {
    730       // In this case, we have a nested struct. We may not end up filling all
    731       // of the spaces in StmtArray (sub-structs should handle themselves
    732       // with separate compound statements).
    733       clang::DeclAccessPair FoundDecl =
    734           clang::DeclAccessPair::make(FD, clang::AS_none);
    735       clang::MemberExpr *RSObjectMember =
    736           clang::MemberExpr::Create(C,
    737                                     RefRSStruct,
    738                                     false,
    739                                     clang::SourceLocation(),
    740                                     clang::NestedNameSpecifierLoc(),
    741                                     clang::SourceLocation(),
    742                                     FD,
    743                                     FoundDecl,
    744                                     clang::DeclarationNameInfo(),
    745                                     nullptr,
    746                                     OrigType->getCanonicalTypeInternal(),
    747                                     clang::VK_RValue,
    748                                     clang::OK_Ordinary);
    749 
    750       if (IsArrayType) {
    751         StmtArray[StmtCount++] = ClearArrayRSObject(C,
    752                                                     DC,
    753                                                     RSObjectMember,
    754                                                     StartLoc,
    755                                                     Loc);
    756       } else {
    757         StmtArray[StmtCount++] = ClearStructRSObject(C,
    758                                                      DC,
    759                                                      RSObjectMember,
    760                                                      StartLoc,
    761                                                      Loc);
    762       }
    763     }
    764   }
    765 
    766   slangAssert(StmtCount > 0);
    767   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
    768       C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
    769 
    770   delete [] StmtArray;
    771 
    772   return CS;
    773 }
    774 
    775 static clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
    776                                             clang::Expr *DstExpr,
    777                                             clang::Expr *SrcExpr,
    778                                             clang::SourceLocation StartLoc,
    779                                             clang::SourceLocation Loc) {
    780   const clang::Type *T = DstExpr->getType().getTypePtr();
    781   clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
    782   slangAssert((SetObjectFD != nullptr) &&
    783               "rsSetObject doesn't cover all RS object types");
    784 
    785   clang::QualType SetObjectFDType = SetObjectFD->getType();
    786   clang::QualType SetObjectFDArgType[2];
    787   SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
    788   SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
    789 
    790   clang::Expr *RefRSSetObjectFD =
    791       clang::DeclRefExpr::Create(C,
    792                                  clang::NestedNameSpecifierLoc(),
    793                                  clang::SourceLocation(),
    794                                  SetObjectFD,
    795                                  false,
    796                                  Loc,
    797                                  SetObjectFDType,
    798                                  clang::VK_RValue,
    799                                  nullptr);
    800 
    801   clang::Expr *RSSetObjectFP =
    802       clang::ImplicitCastExpr::Create(C,
    803                                       C.getPointerType(SetObjectFDType),
    804                                       clang::CK_FunctionToPointerDecay,
    805                                       RefRSSetObjectFD,
    806                                       nullptr,
    807                                       clang::VK_RValue);
    808 
    809   llvm::SmallVector<clang::Expr*, 2> ArgList;
    810   ArgList.push_back(new(C) clang::UnaryOperator(DstExpr,
    811                                                 clang::UO_AddrOf,
    812                                                 SetObjectFDArgType[0],
    813                                                 clang::VK_RValue,
    814                                                 clang::OK_Ordinary,
    815                                                 Loc));
    816   ArgList.push_back(SrcExpr);
    817 
    818   clang::CallExpr *RSSetObjectCall =
    819       new(C) clang::CallExpr(C,
    820                              RSSetObjectFP,
    821                              ArgList,
    822                              SetObjectFD->getCallResultType(),
    823                              clang::VK_RValue,
    824                              Loc);
    825 
    826   return RSSetObjectCall;
    827 }
    828 
    829 static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
    830                                             clang::Expr *LHS,
    831                                             clang::Expr *RHS,
    832                                             clang::SourceLocation StartLoc,
    833                                             clang::SourceLocation Loc);
    834 
    835 /*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
    836                                            clang::Expr *DstArr,
    837                                            clang::Expr *SrcArr,
    838                                            clang::SourceLocation StartLoc,
    839                                            clang::SourceLocation Loc) {
    840   clang::DeclContext *DC = nullptr;
    841   const clang::Type *BaseType = DstArr->getType().getTypePtr();
    842   slangAssert(BaseType->isArrayType());
    843 
    844   int NumArrayElements = ArrayDim(BaseType);
    845   // Actually extract out the base RS object type for use later
    846   BaseType = BaseType->getArrayElementTypeNoTypeQual();
    847 
    848   clang::Stmt *StmtArray[2] = {nullptr};
    849   int StmtCtr = 0;
    850 
    851   if (NumArrayElements <= 0) {
    852     return nullptr;
    853   }
    854 
    855   // Create helper variable for iterating through elements
    856   clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
    857   clang::VarDecl *IIVD =
    858       clang::VarDecl::Create(C,
    859                              DC,
    860                              StartLoc,
    861                              Loc,
    862                              &II,
    863                              C.IntTy,
    864                              C.getTrivialTypeSourceInfo(C.IntTy),
    865                              clang::SC_None,
    866                              clang::SC_None);
    867   clang::Decl *IID = (clang::Decl *)IIVD;
    868 
    869   clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
    870   StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
    871 
    872   // Form the actual loop
    873   // for (Init; Cond; Inc)
    874   //   RSSetObjectCall;
    875 
    876   // Init -> "rsIntIter = 0"
    877   clang::DeclRefExpr *RefrsIntIter =
    878       clang::DeclRefExpr::Create(C,
    879                                  clang::NestedNameSpecifierLoc(),
    880                                  IIVD,
    881                                  Loc,
    882                                  C.IntTy,
    883                                  clang::VK_RValue,
    884                                  nullptr);
    885 
    886   clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
    887       llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
    888 
    889   clang::BinaryOperator *Init =
    890       new(C) clang::BinaryOperator(RefrsIntIter,
    891                                    Int0,
    892                                    clang::BO_Assign,
    893                                    C.IntTy,
    894                                    clang::VK_RValue,
    895                                    clang::OK_Ordinary,
    896                                    Loc);
    897 
    898   // Cond -> "rsIntIter < NumArrayElements"
    899   clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
    900       llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
    901 
    902   clang::BinaryOperator *Cond =
    903       new(C) clang::BinaryOperator(RefrsIntIter,
    904                                    NumArrayElementsExpr,
    905                                    clang::BO_LT,
    906                                    C.IntTy,
    907                                    clang::VK_RValue,
    908                                    clang::OK_Ordinary,
    909                                    Loc);
    910 
    911   // Inc -> "rsIntIter++"
    912   clang::UnaryOperator *Inc =
    913       new(C) clang::UnaryOperator(RefrsIntIter,
    914                                   clang::UO_PostInc,
    915                                   C.IntTy,
    916                                   clang::VK_RValue,
    917                                   clang::OK_Ordinary,
    918                                   Loc);
    919 
    920   // Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
    921   // Loop operates on individual array elements
    922 
    923   clang::Expr *DstArrPtr =
    924       clang::ImplicitCastExpr::Create(C,
    925           C.getPointerType(BaseType->getCanonicalTypeInternal()),
    926           clang::CK_ArrayToPointerDecay,
    927           DstArr,
    928           nullptr,
    929           clang::VK_RValue);
    930 
    931   clang::Expr *DstArrPtrSubscript =
    932       new(C) clang::ArraySubscriptExpr(DstArrPtr,
    933                                        RefrsIntIter,
    934                                        BaseType->getCanonicalTypeInternal(),
    935                                        clang::VK_RValue,
    936                                        clang::OK_Ordinary,
    937                                        Loc);
    938 
    939   clang::Expr *SrcArrPtr =
    940       clang::ImplicitCastExpr::Create(C,
    941           C.getPointerType(BaseType->getCanonicalTypeInternal()),
    942           clang::CK_ArrayToPointerDecay,
    943           SrcArr,
    944           nullptr,
    945           clang::VK_RValue);
    946 
    947   clang::Expr *SrcArrPtrSubscript =
    948       new(C) clang::ArraySubscriptExpr(SrcArrPtr,
    949                                        RefrsIntIter,
    950                                        BaseType->getCanonicalTypeInternal(),
    951                                        clang::VK_RValue,
    952                                        clang::OK_Ordinary,
    953                                        Loc);
    954 
    955   DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
    956 
    957   clang::Stmt *RSSetObjectCall = nullptr;
    958   if (BaseType->isArrayType()) {
    959     RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
    960                                              SrcArrPtrSubscript,
    961                                              StartLoc, Loc);
    962   } else if (DT == DataTypeUnknown) {
    963     RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
    964                                               SrcArrPtrSubscript,
    965                                               StartLoc, Loc);
    966   } else {
    967     RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
    968                                               SrcArrPtrSubscript,
    969                                               StartLoc, Loc);
    970   }
    971 
    972   clang::ForStmt *DestructorLoop =
    973       new(C) clang::ForStmt(C,
    974                             Init,
    975                             Cond,
    976                             nullptr,  // no condVar
    977                             Inc,
    978                             RSSetObjectCall,
    979                             Loc,
    980                             Loc,
    981                             Loc);
    982 
    983   StmtArray[StmtCtr++] = DestructorLoop;
    984   slangAssert(StmtCtr == 2);
    985 
    986   clang::CompoundStmt *CS =
    987       new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
    988 
    989   return CS;
    990 } */
    991 
    992 static clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
    993                                             clang::Expr *LHS,
    994                                             clang::Expr *RHS,
    995                                             clang::SourceLocation StartLoc,
    996                                             clang::SourceLocation Loc) {
    997   clang::QualType QT = LHS->getType();
    998   const clang::Type *T = QT.getTypePtr();
    999   slangAssert(T->isStructureType());
   1000   slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
   1001 
   1002   // Keep an extra slot for the original copy (memcpy)
   1003   unsigned FieldsToSet = CountRSObjectTypes(C, T, Loc) + 1;
   1004 
   1005   unsigned StmtCount = 0;
   1006   clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
   1007   for (unsigned i = 0; i < FieldsToSet; i++) {
   1008     StmtArray[i] = nullptr;
   1009   }
   1010 
   1011   clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
   1012   RD = RD->getDefinition();
   1013   for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
   1014          FE = RD->field_end();
   1015        FI != FE;
   1016        FI++) {
   1017     bool IsArrayType = false;
   1018     clang::FieldDecl *FD = *FI;
   1019     const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
   1020     const clang::Type *OrigType = FT;
   1021 
   1022     if (!CountRSObjectTypes(C, FT, Loc)) {
   1023       // Skip to next if we don't have any viable RS object types
   1024       continue;
   1025     }
   1026 
   1027     clang::DeclAccessPair FoundDecl =
   1028         clang::DeclAccessPair::make(FD, clang::AS_none);
   1029     clang::MemberExpr *DstMember =
   1030         clang::MemberExpr::Create(C,
   1031                                   LHS,
   1032                                   false,
   1033                                   clang::SourceLocation(),
   1034                                   clang::NestedNameSpecifierLoc(),
   1035                                   clang::SourceLocation(),
   1036                                   FD,
   1037                                   FoundDecl,
   1038                                   clang::DeclarationNameInfo(),
   1039                                   nullptr,
   1040                                   OrigType->getCanonicalTypeInternal(),
   1041                                   clang::VK_RValue,
   1042                                   clang::OK_Ordinary);
   1043 
   1044     clang::MemberExpr *SrcMember =
   1045         clang::MemberExpr::Create(C,
   1046                                   RHS,
   1047                                   false,
   1048                                   clang::SourceLocation(),
   1049                                   clang::NestedNameSpecifierLoc(),
   1050                                   clang::SourceLocation(),
   1051                                   FD,
   1052                                   FoundDecl,
   1053                                   clang::DeclarationNameInfo(),
   1054                                   nullptr,
   1055                                   OrigType->getCanonicalTypeInternal(),
   1056                                   clang::VK_RValue,
   1057                                   clang::OK_Ordinary);
   1058 
   1059     if (FT->isArrayType()) {
   1060       FT = FT->getArrayElementTypeNoTypeQual();
   1061       IsArrayType = true;
   1062     }
   1063 
   1064     DataType DT = RSExportPrimitiveType::GetRSSpecificType(FT);
   1065 
   1066     if (IsArrayType) {
   1067       clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
   1068       DiagEngine.Report(
   1069         clang::FullSourceLoc(Loc, C.getSourceManager()),
   1070         DiagEngine.getCustomDiagID(
   1071           clang::DiagnosticsEngine::Error,
   1072           "Arrays of RS object types within structures cannot be copied"));
   1073       // TODO(srhines): Support setting arrays of RS objects
   1074       // StmtArray[StmtCount++] =
   1075       //    CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
   1076     } else if (DT == DataTypeUnknown) {
   1077       StmtArray[StmtCount++] =
   1078           CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
   1079     } else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
   1080       StmtArray[StmtCount++] =
   1081           CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
   1082     } else {
   1083       slangAssert(false);
   1084     }
   1085   }
   1086 
   1087   slangAssert(StmtCount < FieldsToSet);
   1088 
   1089   // We still need to actually do the overall struct copy. For simplicity,
   1090   // we just do a straight-up assignment (which will still preserve all
   1091   // the proper RS object reference counts).
   1092   clang::BinaryOperator *CopyStruct =
   1093       new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
   1094                                    clang::VK_RValue, clang::OK_Ordinary, Loc,
   1095                                    false);
   1096   StmtArray[StmtCount++] = CopyStruct;
   1097 
   1098   clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
   1099       C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
   1100 
   1101   delete [] StmtArray;
   1102 
   1103   return CS;
   1104 }
   1105 
   1106 }  // namespace
   1107 
   1108 void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
   1109     clang::BinaryOperator *AS) {
   1110 
   1111   clang::QualType QT = AS->getType();
   1112 
   1113   clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
   1114       DataTypeRSAllocation)->getASTContext();
   1115 
   1116   clang::SourceLocation Loc = AS->getExprLoc();
   1117   clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
   1118   clang::Stmt *UpdatedStmt = nullptr;
   1119 
   1120   if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
   1121     // By definition, this is a struct assignment if we get here
   1122     UpdatedStmt =
   1123         CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
   1124   } else {
   1125     UpdatedStmt =
   1126         CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
   1127   }
   1128 
   1129   RSASTReplace R(C);
   1130   R.ReplaceStmt(mCS, AS, UpdatedStmt);
   1131 }
   1132 
   1133 void RSObjectRefCount::Scope::AppendRSObjectInit(
   1134     clang::VarDecl *VD,
   1135     clang::DeclStmt *DS,
   1136     DataType DT,
   1137     clang::Expr *InitExpr) {
   1138   slangAssert(VD);
   1139 
   1140   if (!InitExpr) {
   1141     return;
   1142   }
   1143 
   1144   clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
   1145       DataTypeRSAllocation)->getASTContext();
   1146   clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
   1147       DataTypeRSAllocation)->getLocation();
   1148   clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
   1149       DataTypeRSAllocation)->getInnerLocStart();
   1150 
   1151   if (DT == DataTypeIsStruct) {
   1152     const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
   1153     clang::DeclRefExpr *RefRSVar =
   1154         clang::DeclRefExpr::Create(C,
   1155                                    clang::NestedNameSpecifierLoc(),
   1156                                    clang::SourceLocation(),
   1157                                    VD,
   1158                                    false,
   1159                                    Loc,
   1160                                    T->getCanonicalTypeInternal(),
   1161                                    clang::VK_RValue,
   1162                                    nullptr);
   1163 
   1164     clang::Stmt *RSSetObjectOps =
   1165         CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
   1166 
   1167     std::list<clang::Stmt*> StmtList;
   1168     StmtList.push_back(RSSetObjectOps);
   1169     AppendAfterStmt(C, mCS, DS, StmtList);
   1170     return;
   1171   }
   1172 
   1173   clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
   1174   slangAssert((SetObjectFD != nullptr) &&
   1175               "rsSetObject doesn't cover all RS object types");
   1176 
   1177   clang::QualType SetObjectFDType = SetObjectFD->getType();
   1178   clang::QualType SetObjectFDArgType[2];
   1179   SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
   1180   SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
   1181 
   1182   clang::Expr *RefRSSetObjectFD =
   1183       clang::DeclRefExpr::Create(C,
   1184                                  clang::NestedNameSpecifierLoc(),
   1185                                  clang::SourceLocation(),
   1186                                  SetObjectFD,
   1187                                  false,
   1188                                  Loc,
   1189                                  SetObjectFDType,
   1190                                  clang::VK_RValue,
   1191                                  nullptr);
   1192 
   1193   clang::Expr *RSSetObjectFP =
   1194       clang::ImplicitCastExpr::Create(C,
   1195                                       C.getPointerType(SetObjectFDType),
   1196                                       clang::CK_FunctionToPointerDecay,
   1197                                       RefRSSetObjectFD,
   1198                                       nullptr,
   1199                                       clang::VK_RValue);
   1200 
   1201   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
   1202   clang::DeclRefExpr *RefRSVar =
   1203       clang::DeclRefExpr::Create(C,
   1204                                  clang::NestedNameSpecifierLoc(),
   1205                                  clang::SourceLocation(),
   1206                                  VD,
   1207                                  false,
   1208                                  Loc,
   1209                                  T->getCanonicalTypeInternal(),
   1210                                  clang::VK_RValue,
   1211                                  nullptr);
   1212 
   1213   llvm::SmallVector<clang::Expr*, 2> ArgList;
   1214   ArgList.push_back(new(C) clang::UnaryOperator(RefRSVar,
   1215                                                 clang::UO_AddrOf,
   1216                                                 SetObjectFDArgType[0],
   1217                                                 clang::VK_RValue,
   1218                                                 clang::OK_Ordinary,
   1219                                                 Loc));
   1220   ArgList.push_back(InitExpr);
   1221 
   1222   clang::CallExpr *RSSetObjectCall =
   1223       new(C) clang::CallExpr(C,
   1224                              RSSetObjectFP,
   1225                              ArgList,
   1226                              SetObjectFD->getCallResultType(),
   1227                              clang::VK_RValue,
   1228                              Loc);
   1229 
   1230   std::list<clang::Stmt*> StmtList;
   1231   StmtList.push_back(RSSetObjectCall);
   1232   AppendAfterStmt(C, mCS, DS, StmtList);
   1233 }
   1234 
   1235 void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
   1236   for (std::list<clang::VarDecl*>::const_iterator I = mRSO.begin(),
   1237           E = mRSO.end();
   1238         I != E;
   1239         I++) {
   1240     clang::VarDecl *VD = *I;
   1241     clang::Stmt *RSClearObjectCall = ClearRSObject(VD, VD->getDeclContext());
   1242     if (RSClearObjectCall) {
   1243       clang::ASTContext &C = (*mRSO.begin())->getASTContext();
   1244       // Mark VD as used.  It might be unused, except for the destructor.
   1245       // 'markUsed' has side-effects that are caused only if VD is not already
   1246       // used.  Hence no need for an extra check here.
   1247       VD->markUsed(C);
   1248       DestructorVisitor DV(C,
   1249                            mCS,
   1250                            RSClearObjectCall,
   1251                            VD->getSourceRange().getBegin());
   1252       DV.Visit(mCS);
   1253       DV.InsertDestructors();
   1254     }
   1255   }
   1256 }
   1257 
   1258 clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
   1259     clang::VarDecl *VD,
   1260     clang::DeclContext *DC) {
   1261   slangAssert(VD);
   1262   clang::ASTContext &C = VD->getASTContext();
   1263   clang::SourceLocation Loc = VD->getLocation();
   1264   clang::SourceLocation StartLoc = VD->getInnerLocStart();
   1265   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
   1266 
   1267   // Reference expr to target RS object variable
   1268   clang::DeclRefExpr *RefRSVar =
   1269       clang::DeclRefExpr::Create(C,
   1270                                  clang::NestedNameSpecifierLoc(),
   1271                                  clang::SourceLocation(),
   1272                                  VD,
   1273                                  false,
   1274                                  Loc,
   1275                                  T->getCanonicalTypeInternal(),
   1276                                  clang::VK_RValue,
   1277                                  nullptr);
   1278 
   1279   if (T->isArrayType()) {
   1280     return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
   1281   }
   1282 
   1283   DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
   1284 
   1285   if (DT == DataTypeUnknown ||
   1286       DT == DataTypeIsStruct) {
   1287     return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
   1288   }
   1289 
   1290   slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
   1291               "Should be RS object");
   1292 
   1293   return ClearSingleRSObject(C, RefRSVar, Loc);
   1294 }
   1295 
   1296 bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
   1297                                           DataType *DT,
   1298                                           clang::Expr **InitExpr) {
   1299   slangAssert(VD && DT && InitExpr);
   1300   const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
   1301 
   1302   // Loop through array types to get to base type
   1303   while (T && T->isArrayType()) {
   1304     T = T->getArrayElementTypeNoTypeQual();
   1305   }
   1306 
   1307   bool DataTypeIsStructWithRSObject = false;
   1308   *DT = RSExportPrimitiveType::GetRSSpecificType(T);
   1309 
   1310   if (*DT == DataTypeUnknown) {
   1311     if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
   1312       *DT = DataTypeIsStruct;
   1313       DataTypeIsStructWithRSObject = true;
   1314     } else {
   1315       return false;
   1316     }
   1317   }
   1318 
   1319   bool DataTypeIsRSObject = false;
   1320   if (DataTypeIsStructWithRSObject) {
   1321     DataTypeIsRSObject = true;
   1322   } else {
   1323     DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
   1324   }
   1325   *InitExpr = VD->getInit();
   1326 
   1327   if (!DataTypeIsRSObject && *InitExpr) {
   1328     // If we already have an initializer for a matrix type, we are done.
   1329     return DataTypeIsRSObject;
   1330   }
   1331 
   1332   clang::Expr *ZeroInitializer =
   1333       CreateZeroInitializerForRSSpecificType(*DT,
   1334                                              VD->getASTContext(),
   1335                                              VD->getLocation());
   1336 
   1337   if (ZeroInitializer) {
   1338     ZeroInitializer->setType(T->getCanonicalTypeInternal());
   1339     VD->setInit(ZeroInitializer);
   1340   }
   1341 
   1342   return DataTypeIsRSObject;
   1343 }
   1344 
   1345 clang::Expr *RSObjectRefCount::CreateZeroInitializerForRSSpecificType(
   1346     DataType DT,
   1347     clang::ASTContext &C,
   1348     const clang::SourceLocation &Loc) {
   1349   clang::Expr *Res = nullptr;
   1350   switch (DT) {
   1351     case DataTypeIsStruct:
   1352     case DataTypeRSElement:
   1353     case DataTypeRSType:
   1354     case DataTypeRSAllocation:
   1355     case DataTypeRSSampler:
   1356     case DataTypeRSScript:
   1357     case DataTypeRSMesh:
   1358     case DataTypeRSPath:
   1359     case DataTypeRSProgramFragment:
   1360     case DataTypeRSProgramVertex:
   1361     case DataTypeRSProgramRaster:
   1362     case DataTypeRSProgramStore:
   1363     case DataTypeRSFont: {
   1364       //    (ImplicitCastExpr 'nullptr_t'
   1365       //      (IntegerLiteral 0)))
   1366       llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
   1367       clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
   1368       clang::Expr *CastToNull =
   1369           clang::ImplicitCastExpr::Create(C,
   1370                                           C.NullPtrTy,
   1371                                           clang::CK_IntegralToPointer,
   1372                                           Int0,
   1373                                           nullptr,
   1374                                           clang::VK_RValue);
   1375 
   1376       llvm::SmallVector<clang::Expr*, 1>InitList;
   1377       InitList.push_back(CastToNull);
   1378 
   1379       Res = new(C) clang::InitListExpr(C, Loc, InitList, Loc);
   1380       break;
   1381     }
   1382     case DataTypeRSMatrix2x2:
   1383     case DataTypeRSMatrix3x3:
   1384     case DataTypeRSMatrix4x4: {
   1385       // RS matrix is not completely an RS object. They hold data by themselves.
   1386       // (InitListExpr rs_matrix2x2
   1387       //   (InitListExpr float[4]
   1388       //     (FloatingLiteral 0)
   1389       //     (FloatingLiteral 0)
   1390       //     (FloatingLiteral 0)
   1391       //     (FloatingLiteral 0)))
   1392       clang::QualType FloatTy = C.FloatTy;
   1393       // Constructor sets value to 0.0f by default
   1394       llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
   1395       clang::FloatingLiteral *Float0Val =
   1396           clang::FloatingLiteral::Create(C,
   1397                                          Val,
   1398                                          /* isExact = */true,
   1399                                          FloatTy,
   1400                                          Loc);
   1401 
   1402       unsigned N = 0;
   1403       if (DT == DataTypeRSMatrix2x2)
   1404         N = 2;
   1405       else if (DT == DataTypeRSMatrix3x3)
   1406         N = 3;
   1407       else if (DT == DataTypeRSMatrix4x4)
   1408         N = 4;
   1409       unsigned N_2 = N * N;
   1410 
   1411       // Assume we are going to be allocating 16 elements, since 4x4 is max.
   1412       llvm::SmallVector<clang::Expr*, 16> InitVals;
   1413       for (unsigned i = 0; i < N_2; i++)
   1414         InitVals.push_back(Float0Val);
   1415       clang::Expr *InitExpr =
   1416           new(C) clang::InitListExpr(C, Loc, InitVals, Loc);
   1417       InitExpr->setType(C.getConstantArrayType(FloatTy,
   1418                                                llvm::APInt(32, N_2),
   1419                                                clang::ArrayType::Normal,
   1420                                                /* EltTypeQuals = */0));
   1421       llvm::SmallVector<clang::Expr*, 1> InitExprVec;
   1422       InitExprVec.push_back(InitExpr);
   1423 
   1424       Res = new(C) clang::InitListExpr(C, Loc, InitExprVec, Loc);
   1425       break;
   1426     }
   1427     case DataTypeUnknown:
   1428     case DataTypeFloat16:
   1429     case DataTypeFloat32:
   1430     case DataTypeFloat64:
   1431     case DataTypeSigned8:
   1432     case DataTypeSigned16:
   1433     case DataTypeSigned32:
   1434     case DataTypeSigned64:
   1435     case DataTypeUnsigned8:
   1436     case DataTypeUnsigned16:
   1437     case DataTypeUnsigned32:
   1438     case DataTypeUnsigned64:
   1439     case DataTypeBoolean:
   1440     case DataTypeUnsigned565:
   1441     case DataTypeUnsigned5551:
   1442     case DataTypeUnsigned4444:
   1443     case DataTypeMax: {
   1444       slangAssert(false && "Not RS object type!");
   1445     }
   1446     // No default case will enable compiler detecting the missing cases
   1447   }
   1448 
   1449   return Res;
   1450 }
   1451 
   1452 void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
   1453   for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
   1454        I != E;
   1455        I++) {
   1456     clang::Decl *D = *I;
   1457     if (D->getKind() == clang::Decl::Var) {
   1458       clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
   1459       DataType DT = DataTypeUnknown;
   1460       clang::Expr *InitExpr = nullptr;
   1461       if (InitializeRSObject(VD, &DT, &InitExpr)) {
   1462         // We need to zero-init all RS object types (including matrices), ...
   1463         getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
   1464         // ... but, only add to the list of RS objects if we have some
   1465         // non-matrix RS object fields.
   1466         if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(),
   1467                                VD->getLocation())) {
   1468           getCurrentScope()->addRSObject(VD);
   1469         }
   1470       }
   1471     }
   1472   }
   1473 }
   1474 
   1475 void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
   1476   if (!CS->body_empty()) {
   1477     // Push a new scope
   1478     Scope *S = new Scope(CS);
   1479     mScopeStack.push(S);
   1480 
   1481     VisitStmt(CS);
   1482 
   1483     // Destroy the scope
   1484     slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
   1485     S->InsertLocalVarDestructors();
   1486     mScopeStack.pop();
   1487     delete S;
   1488   }
   1489 }
   1490 
   1491 void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
   1492   clang::QualType QT = AS->getType();
   1493 
   1494   if (CountRSObjectTypes(mCtx, QT.getTypePtr(), AS->getExprLoc())) {
   1495     getCurrentScope()->ReplaceRSObjectAssignment(AS);
   1496   }
   1497 }
   1498 
   1499 void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
   1500   for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
   1501        I != E;
   1502        I++) {
   1503     if (clang::Stmt *Child = *I) {
   1504       Visit(Child);
   1505     }
   1506   }
   1507 }
   1508 
   1509 // This function walks the list of global variables and (potentially) creates
   1510 // a single global static destructor function that properly decrements
   1511 // reference counts on the contained RS object types.
   1512 clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
   1513   Init();
   1514 
   1515   clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
   1516   clang::SourceLocation loc;
   1517 
   1518   llvm::StringRef SR(".rs.dtor");
   1519   clang::IdentifierInfo &II = mCtx.Idents.get(SR);
   1520   clang::DeclarationName N(&II);
   1521   clang::FunctionProtoType::ExtProtoInfo EPI;
   1522   clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy,
   1523       llvm::ArrayRef<clang::QualType>(), EPI);
   1524   clang::FunctionDecl *FD = nullptr;
   1525 
   1526   // Generate rsClearObject() call chains for every global variable
   1527   // (whether static or extern).
   1528   std::list<clang::Stmt *> StmtList;
   1529   for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
   1530           E = DC->decls_end(); I != E; I++) {
   1531     clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
   1532     if (VD) {
   1533       if (CountRSObjectTypes(mCtx, VD->getType().getTypePtr(), loc)) {
   1534         if (!FD) {
   1535           // Only create FD if we are going to use it.
   1536           FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, nullptr,
   1537                                            clang::SC_None);
   1538         }
   1539         // Mark VD as used.  It might be unused, except for the destructor.
   1540         // 'markUsed' has side-effects that are caused only if VD is not already
   1541         // used.  Hence no need for an extra check here.
   1542         VD->markUsed(mCtx);
   1543         // Make sure to create any helpers within the function's DeclContext,
   1544         // not the one associated with the global translation unit.
   1545         clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
   1546         StmtList.push_back(RSClearObjectCall);
   1547       }
   1548     }
   1549   }
   1550 
   1551   // Nothing needs to be destroyed, so don't emit a dtor.
   1552   if (StmtList.empty()) {
   1553     return nullptr;
   1554   }
   1555 
   1556   clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
   1557 
   1558   FD->setBody(CS);
   1559 
   1560   return FD;
   1561 }
   1562 
   1563 }  // namespace slang
   1564