Home | History | Annotate | Download | only in slang
      1 /*
      2  * Copyright 2011-2012, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "slang_rs_export_foreach.h"
     18 
     19 #include <string>
     20 
     21 #include "clang/AST/ASTContext.h"
     22 #include "clang/AST/Attr.h"
     23 #include "clang/AST/Decl.h"
     24 #include "clang/AST/TypeLoc.h"
     25 
     26 #include "llvm/IR/DerivedTypes.h"
     27 
     28 #include "bcinfo/MetadataExtractor.h"
     29 
     30 #include "slang_assert.h"
     31 #include "slang_rs_context.h"
     32 #include "slang_rs_export_type.h"
     33 #include "slang_version.h"
     34 
     35 namespace {
     36 
     37 const size_t RS_KERNEL_INPUT_LIMIT = 8; // see frameworks/base/libs/rs/cpu_ref/rsCpuCoreRuntime.h
     38 
     39 enum SpecialParameterKind {
     40   SPK_LOCATION, // 'int' or 'unsigned int'
     41   SPK_CONTEXT,  // rs_kernel_context
     42 };
     43 
     44 struct SpecialParameter {
     45   const char *name;
     46   bcinfo::MetadataSignatureBitval bitval;
     47   SpecialParameterKind kind;
     48   SlangTargetAPI minAPI;
     49 };
     50 
     51 // Table entries are in the order parameters must occur in a kernel parameter list.
     52 const SpecialParameter specialParameterTable[] = {
     53   { "context", bcinfo::MD_SIG_Ctxt, SPK_CONTEXT, SLANG_M_TARGET_API },
     54   { "x", bcinfo::MD_SIG_X, SPK_LOCATION, SLANG_MINIMUM_TARGET_API },
     55   { "y", bcinfo::MD_SIG_Y, SPK_LOCATION, SLANG_MINIMUM_TARGET_API },
     56   { "z", bcinfo::MD_SIG_Z, SPK_LOCATION, SLANG_M_TARGET_API },
     57   { nullptr, bcinfo::MD_SIG_None, SPK_LOCATION, SLANG_MINIMUM_TARGET_API }, // marks end of table
     58 };
     59 
     60 // If the specified name matches the name of an entry in
     61 // specialParameterTable, return the corresponding table index.
     62 // Return -1 if not found.
     63 int lookupSpecialParameter(const llvm::StringRef name) {
     64   for (int i = 0; specialParameterTable[i].name != nullptr; ++i) {
     65     if (name.equals(specialParameterTable[i].name)) {
     66       return i;
     67     }
     68   }
     69 
     70   return -1;
     71 }
     72 
     73 // Return a comma-separated list of names in specialParameterTable
     74 // that are available at the specified API level.
     75 std::string listSpecialParameters(unsigned int api) {
     76   std::string ret;
     77   bool first = true;
     78   for (int i = 0; specialParameterTable[i].name != nullptr; ++i) {
     79     if (specialParameterTable[i].minAPI > api)
     80       continue;
     81     if (first)
     82       first = false;
     83     else
     84       ret += ", ";
     85     ret += "'";
     86     ret += specialParameterTable[i].name;
     87     ret += "'";
     88   }
     89   return ret;
     90 }
     91 
     92 }
     93 
     94 namespace slang {
     95 
     96 // This function takes care of additional validation and construction of
     97 // parameters related to forEach_* reflection.
     98 bool RSExportForEach::validateAndConstructParams(
     99     RSContext *Context, const clang::FunctionDecl *FD) {
    100   slangAssert(Context && FD);
    101   bool valid = true;
    102 
    103   numParams = FD->getNumParams();
    104 
    105   if (Context->getTargetAPI() < SLANG_JB_TARGET_API) {
    106     // Before JellyBean, we allowed only one kernel per file.  It must be called "root".
    107     if (!isRootRSFunc(FD)) {
    108       Context->ReportError(FD->getLocation(),
    109                            "Non-root compute kernel %0() is "
    110                            "not supported in SDK levels %1-%2")
    111           << FD->getName() << SLANG_MINIMUM_TARGET_API
    112           << (SLANG_JB_TARGET_API - 1);
    113       return false;
    114     }
    115   }
    116 
    117   mResultType = FD->getReturnType().getCanonicalType();
    118   // Compute kernel functions are defined differently when the
    119   // "__attribute__((kernel))" is set.
    120   if (FD->hasAttr<clang::KernelAttr>()) {
    121     valid |= validateAndConstructKernelParams(Context, FD);
    122   } else {
    123     valid |= validateAndConstructOldStyleParams(Context, FD);
    124   }
    125 
    126   valid |= setSignatureMetadata(Context, FD);
    127   return valid;
    128 }
    129 
    130 bool RSExportForEach::validateAndConstructOldStyleParams(
    131     RSContext *Context, const clang::FunctionDecl *FD) {
    132   slangAssert(Context && FD);
    133   // If numParams is 0, we already marked this as a graphics root().
    134   slangAssert(numParams > 0);
    135 
    136   bool valid = true;
    137 
    138   // Compute kernel functions of this style are required to return a void type.
    139   clang::ASTContext &C = Context->getASTContext();
    140   if (mResultType != C.VoidTy) {
    141     Context->ReportError(FD->getLocation(),
    142                          "Compute kernel %0() is required to return a "
    143                          "void type")
    144         << FD->getName();
    145     valid = false;
    146   }
    147 
    148   // Validate remaining parameter types
    149 
    150   size_t IndexOfFirstSpecialParameter = numParams;
    151   valid |= processSpecialParameters(Context, FD, &IndexOfFirstSpecialParameter);
    152 
    153   // Validate the non-special parameters, which should all be found before the
    154   // first special parameter.
    155   for (size_t i = 0; i < IndexOfFirstSpecialParameter; i++) {
    156     const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
    157     clang::QualType QT = PVD->getType().getCanonicalType();
    158 
    159     if (!QT->isPointerType()) {
    160       Context->ReportError(PVD->getLocation(),
    161                            "Compute kernel %0() cannot have non-pointer "
    162                            "parameters besides special parameters (%1). Parameter '%2' is "
    163                            "of type: '%3'")
    164           << FD->getName() << listSpecialParameters(Context->getTargetAPI())
    165           << PVD->getName() << PVD->getType().getAsString();
    166       valid = false;
    167       continue;
    168     }
    169 
    170     // The only non-const pointer should be out.
    171     if (!QT->getPointeeType().isConstQualified()) {
    172       if (mOut == nullptr) {
    173         mOut = PVD;
    174       } else {
    175         Context->ReportError(PVD->getLocation(),
    176                              "Compute kernel %0() can only have one non-const "
    177                              "pointer parameter. Parameters '%1' and '%2' are "
    178                              "both non-const.")
    179             << FD->getName() << mOut->getName() << PVD->getName();
    180         valid = false;
    181       }
    182     } else {
    183       if (mIns.empty() && mOut == nullptr) {
    184         mIns.push_back(PVD);
    185       } else if (mUsrData == nullptr) {
    186         mUsrData = PVD;
    187       } else {
    188         Context->ReportError(
    189             PVD->getLocation(),
    190             "Unexpected parameter '%0' for compute kernel %1()")
    191             << PVD->getName() << FD->getName();
    192         valid = false;
    193       }
    194     }
    195   }
    196 
    197   if (mIns.empty() && !mOut) {
    198     Context->ReportError(FD->getLocation(),
    199                          "Compute kernel %0() must have at least one "
    200                          "parameter for in or out")
    201         << FD->getName();
    202     valid = false;
    203   }
    204 
    205   return valid;
    206 }
    207 
    208 bool RSExportForEach::validateAndConstructKernelParams(
    209     RSContext *Context, const clang::FunctionDecl *FD) {
    210   slangAssert(Context && FD);
    211   bool valid = true;
    212   clang::ASTContext &C = Context->getASTContext();
    213 
    214   if (Context->getTargetAPI() < SLANG_JB_MR1_TARGET_API) {
    215     Context->ReportError(FD->getLocation(),
    216                          "Compute kernel %0() targeting SDK levels "
    217                          "%1-%2 may not use pass-by-value with "
    218                          "__attribute__((kernel))")
    219         << FD->getName() << SLANG_MINIMUM_TARGET_API
    220         << (SLANG_JB_MR1_TARGET_API - 1);
    221     return false;
    222   }
    223 
    224   // Denote that we are indeed a pass-by-value kernel.
    225   mIsKernelStyle = true;
    226   mHasReturnType = (mResultType != C.VoidTy);
    227 
    228   if (mResultType->isPointerType()) {
    229     Context->ReportError(
    230         FD->getTypeSpecStartLoc(),
    231         "Compute kernel %0() cannot return a pointer type: '%1'")
    232         << FD->getName() << mResultType.getAsString();
    233     valid = false;
    234   }
    235 
    236   // Validate remaining parameter types
    237 
    238   size_t IndexOfFirstSpecialParameter = numParams;
    239   valid |= processSpecialParameters(Context, FD, &IndexOfFirstSpecialParameter);
    240 
    241   // Validate the non-special parameters, which should all be found before the
    242   // first special.
    243   for (size_t i = 0; i < IndexOfFirstSpecialParameter; i++) {
    244     const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
    245 
    246     if (Context->getTargetAPI() >= SLANG_M_TARGET_API || i == 0) {
    247       if (i >= RS_KERNEL_INPUT_LIMIT) {
    248         Context->ReportError(PVD->getLocation(),
    249                              "Invalid parameter '%0' for compute kernel %1(). "
    250                              "Kernels targeting SDK levels %2+ may not use "
    251                              "more than %3 input parameters.") << PVD->getName() <<
    252                              FD->getName() << SLANG_M_TARGET_API <<
    253                              int(RS_KERNEL_INPUT_LIMIT);
    254 
    255       } else {
    256         mIns.push_back(PVD);
    257       }
    258     } else {
    259       Context->ReportError(PVD->getLocation(),
    260                            "Invalid parameter '%0' for compute kernel %1(). "
    261                            "Kernels targeting SDK levels %2-%3 may not use "
    262                            "multiple input parameters.") << PVD->getName() <<
    263                            FD->getName() << SLANG_MINIMUM_TARGET_API <<
    264                            (SLANG_M_TARGET_API - 1);
    265       valid = false;
    266     }
    267     clang::QualType QT = PVD->getType().getCanonicalType();
    268     if (QT->isPointerType()) {
    269       Context->ReportError(PVD->getLocation(),
    270                            "Compute kernel %0() cannot have "
    271                            "parameter '%1' of pointer type: '%2'")
    272           << FD->getName() << PVD->getName() << PVD->getType().getAsString();
    273       valid = false;
    274     }
    275   }
    276 
    277   // Check that we have at least one allocation to use for dimensions.
    278   if (valid && mIns.empty() && !mHasReturnType && Context->getTargetAPI() < SLANG_M_TARGET_API) {
    279     Context->ReportError(FD->getLocation(),
    280                          "Compute kernel %0() targeting SDK levels "
    281                          "%1-%2 must have at least one "
    282                          "input parameter or a non-void return "
    283                          "type")
    284         << FD->getName() << SLANG_MINIMUM_TARGET_API
    285         << (SLANG_M_TARGET_API - 1);
    286     valid = false;
    287   }
    288 
    289   return valid;
    290 }
    291 
    292 // Process the optional special parameters:
    293 // - Sets *IndexOfFirstSpecialParameter to the index of the first special parameter, or
    294 //     FD->getNumParams() if none are found.
    295 // - Sets mSpecialParameterSignatureMetadata for the found special parameters.
    296 // Returns true if no errors.
    297 bool RSExportForEach::processSpecialParameters(
    298     RSContext *Context, const clang::FunctionDecl *FD,
    299     size_t *IndexOfFirstSpecialParameter) {
    300   slangAssert(IndexOfFirstSpecialParameter != nullptr);
    301   slangAssert(mSpecialParameterSignatureMetadata == 0);
    302   clang::ASTContext &C = Context->getASTContext();
    303 
    304   // Find all special parameters if present.
    305   int LastSpecialParameterIdx = -1;     // index into specialParameterTable
    306   int FirstLocationSpecialParameterIdx = -1; // index into specialParameterTable
    307   clang::QualType FirstLocationSpecialParameterType;
    308   size_t NumParams = FD->getNumParams();
    309   *IndexOfFirstSpecialParameter = NumParams;
    310   bool valid = true;
    311   for (size_t i = 0; i < NumParams; i++) {
    312     const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
    313     const llvm::StringRef ParamName = PVD->getName();
    314     const clang::QualType Type = PVD->getType();
    315     const clang::QualType QT = Type.getCanonicalType();
    316     const clang::QualType UT = QT.getUnqualifiedType();
    317     int SpecialParameterIdx = lookupSpecialParameter(ParamName);
    318 
    319     static const char KernelContextUnqualifiedTypeName[] =
    320         "const struct rs_kernel_context_t *";
    321     static const char KernelContextTypeName[] = "rs_kernel_context";
    322 
    323     // If the type is rs_context, it should have been named "context" and classified
    324     // as a special parameter.
    325     if (SpecialParameterIdx < 0 && UT.getAsString() == KernelContextUnqualifiedTypeName) {
    326       Context->ReportError(
    327           PVD->getLocation(),
    328           "The special parameter of type '%0' must be called "
    329           "'context' instead of '%1'.")
    330           << KernelContextTypeName << ParamName;
    331       SpecialParameterIdx = lookupSpecialParameter("context");
    332     }
    333 
    334     // If it's not a special parameter, check that it appears before any special
    335     // parameter.
    336     if (SpecialParameterIdx < 0) {
    337       if (*IndexOfFirstSpecialParameter < NumParams) {
    338         Context->ReportError(PVD->getLocation(),
    339                              "In compute kernel %0(), parameter '%1' cannot "
    340                              "appear after any of the special parameters (%2).")
    341             << FD->getName() << ParamName << listSpecialParameters(Context->getTargetAPI());
    342         valid = false;
    343       }
    344       continue;
    345     }
    346 
    347     const SpecialParameter &SP = specialParameterTable[SpecialParameterIdx];
    348 
    349     // Verify that this special parameter is OK for the current API level.
    350     if (Context->getTargetAPI() < SP.minAPI) {
    351       Context->ReportError(PVD->getLocation(),
    352                            "Compute kernel %0() targeting SDK levels "
    353                            "%1-%2 may not use special parameter '%3'.")
    354           << FD->getName() << SLANG_MINIMUM_TARGET_API << (SP.minAPI - 1)
    355           << SP.name;
    356       valid = false;
    357     }
    358 
    359     // Check that the order of the special parameters is correct.
    360     if (SpecialParameterIdx < LastSpecialParameterIdx) {
    361       Context->ReportError(
    362           PVD->getLocation(),
    363           "In compute kernel %0(), special parameter '%1' must "
    364           "be defined before special parameter '%2'.")
    365           << FD->getName() << SP.name
    366           << specialParameterTable[LastSpecialParameterIdx].name;
    367       valid = false;
    368     }
    369 
    370     // Validate the data type of the special parameter.
    371     switch (SP.kind) {
    372     case SPK_LOCATION: {
    373       // Location special parameters can only be int or uint.
    374       if (UT != C.UnsignedIntTy && UT != C.IntTy) {
    375         Context->ReportError(PVD->getLocation(),
    376                              "Special parameter '%0' must be of type 'int' or "
    377                              "'unsigned int'. It is of type '%1'.")
    378             << ParamName << Type.getAsString();
    379         valid = false;
    380       }
    381 
    382       // Ensure that all location special parameters have the same type.
    383       if (FirstLocationSpecialParameterIdx >= 0) {
    384         if (Type != FirstLocationSpecialParameterType) {
    385           Context->ReportError(
    386               PVD->getLocation(),
    387               "Special parameters '%0' and '%1' must be of the same type. "
    388               "'%0' is of type '%2' while '%1' is of type '%3'.")
    389               << specialParameterTable[FirstLocationSpecialParameterIdx].name
    390               << SP.name << FirstLocationSpecialParameterType.getAsString()
    391               << Type.getAsString();
    392           valid = false;
    393         }
    394       } else {
    395         FirstLocationSpecialParameterIdx = SpecialParameterIdx;
    396         FirstLocationSpecialParameterType = Type;
    397       }
    398     } break;
    399     case SPK_CONTEXT: {
    400       // Check that variables named "context" are of type rs_context.
    401       if (UT.getAsString() != KernelContextUnqualifiedTypeName) {
    402         Context->ReportError(PVD->getLocation(),
    403                              "Special parameter '%0' must be of type '%1'. "
    404                              "It is of type '%2'.")
    405             << ParamName << KernelContextTypeName
    406             << Type.getAsString();
    407         valid = false;
    408       }
    409     } break;
    410     default:
    411       slangAssert(!"Unexpected special parameter type");
    412     }
    413 
    414     // We should not be invoked if two parameters of the same name are present.
    415     slangAssert(!(mSpecialParameterSignatureMetadata & SP.bitval));
    416     mSpecialParameterSignatureMetadata |= SP.bitval;
    417 
    418     LastSpecialParameterIdx = SpecialParameterIdx;
    419     // If this is the first time we find a special parameter, save it.
    420     if (*IndexOfFirstSpecialParameter >= NumParams) {
    421       *IndexOfFirstSpecialParameter = i;
    422     }
    423   }
    424   return valid;
    425 }
    426 
    427 bool RSExportForEach::setSignatureMetadata(RSContext *Context,
    428                                            const clang::FunctionDecl *FD) {
    429   mSignatureMetadata = 0;
    430   bool valid = true;
    431 
    432   if (mIsKernelStyle) {
    433     slangAssert(mOut == nullptr);
    434     slangAssert(mUsrData == nullptr);
    435   } else {
    436     slangAssert(!mHasReturnType);
    437   }
    438 
    439   // Set up the bitwise metadata encoding for runtime argument passing.
    440   const bool HasOut = mOut || mHasReturnType;
    441   mSignatureMetadata |= (hasIns() ?       bcinfo::MD_SIG_In     : 0);
    442   mSignatureMetadata |= (HasOut ?         bcinfo::MD_SIG_Out    : 0);
    443   mSignatureMetadata |= (mUsrData ?       bcinfo::MD_SIG_Usr    : 0);
    444   mSignatureMetadata |= (mIsKernelStyle ? bcinfo::MD_SIG_Kernel : 0);  // pass-by-value
    445   mSignatureMetadata |= mSpecialParameterSignatureMetadata;
    446 
    447   if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) {
    448     // APIs before ICS cannot skip between parameters. It is ok, however, for
    449     // them to omit further parameters (i.e. skipping X is ok if you skip Y).
    450     if (mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr |
    451                                bcinfo::MD_SIG_X | bcinfo::MD_SIG_Y) &&
    452         mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr |
    453                                bcinfo::MD_SIG_X) &&
    454         mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_Usr) &&
    455         mSignatureMetadata != (bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out) &&
    456         mSignatureMetadata != (bcinfo::MD_SIG_In)) {
    457       Context->ReportError(FD->getLocation(),
    458                            "Compute kernel %0() targeting SDK levels "
    459                            "%1-%2 may not skip parameters")
    460           << FD->getName() << SLANG_MINIMUM_TARGET_API
    461           << (SLANG_ICS_TARGET_API - 1);
    462       valid = false;
    463     }
    464   }
    465   return valid;
    466 }
    467 
    468 RSExportForEach *RSExportForEach::Create(RSContext *Context,
    469                                          const clang::FunctionDecl *FD) {
    470   slangAssert(Context && FD);
    471   llvm::StringRef Name = FD->getName();
    472   RSExportForEach *FE;
    473 
    474   slangAssert(!Name.empty() && "Function must have a name");
    475 
    476   FE = new RSExportForEach(Context, Name);
    477 
    478   if (!FE->validateAndConstructParams(Context, FD)) {
    479     return nullptr;
    480   }
    481 
    482   clang::ASTContext &Ctx = Context->getASTContext();
    483 
    484   std::string Id = CreateDummyName("helper_foreach_param", FE->getName());
    485 
    486   // Extract the usrData parameter (if we have one)
    487   if (FE->mUsrData) {
    488     const clang::ParmVarDecl *PVD = FE->mUsrData;
    489     clang::QualType QT = PVD->getType().getCanonicalType();
    490     slangAssert(QT->isPointerType() &&
    491                 QT->getPointeeType().isConstQualified());
    492 
    493     const clang::ASTContext &C = Context->getASTContext();
    494     if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
    495         C.VoidTy) {
    496       // In the case of using const void*, we can't reflect an appopriate
    497       // Java type, so we fall back to just reflecting the ain/aout parameters
    498       FE->mUsrData = nullptr;
    499     } else {
    500       clang::RecordDecl *RD =
    501           clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
    502                                     Ctx.getTranslationUnitDecl(),
    503                                     clang::SourceLocation(),
    504                                     clang::SourceLocation(),
    505                                     &Ctx.Idents.get(Id));
    506 
    507       clang::FieldDecl *FD =
    508           clang::FieldDecl::Create(Ctx,
    509                                    RD,
    510                                    clang::SourceLocation(),
    511                                    clang::SourceLocation(),
    512                                    PVD->getIdentifier(),
    513                                    QT->getPointeeType(),
    514                                    nullptr,
    515                                    /* BitWidth = */ nullptr,
    516                                    /* Mutable = */ false,
    517                                    /* HasInit = */ clang::ICIS_NoInit);
    518       RD->addDecl(FD);
    519       RD->completeDefinition();
    520 
    521       // Create an export type iff we have a valid usrData type
    522       clang::QualType T = Ctx.getTagDeclType(RD);
    523       slangAssert(!T.isNull());
    524 
    525       RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
    526 
    527       slangAssert(ET && "Failed to export a kernel");
    528 
    529       slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
    530                   "Parameter packet must be a record");
    531 
    532       FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
    533     }
    534   }
    535 
    536   if (FE->hasIns()) {
    537 
    538     for (InIter BI = FE->mIns.begin(), EI = FE->mIns.end(); BI != EI; BI++) {
    539       const clang::Type *T = (*BI)->getType().getCanonicalType().getTypePtr();
    540       RSExportType *InExportType = RSExportType::Create(Context, T);
    541 
    542       if (FE->mIsKernelStyle) {
    543         slangAssert(InExportType != nullptr);
    544       }
    545 
    546       FE->mInTypes.push_back(InExportType);
    547     }
    548   }
    549 
    550   if (FE->mIsKernelStyle && FE->mHasReturnType) {
    551     const clang::Type *T = FE->mResultType.getTypePtr();
    552     FE->mOutType = RSExportType::Create(Context, T);
    553     slangAssert(FE->mOutType);
    554   } else if (FE->mOut) {
    555     const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
    556     FE->mOutType = RSExportType::Create(Context, T);
    557   }
    558 
    559   return FE;
    560 }
    561 
    562 RSExportForEach *RSExportForEach::CreateDummyRoot(RSContext *Context) {
    563   slangAssert(Context);
    564   llvm::StringRef Name = "root";
    565   RSExportForEach *FE = new RSExportForEach(Context, Name);
    566   FE->mDummyRoot = true;
    567   return FE;
    568 }
    569 
    570 bool RSExportForEach::isGraphicsRootRSFunc(unsigned int targetAPI,
    571                                            const clang::FunctionDecl *FD) {
    572   if (FD->hasAttr<clang::KernelAttr>()) {
    573     return false;
    574   }
    575 
    576   if (!isRootRSFunc(FD)) {
    577     return false;
    578   }
    579 
    580   if (FD->getNumParams() == 0) {
    581     // Graphics root function
    582     return true;
    583   }
    584 
    585   // Check for legacy graphics root function (with single parameter).
    586   if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
    587     const clang::QualType &IntType = FD->getASTContext().IntTy;
    588     if (FD->getReturnType().getCanonicalType() == IntType) {
    589       return true;
    590     }
    591   }
    592 
    593   return false;
    594 }
    595 
    596 bool RSExportForEach::isRSForEachFunc(unsigned int targetAPI,
    597                                       slang::RSContext* Context,
    598                                       const clang::FunctionDecl *FD) {
    599   slangAssert(Context && FD);
    600   bool hasKernelAttr = FD->hasAttr<clang::KernelAttr>();
    601 
    602   if (FD->getStorageClass() == clang::SC_Static) {
    603     if (hasKernelAttr) {
    604       Context->ReportError(FD->getLocation(),
    605                            "Invalid use of attribute kernel with "
    606                            "static function declaration: %0")
    607           << FD->getName();
    608     }
    609     return false;
    610   }
    611 
    612   // Anything tagged as a kernel is definitely used with ForEach.
    613   if (hasKernelAttr) {
    614     return true;
    615   }
    616 
    617   if (isGraphicsRootRSFunc(targetAPI, FD)) {
    618     return false;
    619   }
    620 
    621   // Check if first parameter is a pointer (which is required for ForEach).
    622   unsigned int numParams = FD->getNumParams();
    623 
    624   if (numParams > 0) {
    625     const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
    626     clang::QualType QT = PVD->getType().getCanonicalType();
    627 
    628     if (QT->isPointerType()) {
    629       return true;
    630     }
    631 
    632     // Any non-graphics root() is automatically a ForEach candidate.
    633     // At this point, however, we know that it is not going to be a valid
    634     // compute root() function (due to not having a pointer parameter). We
    635     // still want to return true here, so that we can issue appropriate
    636     // diagnostics.
    637     if (isRootRSFunc(FD)) {
    638       return true;
    639     }
    640   }
    641 
    642   return false;
    643 }
    644 
    645 bool
    646 RSExportForEach::validateSpecialFuncDecl(unsigned int targetAPI,
    647                                          slang::RSContext *Context,
    648                                          clang::FunctionDecl const *FD) {
    649   slangAssert(Context && FD);
    650   bool valid = true;
    651   const clang::ASTContext &C = FD->getASTContext();
    652   const clang::QualType &IntType = FD->getASTContext().IntTy;
    653 
    654   if (isGraphicsRootRSFunc(targetAPI, FD)) {
    655     if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
    656       // Legacy graphics root function
    657       const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
    658       clang::QualType QT = PVD->getType().getCanonicalType();
    659       if (QT != IntType) {
    660         Context->ReportError(PVD->getLocation(),
    661                              "invalid parameter type for legacy "
    662                              "graphics root() function: %0")
    663             << PVD->getType();
    664         valid = false;
    665       }
    666     }
    667 
    668     // Graphics root function, so verify that it returns an int
    669     if (FD->getReturnType().getCanonicalType() != IntType) {
    670       Context->ReportError(FD->getLocation(),
    671                            "root() is required to return "
    672                            "an int for graphics usage");
    673       valid = false;
    674     }
    675   } else if (isInitRSFunc(FD) || isDtorRSFunc(FD)) {
    676     if (FD->getNumParams() != 0) {
    677       Context->ReportError(FD->getLocation(),
    678                            "%0(void) is required to have no "
    679                            "parameters")
    680           << FD->getName();
    681       valid = false;
    682     }
    683 
    684     if (FD->getReturnType().getCanonicalType() != C.VoidTy) {
    685       Context->ReportError(FD->getLocation(),
    686                            "%0(void) is required to have a void "
    687                            "return type")
    688           << FD->getName();
    689       valid = false;
    690     }
    691   } else {
    692     slangAssert(false && "must be called on root, init or .rs.dtor function!");
    693   }
    694 
    695   return valid;
    696 }
    697 
    698 }  // namespace slang
    699