Home | History | Annotate | Download | only in slang
      1 /*
      2  * Copyright 2011-2012, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "slang_rs_export_foreach.h"
     18 
     19 #include <string>
     20 
     21 #include "clang/AST/ASTContext.h"
     22 #include "clang/AST/Decl.h"
     23 #include "clang/AST/TypeLoc.h"
     24 
     25 #include "llvm/DerivedTypes.h"
     26 #include "llvm/Target/TargetData.h"
     27 
     28 #include "slang_assert.h"
     29 #include "slang_rs_context.h"
     30 #include "slang_rs_export_type.h"
     31 #include "slang_version.h"
     32 
     33 namespace slang {
     34 
     35 namespace {
     36 
     37 static void ReportNameError(clang::DiagnosticsEngine *DiagEngine,
     38                             clang::ParmVarDecl const *PVD) {
     39   slangAssert(DiagEngine && PVD);
     40   const clang::SourceManager &SM = DiagEngine->getSourceManager();
     41 
     42   DiagEngine->Report(
     43     clang::FullSourceLoc(PVD->getLocation(), SM),
     44     DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
     45                                 "Duplicate parameter entry "
     46                                 "(by position/name): '%0'"))
     47     << PVD->getName();
     48   return;
     49 }
     50 
     51 }  // namespace
     52 
     53 // This function takes care of additional validation and construction of
     54 // parameters related to forEach_* reflection.
     55 bool RSExportForEach::validateAndConstructParams(
     56     RSContext *Context, const clang::FunctionDecl *FD) {
     57   slangAssert(Context && FD);
     58   bool valid = true;
     59   clang::ASTContext &C = Context->getASTContext();
     60   clang::DiagnosticsEngine *DiagEngine = Context->getDiagnostics();
     61 
     62   numParams = FD->getNumParams();
     63   slangAssert(numParams > 0);
     64 
     65   if (Context->getTargetAPI() < SLANG_JB_TARGET_API) {
     66     if (!isRootRSFunc(FD)) {
     67       DiagEngine->Report(
     68         clang::FullSourceLoc(FD->getLocation(), DiagEngine->getSourceManager()),
     69         DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
     70                                     "Non-root compute kernel %0() is "
     71                                     "not supported in SDK levels %1-%2"))
     72         << FD->getName()
     73         << SLANG_MINIMUM_TARGET_API
     74         << (SLANG_JB_TARGET_API - 1);
     75       return false;
     76     }
     77   }
     78 
     79   // Compute kernel functions are required to return a void type for now
     80   if (FD->getResultType().getCanonicalType() != C.VoidTy) {
     81     DiagEngine->Report(
     82       clang::FullSourceLoc(FD->getLocation(), DiagEngine->getSourceManager()),
     83       DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
     84                                   "Compute kernel %0() is required to return a "
     85                                   "void type")) << FD->getName();
     86     valid = false;
     87   }
     88 
     89   // Validate remaining parameter types
     90   // TODO(all): Add support for LOD/face when we have them
     91 
     92   size_t i = 0;
     93   const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
     94   clang::QualType QT = PVD->getType().getCanonicalType();
     95 
     96   // Check for const T1 *in
     97   if (QT->isPointerType() && QT->getPointeeType().isConstQualified()) {
     98     mIn = PVD;
     99     i++;  // advance parameter pointer
    100   }
    101 
    102   // Check for T2 *out
    103   if (i < numParams) {
    104     PVD = FD->getParamDecl(i);
    105     QT = PVD->getType().getCanonicalType();
    106     if (QT->isPointerType() && !QT->getPointeeType().isConstQualified()) {
    107       mOut = PVD;
    108       i++;  // advance parameter pointer
    109     }
    110   }
    111 
    112   if (!mIn && !mOut) {
    113     DiagEngine->Report(
    114       clang::FullSourceLoc(FD->getLocation(),
    115                            DiagEngine->getSourceManager()),
    116       DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
    117                                   "Compute kernel %0() must have at least one "
    118                                   "parameter for in or out")) << FD->getName();
    119     valid = false;
    120   }
    121 
    122   // Check for T3 *usrData
    123   if (i < numParams) {
    124     PVD = FD->getParamDecl(i);
    125     QT = PVD->getType().getCanonicalType();
    126     if (QT->isPointerType() && QT->getPointeeType().isConstQualified()) {
    127       mUsrData = PVD;
    128       i++;  // advance parameter pointer
    129     }
    130   }
    131 
    132   while (i < numParams) {
    133     PVD = FD->getParamDecl(i);
    134     QT = PVD->getType().getCanonicalType();
    135 
    136     if (QT.getUnqualifiedType() != C.UnsignedIntTy) {
    137       DiagEngine->Report(
    138         clang::FullSourceLoc(PVD->getLocation(),
    139                              DiagEngine->getSourceManager()),
    140         DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
    141                                     "Unexpected kernel %0() parameter '%1' "
    142                                     "of type '%2'"))
    143         << FD->getName() << PVD->getName() << PVD->getType().getAsString();
    144       valid = false;
    145     } else {
    146       llvm::StringRef ParamName = PVD->getName();
    147       if (ParamName.equals("x")) {
    148         if (mX) {
    149           ReportNameError(DiagEngine, PVD);
    150           valid = false;
    151         } else if (mY) {
    152           // Can't go back to X after skipping Y
    153           ReportNameError(DiagEngine, PVD);
    154           valid = false;
    155         } else {
    156           mX = PVD;
    157         }
    158       } else if (ParamName.equals("y")) {
    159         if (mY) {
    160           ReportNameError(DiagEngine, PVD);
    161           valid = false;
    162         } else {
    163           mY = PVD;
    164         }
    165       } else {
    166         if (!mX && !mY) {
    167           mX = PVD;
    168         } else if (!mY) {
    169           mY = PVD;
    170         } else {
    171           DiagEngine->Report(
    172             clang::FullSourceLoc(PVD->getLocation(),
    173                                  DiagEngine->getSourceManager()),
    174             DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
    175                                         "Unexpected kernel %0() parameter '%1' "
    176                                         "of type '%2'"))
    177             << FD->getName() << PVD->getName() << PVD->getType().getAsString();
    178           valid = false;
    179         }
    180       }
    181     }
    182 
    183     i++;
    184   }
    185 
    186   mSignatureMetadata = 0;
    187   if (valid) {
    188     // Set up the bitwise metadata encoding for runtime argument passing.
    189     mSignatureMetadata |= (mIn ?       0x01 : 0);
    190     mSignatureMetadata |= (mOut ?      0x02 : 0);
    191     mSignatureMetadata |= (mUsrData ?  0x04 : 0);
    192     mSignatureMetadata |= (mX ?        0x08 : 0);
    193     mSignatureMetadata |= (mY ?        0x10 : 0);
    194   }
    195 
    196   if (Context->getTargetAPI() < SLANG_ICS_TARGET_API) {
    197     // APIs before ICS cannot skip between parameters. It is ok, however, for
    198     // them to omit further parameters (i.e. skipping X is ok if you skip Y).
    199     if (mSignatureMetadata != 0x1f &&  // In, Out, UsrData, X, Y
    200         mSignatureMetadata != 0x0f &&  // In, Out, UsrData, X
    201         mSignatureMetadata != 0x07 &&  // In, Out, UsrData
    202         mSignatureMetadata != 0x03 &&  // In, Out
    203         mSignatureMetadata != 0x01) {  // In
    204       DiagEngine->Report(
    205         clang::FullSourceLoc(FD->getLocation(),
    206                              DiagEngine->getSourceManager()),
    207         DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
    208                                     "Compute kernel %0() targeting SDK levels "
    209                                     "%1-%2 may not skip parameters"))
    210         << FD->getName() << SLANG_MINIMUM_TARGET_API
    211         << (SLANG_ICS_TARGET_API - 1);
    212       valid = false;
    213     }
    214   }
    215 
    216   return valid;
    217 }
    218 
    219 RSExportForEach *RSExportForEach::Create(RSContext *Context,
    220                                          const clang::FunctionDecl *FD) {
    221   slangAssert(Context && FD);
    222   llvm::StringRef Name = FD->getName();
    223   RSExportForEach *FE;
    224 
    225   slangAssert(!Name.empty() && "Function must have a name");
    226 
    227   FE = new RSExportForEach(Context, Name);
    228 
    229   if (!FE->validateAndConstructParams(Context, FD)) {
    230     return NULL;
    231   }
    232 
    233   clang::ASTContext &Ctx = Context->getASTContext();
    234 
    235   std::string Id(DUMMY_RS_TYPE_NAME_PREFIX"helper_foreach_param:");
    236   Id.append(FE->getName()).append(DUMMY_RS_TYPE_NAME_POSTFIX);
    237 
    238   // Extract the usrData parameter (if we have one)
    239   if (FE->mUsrData) {
    240     const clang::ParmVarDecl *PVD = FE->mUsrData;
    241     clang::QualType QT = PVD->getType().getCanonicalType();
    242     slangAssert(QT->isPointerType() &&
    243                 QT->getPointeeType().isConstQualified());
    244 
    245     const clang::ASTContext &C = Context->getASTContext();
    246     if (QT->getPointeeType().getCanonicalType().getUnqualifiedType() ==
    247         C.VoidTy) {
    248       // In the case of using const void*, we can't reflect an appopriate
    249       // Java type, so we fall back to just reflecting the ain/aout parameters
    250       FE->mUsrData = NULL;
    251     } else {
    252       clang::RecordDecl *RD =
    253           clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
    254                                     Ctx.getTranslationUnitDecl(),
    255                                     clang::SourceLocation(),
    256                                     clang::SourceLocation(),
    257                                     &Ctx.Idents.get(Id));
    258 
    259       clang::FieldDecl *FD =
    260           clang::FieldDecl::Create(Ctx,
    261                                    RD,
    262                                    clang::SourceLocation(),
    263                                    clang::SourceLocation(),
    264                                    PVD->getIdentifier(),
    265                                    QT->getPointeeType(),
    266                                    NULL,
    267                                    /* BitWidth = */ NULL,
    268                                    /* Mutable = */ false,
    269                                    /* HasInit = */ false);
    270       RD->addDecl(FD);
    271       RD->completeDefinition();
    272 
    273       // Create an export type iff we have a valid usrData type
    274       clang::QualType T = Ctx.getTagDeclType(RD);
    275       slangAssert(!T.isNull());
    276 
    277       RSExportType *ET = RSExportType::Create(Context, T.getTypePtr());
    278 
    279       if (ET == NULL) {
    280         fprintf(stderr, "Failed to export the function %s. There's at least "
    281                         "one parameter whose type is not supported by the "
    282                         "reflection\n", FE->getName().c_str());
    283         return NULL;
    284       }
    285 
    286       slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
    287                   "Parameter packet must be a record");
    288 
    289       FE->mParamPacketType = static_cast<RSExportRecordType *>(ET);
    290     }
    291   }
    292 
    293   if (FE->mIn) {
    294     const clang::Type *T = FE->mIn->getType().getCanonicalType().getTypePtr();
    295     FE->mInType = RSExportType::Create(Context, T);
    296   }
    297 
    298   if (FE->mOut) {
    299     const clang::Type *T = FE->mOut->getType().getCanonicalType().getTypePtr();
    300     FE->mOutType = RSExportType::Create(Context, T);
    301   }
    302 
    303   return FE;
    304 }
    305 
    306 RSExportForEach *RSExportForEach::CreateDummyRoot(RSContext *Context) {
    307   slangAssert(Context);
    308   llvm::StringRef Name = "root";
    309   RSExportForEach *FE = new RSExportForEach(Context, Name);
    310   FE->mDummyRoot = true;
    311   return FE;
    312 }
    313 
    314 bool RSExportForEach::isGraphicsRootRSFunc(int targetAPI,
    315                                            const clang::FunctionDecl *FD) {
    316   if (!isRootRSFunc(FD)) {
    317     return false;
    318   }
    319 
    320   if (FD->getNumParams() == 0) {
    321     // Graphics root function
    322     return true;
    323   }
    324 
    325   // Check for legacy graphics root function (with single parameter).
    326   if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
    327     const clang::QualType &IntType = FD->getASTContext().IntTy;
    328     if (FD->getResultType().getCanonicalType() == IntType) {
    329       return true;
    330     }
    331   }
    332 
    333   return false;
    334 }
    335 
    336 bool RSExportForEach::isRSForEachFunc(int targetAPI,
    337     const clang::FunctionDecl *FD) {
    338   if (isGraphicsRootRSFunc(targetAPI, FD)) {
    339     return false;
    340   }
    341 
    342   // Check if first parameter is a pointer (which is required for ForEach).
    343   unsigned int numParams = FD->getNumParams();
    344 
    345   if (numParams > 0) {
    346     const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
    347     clang::QualType QT = PVD->getType().getCanonicalType();
    348 
    349     if (QT->isPointerType()) {
    350       return true;
    351     }
    352 
    353     // Any non-graphics root() is automatically a ForEach candidate.
    354     // At this point, however, we know that it is not going to be a valid
    355     // compute root() function (due to not having a pointer parameter). We
    356     // still want to return true here, so that we can issue appropriate
    357     // diagnostics.
    358     if (isRootRSFunc(FD)) {
    359       return true;
    360     }
    361   }
    362 
    363   return false;
    364 }
    365 
    366 bool
    367 RSExportForEach::validateSpecialFuncDecl(int targetAPI,
    368                                          clang::DiagnosticsEngine *DiagEngine,
    369                                          clang::FunctionDecl const *FD) {
    370   slangAssert(DiagEngine && FD);
    371   bool valid = true;
    372   const clang::ASTContext &C = FD->getASTContext();
    373   const clang::QualType &IntType = FD->getASTContext().IntTy;
    374 
    375   if (isGraphicsRootRSFunc(targetAPI, FD)) {
    376     if ((targetAPI < SLANG_ICS_TARGET_API) && (FD->getNumParams() == 1)) {
    377       // Legacy graphics root function
    378       const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
    379       clang::QualType QT = PVD->getType().getCanonicalType();
    380       if (QT != IntType) {
    381         DiagEngine->Report(
    382           clang::FullSourceLoc(PVD->getLocation(),
    383                                DiagEngine->getSourceManager()),
    384           DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
    385                                       "invalid parameter type for legacy "
    386                                       "graphics root() function: %0"))
    387           << PVD->getType();
    388         valid = false;
    389       }
    390     }
    391 
    392     // Graphics root function, so verify that it returns an int
    393     if (FD->getResultType().getCanonicalType() != IntType) {
    394       DiagEngine->Report(
    395         clang::FullSourceLoc(FD->getLocation(),
    396                              DiagEngine->getSourceManager()),
    397         DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
    398                                     "root() is required to return "
    399                                     "an int for graphics usage"));
    400       valid = false;
    401     }
    402   } else if (isInitRSFunc(FD) || isDtorRSFunc(FD)) {
    403     if (FD->getNumParams() != 0) {
    404       DiagEngine->Report(
    405           clang::FullSourceLoc(FD->getLocation(),
    406                                DiagEngine->getSourceManager()),
    407           DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
    408                                       "%0(void) is required to have no "
    409                                       "parameters")) << FD->getName();
    410       valid = false;
    411     }
    412 
    413     if (FD->getResultType().getCanonicalType() != C.VoidTy) {
    414       DiagEngine->Report(
    415           clang::FullSourceLoc(FD->getLocation(),
    416                                DiagEngine->getSourceManager()),
    417           DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
    418                                       "%0(void) is required to have a void "
    419                                       "return type")) << FD->getName();
    420       valid = false;
    421     }
    422   } else {
    423     slangAssert(false && "must be called on root, init or .rs.dtor function!");
    424   }
    425 
    426   return valid;
    427 }
    428 
    429 }  // namespace slang
    430