Home | History | Annotate | Download | only in slang
      1 /*
      2  * Copyright 2010, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "slang_rs_export_func.h"
     18 
     19 #include <string>
     20 
     21 #include "clang/AST/ASTContext.h"
     22 #include "clang/AST/Decl.h"
     23 
     24 #include "llvm/IR/DataLayout.h"
     25 #include "llvm/IR/DerivedTypes.h"
     26 
     27 #include "slang_assert.h"
     28 #include "slang_rs_context.h"
     29 
     30 namespace slang {
     31 
     32 namespace {
     33 
     34 // Ensure that the exported function is actually valid
     35 static bool ValidateFuncDecl(clang::DiagnosticsEngine *DiagEngine,
     36                              const clang::FunctionDecl *FD) {
     37   slangAssert(DiagEngine && FD);
     38   const clang::ASTContext &C = FD->getASTContext();
     39   if (FD->getResultType().getCanonicalType() != C.VoidTy) {
     40     DiagEngine->Report(
     41       clang::FullSourceLoc(FD->getLocation(), DiagEngine->getSourceManager()),
     42       DiagEngine->getCustomDiagID(clang::DiagnosticsEngine::Error,
     43                                   "invokable non-static functions are "
     44                                   "required to return void"));
     45     return false;
     46   }
     47   return true;
     48 }
     49 
     50 }  // namespace
     51 
     52 RSExportFunc *RSExportFunc::Create(RSContext *Context,
     53                                    const clang::FunctionDecl *FD) {
     54   llvm::StringRef Name = FD->getName();
     55   RSExportFunc *F;
     56 
     57   slangAssert(!Name.empty() && "Function must have a name");
     58 
     59   if (!ValidateFuncDecl(Context->getDiagnostics(), FD)) {
     60     return NULL;
     61   }
     62 
     63   F = new RSExportFunc(Context, Name, FD);
     64 
     65   // Initialize mParamPacketType
     66   if (FD->getNumParams() <= 0) {
     67     F->mParamPacketType = NULL;
     68   } else {
     69     clang::ASTContext &Ctx = Context->getASTContext();
     70 
     71     std::string Id(DUMMY_RS_TYPE_NAME_PREFIX"helper_func_param:");
     72     Id.append(F->getName()).append(DUMMY_RS_TYPE_NAME_POSTFIX);
     73 
     74     clang::RecordDecl *RD =
     75         clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
     76                                   Ctx.getTranslationUnitDecl(),
     77                                   clang::SourceLocation(),
     78                                   clang::SourceLocation(),
     79                                   &Ctx.Idents.get(Id));
     80 
     81     for (unsigned i = 0; i < FD->getNumParams(); i++) {
     82       const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
     83       llvm::StringRef ParamName = PVD->getName();
     84 
     85       if (PVD->hasDefaultArg())
     86         fprintf(stderr, "Note: parameter '%s' in function '%s' has default "
     87                         "value which is not supported\n",
     88                         ParamName.str().c_str(),
     89                         F->getName().c_str());
     90 
     91       clang::FieldDecl *FD =
     92           clang::FieldDecl::Create(Ctx,
     93                                    RD,
     94                                    clang::SourceLocation(),
     95                                    clang::SourceLocation(),
     96                                    PVD->getIdentifier(),
     97                                    PVD->getOriginalType(),
     98                                    NULL,
     99                                    /* BitWidth = */ NULL,
    100                                    /* Mutable = */ false,
    101                                    /* HasInit = */ clang::ICIS_NoInit);
    102       RD->addDecl(FD);
    103     }
    104 
    105     RD->completeDefinition();
    106 
    107     clang::QualType T = Ctx.getTagDeclType(RD);
    108     slangAssert(!T.isNull());
    109 
    110     RSExportType *ET =
    111       RSExportType::Create(Context, T.getTypePtr());
    112 
    113     if (ET == NULL) {
    114       fprintf(stderr, "Failed to export the function %s. There's at least one "
    115                       "parameter whose type is not supported by the "
    116                       "reflection\n", F->getName().c_str());
    117       return NULL;
    118     }
    119 
    120     slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
    121            "Parameter packet must be a record");
    122 
    123     F->mParamPacketType = static_cast<RSExportRecordType *>(ET);
    124   }
    125 
    126   return F;
    127 }
    128 
    129 bool
    130 RSExportFunc::checkParameterPacketType(llvm::StructType *ParamTy) const {
    131   if (ParamTy == NULL)
    132     return !hasParam();
    133   else if (!hasParam())
    134     return false;
    135 
    136   slangAssert(mParamPacketType != NULL);
    137 
    138   const RSExportRecordType *ERT = mParamPacketType;
    139   // must have same number of elements
    140   if (ERT->getFields().size() != ParamTy->getNumElements())
    141     return false;
    142 
    143   const llvm::StructLayout *ParamTySL =
    144       getRSContext()->getDataLayout()->getStructLayout(ParamTy);
    145 
    146   unsigned Index = 0;
    147   for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
    148        FE = ERT->fields_end(); FI != FE; FI++, Index++) {
    149     const RSExportRecordType::Field *F = *FI;
    150 
    151     llvm::Type *T1 = F->getType()->getLLVMType();
    152     llvm::Type *T2 = ParamTy->getTypeAtIndex(Index);
    153 
    154     // Fast check
    155     if (T1 == T2)
    156       continue;
    157 
    158     // Check offset
    159     size_t T1Offset = F->getOffsetInParent();
    160     size_t T2Offset = ParamTySL->getElementOffset(Index);
    161 
    162     if (T1Offset != T2Offset)
    163       return false;
    164 
    165     // Check size
    166     size_t T1Size = RSExportType::GetTypeAllocSize(F->getType());
    167     size_t T2Size = getRSContext()->getDataLayout()->getTypeAllocSize(T2);
    168 
    169     if (T1Size != T2Size)
    170       return false;
    171   }
    172 
    173   return true;
    174 }
    175 
    176 }  // namespace slang
    177