Home | History | Annotate | Download | only in slang
      1 /*
      2  * Copyright 2011-2012, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "slang_rs_export_foreach.h"
     18 
     19 #include <string>
     20 
     21 #include "clang/AST/ASTContext.h"
     22 #include "clang/AST/Attr.h"
     23 #include "clang/AST/Decl.h"
     24 #include "clang/AST/TypeLoc.h"
     25 
     26 #include "llvm/IR/DerivedTypes.h"
     27 
     28 #include "slang_assert.h"
     29 #include "slang_rs_context.h"
     30 #include "slang_rs_export_type.h"
     31 #include "slang_version.h"
     32 
     33 namespace slang {
     34 
     35 // This function takes care of additional validation and construction of
     36 // parameters related to forEach_* reflection.
     37 bool RSExportForEach::validateAndConstructParams(
     38     RSContext *Context, const clang::FunctionDecl *FD) {
     39   slangAssert(Context && FD);
     40   bool valid = true;
     41 
     42   numParams = FD->getNumParams();
     43 
     44   if (Context->getTargetAPI() < SLANG_JB_TARGET_API) {
     45     // Before JellyBean, we allowed only one kernel per file.  It must be called "root".
     46     if (!isRootRSFunc(FD)) {
     47       Context->ReportError(FD->getLocation(),
     48                            "Non-root compute kernel %0() is "
     49                            "not supported in SDK levels %1-%2")
     50           << FD->getName() << SLANG_MINIMUM_TARGET_API
     51           << (SLANG_JB_TARGET_API - 1);
     52       return false;
     53     }
     54   }
     55 
     56   mResultType = FD->getReturnType().getCanonicalType();
     57   // Compute kernel functions are defined differently when the
     58   // "__attribute__((kernel))" is set.
     59   if (FD->hasAttr<clang::KernelAttr>()) {
     60     valid |= validateAndConstructKernelParams(Context, FD);
     61   } else {
     62     valid |= validateAndConstructOldStyleParams(Context, FD);
     63   }
     64 
     65   valid |= setSignatureMetadata(Context, FD);
     66   return valid;
     67 }
     68 
     69 bool RSExportForEach::validateAndConstructOldStyleParams(
     70     RSContext *Context, const clang::FunctionDecl *FD) {
     71   slangAssert(Context && FD);
     72   // If numParams is 0, we already marked this as a graphics root().
     73   slangAssert(numParams > 0);
     74 
     75   bool valid = true;
     76 
     77   // Compute kernel functions of this style are required to return a void type.
     78   clang::ASTContext &C = Context->getASTContext();
     79   if (mResultType != C.VoidTy) {
     80     Context->ReportError(FD->getLocation(),
     81                          "Compute kernel %0() is required to return a "
     82                          "void type")
     83         << FD->getName();
     84     valid = false;
     85   }
     86 
     87   // Validate remaining parameter types
     88   // TODO(all): Add support for LOD/face when we have them
     89 
     90   size_t IndexOfFirstIterator = numParams;
     91   valid |= validateIterationParameters(Context, FD, &IndexOfFirstIterator);
     92 
     93   // Validate the non-iterator parameters, which should all be found before the
     94   // first iterator.
     95   for (size_t i = 0; i < IndexOfFirstIterator; i++) {
     96     const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
     97     clang::QualType QT = PVD->getType().getCanonicalType();
     98 
     99     if (!QT->isPointerType()) {
    100       Context->ReportError(PVD->getLocation(),
    101                            "Compute kernel %0() cannot have non-pointer "
    102                            "parameters besides 'x' and 'y'. Parameter '%1' is "
    103                            "of type: '%2'")
    104           << FD->getName() << PVD->getName() << PVD->getType().getAsString();
    105       valid = false;
    106       continue;
    107     }
    108 
    109     // The only non-const pointer should be out.
    110     if (!QT->getPointeeType().isConstQualified()) {
    111       if (mOut == NULL) {
    112         mOut = PVD;
    113       } else {
    114         Context->ReportError(PVD->getLocation(),
    115                              "Compute kernel %0() can only have one non-const "
    116                              "pointer parameter. Parameters '%1' and '%2' are "
    117                              "both non-const.")
    118             << FD->getName() << mOut->getName() << PVD->getName();
    119         valid = false;
    120       }
    121     } else {
    122       if (mIns.empty() && mOut == NULL) {
    123         mIns.push_back(PVD);
    124       } else if (mUsrData == NULL) {
    125         mUsrData = PVD;
    126       } else {
    127         Context->ReportError(
    128             PVD->getLocation(),
    129             "Unexpected parameter '%0' for compute kernel %1()")
    130             << PVD->getName() << FD->getName();
    131         valid = false;
    132       }
    133     }
    134   }
    135 
    136   if (mIns.empty() && !mOut) {
    137     Context->ReportError(FD->getLocation(),
    138                          "Compute kernel %0() must have at least one "
    139                          "parameter for in or out")
    140         << FD->getName();
    141     valid = false;
    142   }
    143 
    144   return valid;
    145 }
    146 
    147 bool RSExportForEach::validateAndConstructKernelParams(
    148     RSContext *Context, const clang::FunctionDecl *FD) {
    149   slangAssert(Context && FD);
    150   bool valid = true;
    151   clang::ASTContext &C = Context->getASTContext();
    152 
    153   if (Context->getTargetAPI() < SLANG_JB_MR1_TARGET_API) {
    154     Context->ReportError(FD->getLocation(),
    155                          "Compute kernel %0() targeting SDK levels "
    156                          "%1-%2 may not use pass-by-value with "
    157                          "__attribute__((kernel))")
    158         << FD->getName() << SLANG_MINIMUM_TARGET_API
    159         << (SLANG_JB_MR1_TARGET_API - 1);
    160     return false;
    161   }
    162 
    163   // Denote that we are indeed a pass-by-value kernel.
    164   mIsKernelStyle = true;
    165   mHasReturnType = (mResultType != C.VoidTy);
    166 
    167   if (mResultType->isPointerType()) {
    168     Context->ReportError(
    169         FD->getTypeSpecStartLoc(),
    170         "Compute kernel %0() cannot return a pointer type: '%1'")
    171         << FD->getName() << mResultType.getAsString();
    172     valid = false;
    173   }
    174 
    175   // Validate remaining parameter types
    176   // TODO(all): Add support for LOD/face when we have them
    177 
    178   size_t IndexOfFirstIterator = numParams;
    179   valid |= validateIterationParameters(Context, FD, &IndexOfFirstIterator);
    180 
    181   // Validate the non-iterator parameters, which should all be found before the
    182   // first iterator.
    183   for (size_t i = 0; i < IndexOfFirstIterator; i++) {
    184     const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
    185 
    186     /*
    187      * FIXME: Change this to a test against an actual API version when the
    188      *        multi-input feature is officially supported.
    189      */
    190     if (Context->getTargetAPI() == SLANG_DEVELOPMENT_TARGET_API || i == 0) {
    191       mIns.push_back(PVD);
    192     } else {
    193       Context->ReportError(PVD->getLocation(),
    194                            "Invalid parameter '%0' for compute kernel %1(). "
    195                            "Kernels targeting SDK levels %2-%3 may not use "
    196                            "multiple input parameters.") << PVD->getName() <<
    197                            FD->getName() << SLANG_MINIMUM_TARGET_API <<
    198                            SLANG_MAXIMUM_TARGET_API;
    199       valid = false;
    200     }
    201     clang::QualType QT = PVD->getType().getCanonicalType();
    202     if (QT->isPointerType()) {
    203       Context->ReportError(PVD->getLocation(),
    204                            "Compute kernel %0() cannot have "
    205                            "parameter '%1' of pointer type: '%2'")
    206           << FD->getName() << PVD->getName() << PVD->getType().getAsString();
    207       valid = false;
    208     }
    209   }
    210 
    211   // Check that we have at least one allocation to use for dimensions.
    212   if (valid && mIns.empty() && !mHasReturnType) {
    213     Context->ReportError(FD->getLocation(),
    214                          "Compute kernel %0() must have at least one "
    215                          "input parameter or a non-void return "
    216                          "type")
    217         << FD->getName();
    218     valid = false;
    219   }
    220 
    221   return valid;
    222 }
    223 
    224 // Search for the optional x and y parameters.  Returns true if valid.   Also
    225 // sets *IndexOfFirstIterator to the index of the first iterator parameter, or
    226 // FD->getNumParams() if none are found.
    227 bool RSExportForEach::validateIterationParameters(
    228     RSContext *Context, const clang::FunctionDecl *FD,
    229     size_t *IndexOfFirstIterator) {
    230   slangAssert(IndexOfFirstIterator != NULL);
    231   slangAssert(mX == NULL && mY == NULL);
    232   clang::ASTContext &C = Context->getASTContext();
    233 
    234   // Find the x and y parameters if present.
    235   size_t NumParams = FD->getNumParams();
    236   *IndexOfFirstIterator = NumParams;
    237   bool valid = true;
    238   for (size_t i = 0; i < NumParams; i++) {
    239     const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
    240     llvm::StringRef ParamName = PVD->getName();
    241     if (ParamName.equals("x")) {
    242       slangAssert(mX == NULL);  // We won't be invoked if two 'x' are present.
    243       mX = PVD;
    244       if (mY != NULL) {
    245         Context->ReportError(PVD->getLocation(),
    246                              "In compute kernel %0(), parameter 'x' should "
    247                              "be defined before parameter 'y'")
    248             << FD->getName();
    249         valid = false;
    250       }
    251     } else if (ParamName.equals("y")) {
    252       slangAssert(mY == NULL);  // We won't be invoked if two 'y' are present.
    253       mY = PVD;
    254     } else {
    255       // It's neither x nor y.
    256       if (*IndexOfFirstIterator < NumParams) {
    257         Context->ReportError(PVD->getLocation(),
    258                              "In compute kernel %0(), parameter '%1' cannot "
    259                              "appear after the 'x' and 'y' parameters")
    260             << FD->getName() << ParamName;
    261         valid = false;
    262       }
    263       continue;
    264     }
    265     // Validate the data type of x and y.
    266     clang::QualType QT = PVD->getType().getCanonicalType();
    267     clang::QualType UT = QT.getUnqualifiedType();
    268     if (UT != C.UnsignedIntTy && UT != C.IntTy) {
    269       Context->ReportError(PVD->getLocation(),
    270                            "Parameter '%0' must be of type 'int' or "
    271                            "'unsigned int'. It is of type '%1'")
    272           << ParamName << PVD->getType().getAsString();
    273       valid = false;
    274     }
    275     // If this is the first time we find an iterator, save it.
    276     if (*IndexOfFirstIterator >= NumParams) {
    277       *IndexOfFirstIterator = i;
    278     }
    279   }
    280   // Check that x and y have the same type.
    281   if (mX != NULL and mY != NULL) {
    282     clang::QualType XType = mX->getType();
    283     clang::QualType YType = mY->getType();
    284 
    285     if (XType != YType) {
    286       Context->ReportError(mY->getLocation(),
    287                            "Parameter 'x' and 'y' must be of the same type. "
    288                            "'x' is of type '%0' while 'y' is of type '%1'")
    289           << XType.getAsString() << YType.getAsString();
    290       valid = false;
    291     }
    292   }
    293   return valid;
    294 }
    295 
    296 bool RSExportForEach::setSignatureMetadata(RSContext *Context,
    297                                            const clang::FunctionDecl *FD) {
    298   mSignatureMetadata = 0;
    299   bool valid = true;
    300 
    301   if (mIsKernelStyle) {
    302     slangAssert(mOut == NULL);
    303     slangAssert(mUsrData == NULL);
    304   } else {
    305     slangAssert(!mHasReturnType);
    306   }
    307 
    308   // Set up the bitwise metadata encoding for runtime argument passing.
    309   // TODO: If this bit field is re-used from C++ code, define the values in a header.
    310   const bool HasOut = mOut || mHasReturnType;
    311   mSignatureMetadata |= (hasIns() ?       0x01 : 0);
    312   mSignatureMetadata |= (HasOut ?         0x02 : 0);
    313   mSignatureMetadata |= (mUsrData ?       0x04 : 0);
    314   mSignatureMetadata |= (mX ?             0x08 : 0);
    315   mSignatureMetadata |= (mY ?             0x10 : 0);
    316   mSignatureMetadata |= (mIsKernelStyle ? 0x20 : 0);  // pass-by-value
    317 
    318   if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) {
    319     // APIs before ICS cannot skip between parameters. It is ok, however, for
    320     // them to omit further parameters (i.e. skipping X is ok if you skip Y).
    321     if (mSignatureMetadata != 0x1f &&  // In, Out, UsrData, X, Y
    322         mSignatureMetadata != 0x0f &&  // In, Out, UsrData, X
    323         mSignatureMetadata != 0x07 &&  // In, Out, UsrData
    324         mSignatureMetadata != 0x03 &&  // In, Out
    325         mSignatureMetadata != 0x01) {  // In
    326       Context->ReportError(FD->getLocation(),
    327                            "Compute kernel %0() targeting SDK levels "
    328                            "%1-%2 may not skip parameters")
    329           << FD->getName() << SLANG_MINIMUM_TARGET_API
    330           << (SLANG_ICS_TARGET_API - 1);
    331       valid = false;
    332     }
    333   }
    334   return valid;
    335 }
    336 
    337 RSExportForEach *RSExportForEach::Create(RSContext *Context,
    338                                          const clang::FunctionDecl *FD) {
    339   slangAssert(Context && FD);
    340   llvm::StringRef Name = FD->getName();
    341   RSExportForEach *FE;
    342 
    343   slangAssert(!Name.empty() && "Function must have a name");
    344 
    345   FE = new RSExportForEach(Context, Name);
    346 
    347   if (!FE->validateAndConstructParams(Context, FD)) {
    348     return NULL;
    349   }
    350 
    351   clang::ASTContext &Ctx = Context->getASTContext();
    352 
    353   std::string Id = CreateDummyName("helper_foreach_param", FE->getName());
    354 
    355   // Extract the usrData parameter (if we have one)
    356   if (FE->mUsrData) {
    357     const clang::ParmVarDecl *PVD = FE->mUsrData;
    358     clang::QualType QT = PVD->getType().getCanonicalType();
    359     slangAssert(QT->isPointerType() &&
    360                 QT->getPointeeType().isConstQualified());
    361 
    362     const clang::ASTContext &C = Context->getASTContext();
    363     if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
    364         C.VoidTy) {
    365       // In the case of using const void*, we can't reflect an appopriate
    366       // Java type, so we fall back to just reflecting the ain/aout parameters
    367       FE->mUsrData = NULL;
    368     } else {
    369       clang::RecordDecl *RD =
    370           clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
    371                                     Ctx.getTranslationUnitDecl(),
    372                                     clang::SourceLocation(),
    373                                     clang::SourceLocation(),
    374                                     &Ctx.Idents.get(Id));
    375 
    376       clang::FieldDecl *FD =
    377           clang::FieldDecl::Create(Ctx,
    378                                    RD,
    379                                    clang::SourceLocation(),
    380                                    clang::SourceLocation(),
    381                                    PVD->getIdentifier(),
    382                                    QT->getPointeeType(),
    383                                    NULL,
    384                                    /* BitWidth = */ NULL,
    385                                    /* Mutable = */ false,
    386                                    /* HasInit = */ clang::ICIS_NoInit);
    387       RD->addDecl(FD);
    388       RD->completeDefinition();
    389 
    390       // Create an export type iff we have a valid usrData type
    391       clang::QualType T = Ctx.getTagDeclType(RD);
    392       slangAssert(!T.isNull());
    393 
    394       RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
    395 
    396       if (ET == NULL) {
    397         fprintf(stderr, "Failed to export the function %s. There's at least "
    398                         "one parameter whose type is not supported by the "
    399                         "reflection\n", FE->getName().c_str());
    400         return NULL;
    401       }
    402 
    403       slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
    404                   "Parameter packet must be a record");
    405 
    406       FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
    407     }
    408   }
    409 
    410   if (FE->hasIns()) {
    411 
    412     for (InIter BI = FE->mIns.begin(), EI = FE->mIns.end(); BI != EI; BI++) {
    413       const clang::Type *T = (*BI)->getType().getCanonicalType().getTypePtr();
    414       RSExportType *InExportType = RSExportType::Create(Context, T);
    415 
    416       if (FE->mIsKernelStyle) {
    417         slangAssert(InExportType != NULL);
    418       }
    419 
    420       FE->mInTypes.push_back(InExportType);
    421     }
    422   }
    423 
    424   if (FE->mIsKernelStyle && FE->mHasReturnType) {
    425     const clang::Type *T = FE->mResultType.getTypePtr();
    426     FE->mOutType = RSExportType::Create(Context, T);
    427     slangAssert(FE->mOutType);
    428   } else if (FE->mOut) {
    429     const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
    430     FE->mOutType = RSExportType::Create(Context, T);
    431   }
    432 
    433   return FE;
    434 }
    435 
    436 RSExportForEach *RSExportForEach::CreateDummyRoot(RSContext *Context) {
    437   slangAssert(Context);
    438   llvm::StringRef Name = "root";
    439   RSExportForEach *FE = new RSExportForEach(Context, Name);
    440   FE->mDummyRoot = true;
    441   return FE;
    442 }
    443 
    444 bool RSExportForEach::isGraphicsRootRSFunc(unsigned int targetAPI,
    445                                            const clang::FunctionDecl *FD) {
    446   if (FD->hasAttr<clang::KernelAttr>()) {
    447     return false;
    448   }
    449 
    450   if (!isRootRSFunc(FD)) {
    451     return false;
    452   }
    453 
    454   if (FD->getNumParams() == 0) {
    455     // Graphics root function
    456     return true;
    457   }
    458 
    459   // Check for legacy graphics root function (with single parameter).
    460   if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
    461     const clang::QualType &IntType = FD->getASTContext().IntTy;
    462     if (FD->getReturnType().getCanonicalType() == IntType) {
    463       return true;
    464     }
    465   }
    466 
    467   return false;
    468 }
    469 
    470 bool RSExportForEach::isRSForEachFunc(unsigned int targetAPI,
    471                                       slang::RSContext* Context,
    472                                       const clang::FunctionDecl *FD) {
    473   slangAssert(Context && FD);
    474   bool hasKernelAttr = FD->hasAttr<clang::KernelAttr>();
    475 
    476   if (FD->getStorageClass() == clang::SC_Static) {
    477     if (hasKernelAttr) {
    478       Context->ReportError(FD->getLocation(),
    479                            "Invalid use of attribute kernel with "
    480                            "static function declaration: %0")
    481           << FD->getName();
    482     }
    483     return false;
    484   }
    485 
    486   // Anything tagged as a kernel is definitely used with ForEach.
    487   if (hasKernelAttr) {
    488     return true;
    489   }
    490 
    491   if (isGraphicsRootRSFunc(targetAPI, FD)) {
    492     return false;
    493   }
    494 
    495   // Check if first parameter is a pointer (which is required for ForEach).
    496   unsigned int numParams = FD->getNumParams();
    497 
    498   if (numParams > 0) {
    499     const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
    500     clang::QualType QT = PVD->getType().getCanonicalType();
    501 
    502     if (QT->isPointerType()) {
    503       return true;
    504     }
    505 
    506     // Any non-graphics root() is automatically a ForEach candidate.
    507     // At this point, however, we know that it is not going to be a valid
    508     // compute root() function (due to not having a pointer parameter). We
    509     // still want to return true here, so that we can issue appropriate
    510     // diagnostics.
    511     if (isRootRSFunc(FD)) {
    512       return true;
    513     }
    514   }
    515 
    516   return false;
    517 }
    518 
    519 bool
    520 RSExportForEach::validateSpecialFuncDecl(unsigned int targetAPI,
    521                                          slang::RSContext *Context,
    522                                          clang::FunctionDecl const *FD) {
    523   slangAssert(Context && FD);
    524   bool valid = true;
    525   const clang::ASTContext &C = FD->getASTContext();
    526   const clang::QualType &IntType = FD->getASTContext().IntTy;
    527 
    528   if (isGraphicsRootRSFunc(targetAPI, FD)) {
    529     if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
    530       // Legacy graphics root function
    531       const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
    532       clang::QualType QT = PVD->getType().getCanonicalType();
    533       if (QT != IntType) {
    534         Context->ReportError(PVD->getLocation(),
    535                              "invalid parameter type for legacy "
    536                              "graphics root() function: %0")
    537             << PVD->getType();
    538         valid = false;
    539       }
    540     }
    541 
    542     // Graphics root function, so verify that it returns an int
    543     if (FD->getReturnType().getCanonicalType() != IntType) {
    544       Context->ReportError(FD->getLocation(),
    545                            "root() is required to return "
    546                            "an int for graphics usage");
    547       valid = false;
    548     }
    549   } else if (isInitRSFunc(FD) || isDtorRSFunc(FD)) {
    550     if (FD->getNumParams() != 0) {
    551       Context->ReportError(FD->getLocation(),
    552                            "%0(void) is required to have no "
    553                            "parameters")
    554           << FD->getName();
    555       valid = false;
    556     }
    557 
    558     if (FD->getReturnType().getCanonicalType() != C.VoidTy) {
    559       Context->ReportError(FD->getLocation(),
    560                            "%0(void) is required to have a void "
    561                            "return type")
    562           << FD->getName();
    563       valid = false;
    564     }
    565   } else {
    566     slangAssert(false && "must be called on root, init or .rs.dtor function!");
    567   }
    568 
    569   return valid;
    570 }
    571 
    572 }  // namespace slang
    573