Home | History | Annotate | Download | only in slang
      1 /*
      2  * Copyright 2015, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "slang_rs_foreach_lowering.h"
     18 
     19 #include "clang/AST/ASTContext.h"
     20 #include "clang/AST/Attr.h"
     21 #include "llvm/Support/raw_ostream.h"
     22 #include "slang_rs_context.h"
     23 #include "slang_rs_export_foreach.h"
     24 
     25 namespace slang {
     26 
     27 namespace {
     28 
     29 const char KERNEL_LAUNCH_FUNCTION_NAME[] = "rsForEach";
     30 const char KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS[] = "rsForEachWithOptions";
     31 const char INTERNAL_LAUNCH_FUNCTION_NAME[] =
     32     "_Z17rsForEachInternaliP14rs_script_calliiP13rs_allocation";
     33 
     34 }  // anonymous namespace
     35 
     36 RSForEachLowering::RSForEachLowering(RSContext* ctxt)
     37     : mCtxt(ctxt), mASTCtxt(ctxt->getASTContext()) {}
     38 
     39 // Check if the passed-in expr references a kernel function in the following
     40 // pattern in the AST.
     41 //
     42 // ImplicitCastExpr 'void *' <BitCast>
     43 //  `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay>
     44 //    `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)'
     45 const clang::FunctionDecl* RSForEachLowering::matchFunctionDesignator(
     46     clang::Expr* expr) {
     47   clang::ImplicitCastExpr* ToVoidPtr =
     48       clang::dyn_cast<clang::ImplicitCastExpr>(expr);
     49   if (ToVoidPtr == nullptr) {
     50     return nullptr;
     51   }
     52 
     53   clang::ImplicitCastExpr* Decay =
     54       clang::dyn_cast<clang::ImplicitCastExpr>(ToVoidPtr->getSubExpr());
     55 
     56   if (Decay == nullptr) {
     57     return nullptr;
     58   }
     59 
     60   clang::DeclRefExpr* DRE =
     61       clang::dyn_cast<clang::DeclRefExpr>(Decay->getSubExpr());
     62 
     63   if (DRE == nullptr) {
     64     return nullptr;
     65   }
     66 
     67   const clang::FunctionDecl* FD =
     68       clang::dyn_cast<clang::FunctionDecl>(DRE->getDecl());
     69 
     70   if (FD == nullptr) {
     71     return nullptr;
     72   }
     73 
     74   return FD;
     75 }
     76 
     77 // Checks if the call expression is a legal rsForEach call by looking for the
     78 // following pattern in the AST. On success, returns the first argument that is
     79 // a FunctionDecl of a kernel function.
     80 //
     81 // CallExpr 'void'
     82 // |
     83 // |-ImplicitCastExpr 'void (*)(void *, ...)' <FunctionToPointerDecay>
     84 // | `-DeclRefExpr  'void (void *, ...)'  'rsForEach' 'void (void *, ...)'
     85 // |
     86 // |-ImplicitCastExpr 'void *' <BitCast>
     87 // | `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay>
     88 // |   `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)'
     89 // |
     90 // |-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue>
     91 // | `-DeclRefExpr 'rs_allocation':'rs_allocation' lvalue ParmVar 'in' 'rs_allocation':'rs_allocation'
     92 // |
     93 // `-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue>
     94 //   `-DeclRefExpr  'rs_allocation':'rs_allocation' lvalue ParmVar 'out' 'rs_allocation':'rs_allocation'
     95 const clang::FunctionDecl* RSForEachLowering::matchKernelLaunchCall(
     96     clang::CallExpr* CE, int* slot, bool* hasOptions) {
     97   const clang::Decl* D = CE->getCalleeDecl();
     98   const clang::FunctionDecl* FD = clang::dyn_cast<clang::FunctionDecl>(D);
     99 
    100   if (FD == nullptr) {
    101     return nullptr;
    102   }
    103 
    104   const clang::StringRef& funcName = FD->getName();
    105 
    106   if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME)) {
    107     *hasOptions = false;
    108   } else if (funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME_WITH_OPTIONS)) {
    109     *hasOptions = true;
    110   } else {
    111     return nullptr;
    112   }
    113 
    114   if (mInsideKernel) {
    115     mCtxt->ReportError(CE->getExprLoc(),
    116         "Invalid kernel launch call made from inside another kernel.");
    117     return nullptr;
    118   }
    119 
    120   clang::Expr* arg0 = CE->getArg(0);
    121   const clang::FunctionDecl* kernel = matchFunctionDesignator(arg0);
    122 
    123   if (kernel == nullptr) {
    124     mCtxt->ReportError(arg0->getExprLoc(),
    125                        "Invalid kernel launch call. "
    126                        "Expects a function designator for the first argument.");
    127     return nullptr;
    128   }
    129 
    130   // Verifies that kernel is indeed a "kernel" function.
    131   *slot = mCtxt->getForEachSlotNumber(kernel);
    132   if (*slot == -1) {
    133     mCtxt->ReportError(CE->getExprLoc(),
    134          "%0 applied to function %1 defined without \"kernel\" attribute")
    135          << funcName << kernel->getName();
    136     return nullptr;
    137   }
    138 
    139   return kernel;
    140 }
    141 
    142 // Create an AST node for the declaration of rsForEachInternal
    143 clang::FunctionDecl* RSForEachLowering::CreateForEachInternalFunctionDecl() {
    144   clang::DeclContext* DC = mASTCtxt.getTranslationUnitDecl();
    145   clang::SourceLocation Loc;
    146 
    147   llvm::StringRef SR(INTERNAL_LAUNCH_FUNCTION_NAME);
    148   clang::IdentifierInfo& II = mASTCtxt.Idents.get(SR);
    149   clang::DeclarationName N(&II);
    150 
    151   clang::FunctionProtoType::ExtProtoInfo EPI;
    152 
    153   const clang::QualType& AllocTy = mCtxt->getAllocationType();
    154   clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy);
    155 
    156   clang::QualType ScriptCallTy = mCtxt->getScriptCallType();
    157   const clang::QualType ScriptCallPtrTy = mASTCtxt.getPointerType(ScriptCallTy);
    158 
    159   clang::QualType ParamTypes[] = {
    160     mASTCtxt.IntTy,   // int slot
    161     ScriptCallPtrTy,  // rs_script_call_t* launch_options
    162     mASTCtxt.IntTy,   // int numOutput
    163     mASTCtxt.IntTy,   // int numInputs
    164     AllocPtrTy        // rs_allocation* allocs
    165   };
    166 
    167   clang::QualType T = mASTCtxt.getFunctionType(
    168       mASTCtxt.VoidTy,  // Return type
    169       ParamTypes,       // Parameter types
    170       EPI);
    171 
    172   clang::FunctionDecl* FD = clang::FunctionDecl::Create(
    173       mASTCtxt, DC, Loc, Loc, N, T, nullptr, clang::SC_Extern);
    174 
    175   static constexpr unsigned kNumParams = sizeof(ParamTypes) / sizeof(ParamTypes[0]);
    176   clang::ParmVarDecl *ParamDecls[kNumParams];
    177   for (unsigned I = 0; I != kNumParams; ++I) {
    178     ParamDecls[I] = clang::ParmVarDecl::Create(mASTCtxt, FD, Loc,
    179         Loc, nullptr, ParamTypes[I], nullptr, clang::SC_None, nullptr);
    180     // Implicit means that this declaration was created by the compiler, and
    181     // not part of the actual source code.
    182     ParamDecls[I]->setImplicit();
    183   }
    184   FD->setParams(llvm::makeArrayRef(ParamDecls, kNumParams));
    185 
    186   // Implicit means that this declaration was created by the compiler, and
    187   // not part of the actual source code.
    188   FD->setImplicit();
    189 
    190   return FD;
    191 }
    192 
    193 // Create an expression like the following that references the rsForEachInternal to
    194 // replace the callee in the original call expression that references rsForEach.
    195 //
    196 // ImplicitCastExpr 'void (*)(int, rs_script_call_t*, int, int, rs_allocation*)' <FunctionToPointerDecay>
    197 // `-DeclRefExpr 'void' Function '_Z17rsForEachInternaliP14rs_script_calliiP13rs_allocation' 'void (int, rs_script_call_t*, int, int, rs_allocation*)'
    198 clang::Expr* RSForEachLowering::CreateCalleeExprForInternalForEach() {
    199   clang::FunctionDecl* FDNew = CreateForEachInternalFunctionDecl();
    200 
    201   const clang::QualType FDNewType = FDNew->getType();
    202 
    203   clang::DeclRefExpr* refExpr = clang::DeclRefExpr::Create(
    204       mASTCtxt, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), FDNew,
    205       false, clang::SourceLocation(), FDNewType, clang::VK_RValue);
    206 
    207   clang::Expr* calleeNew = clang::ImplicitCastExpr::Create(
    208       mASTCtxt, mASTCtxt.getPointerType(FDNewType),
    209       clang::CK_FunctionToPointerDecay, refExpr, nullptr, clang::VK_RValue);
    210 
    211   return calleeNew;
    212 }
    213 
    214 // This visit method checks (via pattern matching) if the call expression is to
    215 // rsForEach, and the arguments satisfy the restrictions on the
    216 // rsForEach API. If so, replace the call with a rsForEachInternal call
    217 // with the first argument replaced by the slot number of the kernel function
    218 // referenced in the original first argument.
    219 //
    220 // See comments to the helper methods defined above for details.
    221 void RSForEachLowering::VisitCallExpr(clang::CallExpr* CE) {
    222   int slot;
    223   bool hasOptions;
    224   const clang::FunctionDecl* kernel = matchKernelLaunchCall(CE, &slot, &hasOptions);
    225   if (kernel == nullptr) {
    226     return;
    227   }
    228 
    229   slangAssert(slot >= 0);
    230 
    231   const unsigned numArgsOrig = CE->getNumArgs();
    232 
    233   clang::QualType resultType = kernel->getReturnType().getCanonicalType();
    234   const unsigned numOutputsExpected = resultType->isVoidType() ? 0 : 1;
    235 
    236   const unsigned numInputsExpected = RSExportForEach::getNumInputs(mCtxt->getTargetAPI(), kernel);
    237 
    238   // Verifies that rsForEach takes the right number of input and output allocations.
    239   // TODO: Check input/output allocation types match kernel function expectation.
    240   const unsigned numAllocations = numArgsOrig - (hasOptions ? 2 : 1);
    241   if (numInputsExpected + numOutputsExpected != numAllocations) {
    242     mCtxt->ReportError(
    243       CE->getExprLoc(),
    244       "Number of input and output allocations unexpected for kernel function %0")
    245     << kernel->getName();
    246     return;
    247   }
    248 
    249   clang::Expr* calleeNew = CreateCalleeExprForInternalForEach();
    250   CE->setCallee(calleeNew);
    251 
    252   const clang::CanQualType IntTy = mASTCtxt.IntTy;
    253   const unsigned IntTySize = mASTCtxt.getTypeSize(IntTy);
    254   const llvm::APInt APIntSlot(IntTySize, slot);
    255   const clang::Expr* arg0 = CE->getArg(0);
    256   const clang::SourceLocation Loc(arg0->getLocStart());
    257   clang::Expr* IntSlotNum =
    258       clang::IntegerLiteral::Create(mASTCtxt, APIntSlot, IntTy, Loc);
    259   CE->setArg(0, IntSlotNum);
    260 
    261   /*
    262     The last few arguments to rsForEach or rsForEachWithOptions are allocations.
    263     Creates a new compound literal of an array initialized with those values, and
    264     passes it to rsForEachInternal as the last (the 5th) argument.
    265 
    266     For example, rsForEach(foo, ain1, ain2, aout) would be translated into
    267     rsForEachInternal(
    268         1,                                   // Slot number for kernel
    269         NULL,                                // Launch options
    270         2,                                   // Number of input allocations
    271         1,                                   // Number of output allocations
    272         (rs_allocation[]){ain1, ain2, aout)  // Input and output allocations
    273     );
    274 
    275     The AST for the rs_allocation array looks like following:
    276 
    277     ImplicitCastExpr 0x99575670 'struct rs_allocation *' <ArrayToPointerDecay>
    278     `-CompoundLiteralExpr 0x99575648 'struct rs_allocation [3]' lvalue
    279       `-InitListExpr 0x99575590 'struct rs_allocation [3]'
    280       |-ImplicitCastExpr 0x99574b38 'rs_allocation':'struct rs_allocation' <LValueToRValue>
    281       | `-DeclRefExpr 0x99574a08 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c408 'ain1' 'rs_allocation':'struct rs_allocation'
    282       |-ImplicitCastExpr 0x99574b50 'rs_allocation':'struct rs_allocation' <LValueToRValue>
    283       | `-DeclRefExpr 0x99574a30 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'ain2' 'rs_allocation':'struct rs_allocation'
    284       `-ImplicitCastExpr 0x99574b68 'rs_allocation':'struct rs_allocation' <LValueToRValue>
    285         `-DeclRefExpr 0x99574a58 'rs_allocation':'struct rs_allocation' lvalue ParmVar 0x9942c478 'aout' 'rs_allocation':'struct rs_allocation'
    286   */
    287 
    288   const clang::QualType& AllocTy = mCtxt->getAllocationType();
    289   const llvm::APInt APIntNumAllocs(IntTySize, numAllocations);
    290   clang::QualType AllocArrayTy = mASTCtxt.getConstantArrayType(
    291       AllocTy,
    292       APIntNumAllocs,
    293       clang::ArrayType::ArraySizeModifier::Normal,
    294       0  // index type qualifiers
    295   );
    296 
    297   const int allocArgIndexEnd = numArgsOrig - 1;
    298   int allocArgIndexStart = allocArgIndexEnd;
    299 
    300   clang::Expr** args = CE->getArgs();
    301 
    302   clang::SourceLocation lparenloc;
    303   clang::SourceLocation rparenloc;
    304 
    305   if (numAllocations > 0) {
    306     allocArgIndexStart = hasOptions ? 2 : 1;
    307     lparenloc = args[allocArgIndexStart]->getExprLoc();
    308     rparenloc = args[allocArgIndexEnd]->getExprLoc();
    309   }
    310 
    311   clang::InitListExpr* init = new (mASTCtxt) clang::InitListExpr(
    312       mASTCtxt,
    313       lparenloc,
    314       llvm::ArrayRef<clang::Expr*>(args + allocArgIndexStart, numAllocations),
    315       rparenloc);
    316   init->setType(AllocArrayTy);
    317 
    318   clang::TypeSourceInfo* ti = mASTCtxt.getTrivialTypeSourceInfo(AllocArrayTy);
    319   clang::CompoundLiteralExpr* CLE = new (mASTCtxt) clang::CompoundLiteralExpr(
    320       lparenloc,
    321       ti,
    322       AllocArrayTy,
    323       clang::VK_LValue,  // A compound literal is an l-value in C.
    324       init,
    325       false  // Not file scope
    326   );
    327 
    328   const clang::QualType AllocPtrTy = mASTCtxt.getPointerType(AllocTy);
    329 
    330   clang::ImplicitCastExpr* Decay = clang::ImplicitCastExpr::Create(
    331       mASTCtxt,
    332       AllocPtrTy,
    333       clang::CK_ArrayToPointerDecay,
    334       CLE,
    335       nullptr,  // C++ cast path
    336       clang::VK_RValue
    337   );
    338 
    339   CE->setNumArgs(mASTCtxt, 5);
    340 
    341   CE->setArg(4, Decay);
    342 
    343   // Sets the new arguments for NULL launch option (if the user does not set one),
    344   // the number of outputs, and the number of inputs.
    345 
    346   if (!hasOptions) {
    347     const llvm::APInt APIntZero(IntTySize, 0);
    348     clang::Expr* IntNull =
    349         clang::IntegerLiteral::Create(mASTCtxt, APIntZero, IntTy, Loc);
    350     clang::QualType ScriptCallTy = mCtxt->getScriptCallType();
    351     const clang::QualType ScriptCallPtrTy = mASTCtxt.getPointerType(ScriptCallTy);
    352     clang::CStyleCastExpr* Cast =
    353         clang::CStyleCastExpr::Create(mASTCtxt,
    354                                       ScriptCallPtrTy,
    355                                       clang::VK_RValue,
    356                                       clang::CK_NullToPointer,
    357                                       IntNull,
    358                                       nullptr,
    359                                       mASTCtxt.getTrivialTypeSourceInfo(ScriptCallPtrTy),
    360                                       clang::SourceLocation(),
    361                                       clang::SourceLocation());
    362     CE->setArg(1, Cast);
    363   }
    364 
    365   const llvm::APInt APIntNumOutput(IntTySize, numOutputsExpected);
    366   clang::Expr* IntNumOutput =
    367       clang::IntegerLiteral::Create(mASTCtxt, APIntNumOutput, IntTy, Loc);
    368   CE->setArg(2, IntNumOutput);
    369 
    370   const llvm::APInt APIntNumInputs(IntTySize, numInputsExpected);
    371   clang::Expr* IntNumInputs =
    372       clang::IntegerLiteral::Create(mASTCtxt, APIntNumInputs, IntTy, Loc);
    373   CE->setArg(3, IntNumInputs);
    374 }
    375 
    376 void RSForEachLowering::VisitStmt(clang::Stmt* S) {
    377   for (clang::Stmt* Child : S->children()) {
    378     if (Child) {
    379       Visit(Child);
    380     }
    381   }
    382 }
    383 
    384 void RSForEachLowering::handleForEachCalls(clang::FunctionDecl* FD,
    385                                            unsigned int targetAPI) {
    386   slangAssert(FD && FD->hasBody());
    387 
    388   mInsideKernel = FD->hasAttr<clang::RenderScriptKernelAttr>();
    389   VisitStmt(FD->getBody());
    390 }
    391 
    392 }  // namespace slang
    393