Home | History | Annotate | Download | only in slang
      1 /*
      2  * Copyright 2010, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "slang_rs_export_func.h"
     18 
     19 #include <string>
     20 
     21 #include "clang/AST/ASTContext.h"
     22 #include "clang/AST/Decl.h"
     23 
     24 #include "llvm/IR/DataLayout.h"
     25 #include "llvm/IR/DerivedTypes.h"
     26 
     27 #include "slang_assert.h"
     28 #include "slang_rs_context.h"
     29 
     30 namespace slang {
     31 
     32 namespace {
     33 
     34 // Ensure that the exported function is actually valid
     35 static bool ValidateFuncDecl(slang::RSContext *Context,
     36                              const clang::FunctionDecl *FD) {
     37   slangAssert(Context && FD);
     38   const clang::ASTContext &C = FD->getASTContext();
     39   if (FD->getReturnType().getCanonicalType() != C.VoidTy) {
     40     Context->ReportError(
     41         FD->getLocation(),
     42         "invokable non-static functions are required to return void");
     43     return false;
     44   }
     45   return true;
     46 }
     47 
     48 }  // namespace
     49 
     50 RSExportFunc *RSExportFunc::Create(RSContext *Context,
     51                                    const clang::FunctionDecl *FD) {
     52   llvm::StringRef Name = FD->getName();
     53   RSExportFunc *F;
     54 
     55   slangAssert(!Name.empty() && "Function must have a name");
     56 
     57   if (!ValidateFuncDecl(Context, FD)) {
     58     return nullptr;
     59   }
     60 
     61   F = new RSExportFunc(Context, Name, FD);
     62 
     63   // Initialize mParamPacketType
     64   if (FD->getNumParams() <= 0) {
     65     F->mParamPacketType = nullptr;
     66   } else {
     67     clang::ASTContext &Ctx = Context->getASTContext();
     68 
     69     std::string Id = CreateDummyName("helper_func_param", F->getName());
     70 
     71     clang::RecordDecl *RD =
     72         clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
     73                                   Ctx.getTranslationUnitDecl(),
     74                                   clang::SourceLocation(),
     75                                   clang::SourceLocation(),
     76                                   &Ctx.Idents.get(Id));
     77 
     78     for (unsigned i = 0; i < FD->getNumParams(); i++) {
     79       const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
     80       llvm::StringRef ParamName = PVD->getName();
     81 
     82       if (PVD->hasDefaultArg())
     83         fprintf(stderr, "Note: parameter '%s' in function '%s' has default "
     84                         "value which is not supported\n",
     85                         ParamName.str().c_str(),
     86                         F->getName().c_str());
     87 
     88       clang::FieldDecl *FD =
     89           clang::FieldDecl::Create(Ctx,
     90                                    RD,
     91                                    clang::SourceLocation(),
     92                                    clang::SourceLocation(),
     93                                    PVD->getIdentifier(),
     94                                    PVD->getOriginalType(),
     95                                    nullptr,
     96                                    /* BitWidth = */ nullptr,
     97                                    /* Mutable = */ false,
     98                                    /* HasInit = */ clang::ICIS_NoInit);
     99       RD->addDecl(FD);
    100     }
    101 
    102     RD->completeDefinition();
    103 
    104     clang::QualType T = Ctx.getTagDeclType(RD);
    105     slangAssert(!T.isNull());
    106 
    107     RSExportType *ET =
    108       RSExportType::Create(Context, T.getTypePtr(), NotLegacyKernelArgument);
    109 
    110     if (ET == nullptr) {
    111       fprintf(stderr, "Failed to export the function %s. There's at least one "
    112                       "parameter whose type is not supported by the "
    113                       "reflection\n", F->getName().c_str());
    114       return nullptr;
    115     }
    116 
    117     slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
    118            "Parameter packet must be a record");
    119 
    120     F->mParamPacketType = static_cast<RSExportRecordType *>(ET);
    121   }
    122 
    123   return F;
    124 }
    125 
    126 bool
    127 RSExportFunc::checkParameterPacketType(llvm::StructType *ParamTy) const {
    128   if (ParamTy == nullptr)
    129     return !hasParam();
    130   else if (!hasParam())
    131     return false;
    132 
    133   slangAssert(mParamPacketType != nullptr);
    134 
    135   const RSExportRecordType *ERT = mParamPacketType;
    136   // must have same number of elements
    137   if (ERT->getFields().size() != ParamTy->getNumElements())
    138     return false;
    139 
    140   const llvm::StructLayout *ParamTySL =
    141       getRSContext()->getDataLayout().getStructLayout(ParamTy);
    142 
    143   unsigned Index = 0;
    144   for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
    145        FE = ERT->fields_end(); FI != FE; FI++, Index++) {
    146     const RSExportRecordType::Field *F = *FI;
    147 
    148     llvm::Type *T1 = F->getType()->getLLVMType();
    149     llvm::Type *T2 = ParamTy->getTypeAtIndex(Index);
    150 
    151     // Fast check
    152     if (T1 == T2)
    153       continue;
    154 
    155     // Check offset
    156     size_t T1Offset = F->getOffsetInParent();
    157     size_t T2Offset = ParamTySL->getElementOffset(Index);
    158 
    159     if (T1Offset != T2Offset)
    160       return false;
    161 
    162     // Check size
    163     size_t T1Size = F->getType()->getAllocSize();
    164     size_t T2Size = getRSContext()->getDataLayout().getTypeAllocSize(T2);
    165 
    166     if (T1Size != T2Size)
    167       return false;
    168   }
    169 
    170   return true;
    171 }
    172 
    173 }  // namespace slang
    174