Home | History | Annotate | Download | only in slang
      1 /*
      2  * Copyright 2010, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "slang_rs_metadata_spec.h"
     18 
     19 #include <cstdlib>
     20 #include <list>
     21 #include <map>
     22 #include <string>
     23 
     24 #include "llvm/ADT/SmallVector.h"
     25 #include "llvm/ADT/StringRef.h"
     26 
     27 #include "llvm/Metadata.h"
     28 #include "llvm/Module.h"
     29 
     30 #include "slang_assert.h"
     31 #include "slang_rs_type_spec.h"
     32 
     33 #define RS_METADATA_STRTAB_MN   "#rs_metadata_strtab"
     34 #define RS_TYPE_INFO_MN         "#rs_type_info"
     35 #define RS_EXPORT_VAR_MN        "#rs_export_var"
     36 #define RS_EXPORT_FUNC_MN       "#rs_export_func"
     37 #define RS_EXPORT_RECORD_TYPE_NAME_MN_PREFIX  "%"
     38 
     39 ///////////////////////////////////////////////////////////////////////////////
     40 // Useful utility functions
     41 ///////////////////////////////////////////////////////////////////////////////
     42 static bool EncodeInteger(llvm::LLVMContext &C,
     43                           unsigned I,
     44                           llvm::SmallVectorImpl<llvm::Value*> &Op) {
     45   llvm::StringRef S(reinterpret_cast<const char*>(&I), sizeof(I));
     46   llvm::MDString *MDS = llvm::MDString::get(C, S);
     47 
     48   if (MDS == NULL)
     49     return false;
     50   Op.push_back(MDS);
     51   return true;
     52 }
     53 
     54 ///////////////////////////////////////////////////////////////////////////////
     55 // class RSMetadataEncoderInternal
     56 ///////////////////////////////////////////////////////////////////////////////
     57 namespace {
     58 
     59 class RSMetadataEncoderInternal {
     60  private:
     61   llvm::Module *mModule;
     62 
     63   typedef std::map</* key */unsigned, unsigned/* index */> TypesMapTy;
     64   TypesMapTy mTypes;
     65   std::list<unsigned> mEncodedRSTypeInfo;  // simply a sequece of integers
     66   unsigned mCurTypeIndex;
     67 
     68   // A special type for lookup created record type. It uses record name as key.
     69   typedef std::map</* name */std::string, unsigned/* index */> RecordTypesMapTy;
     70   RecordTypesMapTy mRecordTypes;
     71 
     72   typedef std::map<std::string, unsigned/* index */> StringsMapTy;
     73   StringsMapTy mStrings;
     74   std::list<const char*> mEncodedStrings;
     75   unsigned mCurStringIndex;
     76 
     77   llvm::NamedMDNode *mVarInfoMetadata;
     78   llvm::NamedMDNode *mFuncInfoMetadata;
     79 
     80   // This function check the return value of function:
     81   //   joinString, encodeTypeBase, encode*Type(), encodeRSType, encodeRSVar,
     82   //   and encodeRSFunc. Return false if the value of Index indicates failure.
     83   inline bool checkReturnIndex(unsigned *Index) {
     84     if (*Index == 0)
     85       return false;
     86     else
     87       (*Index)--;
     88     return true;
     89   }
     90 
     91   unsigned joinString(const std::string &S);
     92 
     93   unsigned encodeTypeBase(const struct RSTypeBase *Base);
     94   unsigned encodeTypeBaseAsKey(const struct RSTypeBase *Base);
     95 #define ENUM_RS_DATA_TYPE_CLASS(x)  \
     96   unsigned encode ## x ## Type(const union RSType *T);
     97 RS_DATA_TYPE_CLASS_ENUMS
     98 #undef ENUM_RS_DATA_TYPE_CLASS
     99 
    100   unsigned encodeRSType(const union RSType *T);
    101 
    102   int flushStringTable();
    103   int flushTypeInfo();
    104 
    105  public:
    106   explicit RSMetadataEncoderInternal(llvm::Module *M);
    107 
    108   int encodeRSVar(const RSVar *V);
    109   int encodeRSFunc(const RSFunction *F);
    110 
    111   int finalize();
    112 };
    113 
    114 }  // namespace
    115 
    116 RSMetadataEncoderInternal::RSMetadataEncoderInternal(llvm::Module *M)
    117     : mModule(M),
    118       mCurTypeIndex(0),
    119       mCurStringIndex(0),
    120       mVarInfoMetadata(NULL),
    121       mFuncInfoMetadata(NULL) {
    122   mTypes.clear();
    123   mEncodedRSTypeInfo.clear();
    124   mRecordTypes.clear();
    125   mStrings.clear();
    126 
    127   return;
    128 }
    129 
    130 // Return (StringIndex + 1) when successfully join the string and 0 if there's
    131 // any error.
    132 unsigned RSMetadataEncoderInternal::joinString(const std::string &S) {
    133   StringsMapTy::const_iterator I = mStrings.find(S);
    134 
    135   if (I != mStrings.end())
    136     return (I->second + 1);
    137 
    138   // Add S into mStrings
    139   std::pair<StringsMapTy::iterator, bool> Res =
    140       mStrings.insert(std::make_pair(S, mCurStringIndex));
    141   // Insertion failed
    142   if (!Res.second)
    143     return 0;
    144 
    145   // Add S into mEncodedStrings
    146   mEncodedStrings.push_back(Res.first->first.c_str());
    147   mCurStringIndex++;
    148 
    149   // Return (StringIndex + 1)
    150   return (Res.first->second + 1);
    151 }
    152 
    153 unsigned
    154 RSMetadataEncoderInternal::encodeTypeBase(const struct RSTypeBase *Base) {
    155   mEncodedRSTypeInfo.push_back(Base->bits);
    156   return ++mCurTypeIndex;
    157 }
    158 
    159 unsigned RSMetadataEncoderInternal::encodeTypeBaseAsKey(
    160     const struct RSTypeBase *Base) {
    161   TypesMapTy::const_iterator I = mTypes.find(Base->bits);
    162   if (I != mTypes.end())
    163     return (I->second + 1);
    164 
    165   // Add Base into mTypes
    166   std::pair<TypesMapTy::iterator, bool> Res =
    167       mTypes.insert(std::make_pair(Base->bits, mCurTypeIndex));
    168   // Insertion failed
    169   if (!Res.second)
    170     return 0;
    171 
    172   // Push to mEncodedRSTypeInfo. This will also update mCurTypeIndex.
    173   return encodeTypeBase(Base);
    174 }
    175 
    176 unsigned RSMetadataEncoderInternal::encodePrimitiveType(const union RSType *T) {
    177   return encodeTypeBaseAsKey(RS_GET_TYPE_BASE(T));
    178 }
    179 
    180 unsigned RSMetadataEncoderInternal::encodePointerType(const union RSType *T) {
    181   // Encode pointee type first
    182   unsigned PointeeType = encodeRSType(RS_POINTER_TYPE_GET_POINTEE_TYPE(T));
    183   if (!checkReturnIndex(&PointeeType))
    184     return 0;
    185 
    186   unsigned Res = encodeTypeBaseAsKey(RS_GET_TYPE_BASE(T));
    187   // Push PointeeType after the base type
    188   mEncodedRSTypeInfo.push_back(PointeeType);
    189   return Res;
    190 }
    191 
    192 unsigned RSMetadataEncoderInternal::encodeVectorType(const union RSType *T) {
    193   return encodeTypeBaseAsKey(RS_GET_TYPE_BASE(T));
    194 }
    195 
    196 unsigned RSMetadataEncoderInternal::encodeMatrixType(const union RSType *T) {
    197   return encodeTypeBaseAsKey(RS_GET_TYPE_BASE(T));
    198 }
    199 
    200 unsigned
    201 RSMetadataEncoderInternal::encodeConstantArrayType(const union RSType *T) {
    202   // Encode element type
    203   unsigned ElementType =
    204       encodeRSType(RS_CONSTANT_ARRAY_TYPE_GET_ELEMENT_TYPE(T));
    205   if (!checkReturnIndex(&ElementType))
    206     return 0;
    207 
    208   unsigned Res = encodeTypeBase(RS_GET_TYPE_BASE(T));
    209   // Push the ElementType after the type base
    210   mEncodedRSTypeInfo.push_back(ElementType);
    211   return Res;
    212 }
    213 
    214 unsigned RSMetadataEncoderInternal::encodeRecordType(const union RSType *T) {
    215   // Construct record name
    216   std::string RecordInfoMetadataName(RS_EXPORT_RECORD_TYPE_NAME_MN_PREFIX);
    217   RecordInfoMetadataName.append(RS_RECORD_TYPE_GET_NAME(T));
    218 
    219   // Try to find it in mRecordTypes
    220   RecordTypesMapTy::const_iterator I =
    221       mRecordTypes.find(RecordInfoMetadataName);
    222 
    223   // This record type has been encoded before. Fast return its index here.
    224   if (I != mRecordTypes.end())
    225     return (I->second + 1);
    226 
    227   // Encode this record type into mTypes. Encode record name string first.
    228   unsigned RecordName = joinString(RecordInfoMetadataName);
    229   if (!checkReturnIndex(&RecordName))
    230     return 0;
    231 
    232   unsigned Base = encodeTypeBase(RS_GET_TYPE_BASE(T));
    233   if (!checkReturnIndex(&Base))
    234     return 0;
    235 
    236   // Push record name after encoding the type base
    237   mEncodedRSTypeInfo.push_back(RecordName);
    238 
    239   // Add this record type into the map
    240   std::pair<StringsMapTy::iterator, bool> Res =
    241       mRecordTypes.insert(std::make_pair(RecordInfoMetadataName, Base));
    242   // Insertion failed
    243   if (!Res.second)
    244     return 0;
    245 
    246   // Create a named MDNode for this record type. We cannot create this before
    247   // encoding type base into Types and updating mRecordTypes. This is because
    248   // we may have structure like:
    249   //
    250   //            struct foo {
    251   //              ...
    252   //              struct foo *bar;  // self type reference
    253   //              ...
    254   //            }
    255   llvm::NamedMDNode *RecordInfoMetadata =
    256       mModule->getOrInsertNamedMetadata(RecordInfoMetadataName);
    257 
    258   slangAssert((RecordInfoMetadata->getNumOperands() == 0) &&
    259               "Record created before!");
    260 
    261   // Encode field info into this named MDNode
    262   llvm::SmallVector<llvm::Value*, 3> FieldInfo;
    263 
    264   for (unsigned i = 0; i < RS_RECORD_TYPE_GET_NUM_FIELDS(T); i++) {
    265     // 1. field name
    266     unsigned FieldName = joinString(RS_RECORD_TYPE_GET_FIELD_NAME(T, i));
    267     if (!checkReturnIndex(&FieldName))
    268       return 0;
    269     if (!EncodeInteger(mModule->getContext(),
    270                        FieldName,
    271                        FieldInfo)) {
    272       return 0;
    273     }
    274 
    275     // 2. field type
    276     unsigned FieldType = encodeRSType(RS_RECORD_TYPE_GET_FIELD_TYPE(T, i));
    277     if (!checkReturnIndex(&FieldType))
    278       return 0;
    279     if (!EncodeInteger(mModule->getContext(),
    280                        FieldType,
    281                        FieldInfo)) {
    282       return 0;
    283     }
    284 
    285     RecordInfoMetadata->addOperand(llvm::MDNode::get(mModule->getContext(),
    286                                                      FieldInfo));
    287     FieldInfo.clear();
    288   }
    289 
    290   return (Res.first->second + 1);
    291 }
    292 
    293 unsigned RSMetadataEncoderInternal::encodeRSType(const union RSType *T) {
    294   switch (static_cast<enum RSTypeClass>(RS_TYPE_GET_CLASS(T))) {
    295 #define ENUM_RS_DATA_TYPE_CLASS(x)  \
    296     case RS_TC_ ## x: return encode ## x ## Type(T);
    297     RS_DATA_TYPE_CLASS_ENUMS
    298 #undef ENUM_RS_DATA_TYPE_CLASS
    299     default: return 0;
    300   }
    301   return 0;
    302 }
    303 
    304 int RSMetadataEncoderInternal::encodeRSVar(const RSVar *V) {
    305   // check parameter
    306   if ((V == NULL) || (V->name == NULL) || (V->type == NULL))
    307     return -1;
    308 
    309   // 1. var name
    310   unsigned VarName = joinString(V->name);
    311   if (!checkReturnIndex(&VarName)) {
    312     return -2;
    313   }
    314 
    315   // 2. type
    316   unsigned Type = encodeRSType(V->type);
    317 
    318   llvm::SmallVector<llvm::Value*, 1> VarInfo;
    319 
    320   if (!EncodeInteger(mModule->getContext(), VarName, VarInfo)) {
    321     return -3;
    322   }
    323   if (!EncodeInteger(mModule->getContext(), Type, VarInfo)) {
    324     return -4;
    325   }
    326 
    327   if (mVarInfoMetadata == NULL)
    328     mVarInfoMetadata = mModule->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
    329 
    330   mVarInfoMetadata->addOperand(llvm::MDNode::get(mModule->getContext(),
    331                                                  VarInfo));
    332 
    333   return 0;
    334 }
    335 
    336 int RSMetadataEncoderInternal::encodeRSFunc(const RSFunction *F) {
    337   // check parameter
    338   if ((F == NULL) || (F->name == NULL)) {
    339     return -1;
    340   }
    341 
    342   // 1. var name
    343   unsigned FuncName = joinString(F->name);
    344   if (!checkReturnIndex(&FuncName)) {
    345     return -2;
    346   }
    347 
    348   llvm::SmallVector<llvm::Value*, 1> FuncInfo;
    349   if (!EncodeInteger(mModule->getContext(), FuncName, FuncInfo)) {
    350     return -3;
    351   }
    352 
    353   if (mFuncInfoMetadata == NULL)
    354     mFuncInfoMetadata = mModule->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
    355 
    356   mFuncInfoMetadata->addOperand(llvm::MDNode::get(mModule->getContext(),
    357                                                   FuncInfo));
    358 
    359   return 0;
    360 }
    361 
    362 // Write string table and string index table
    363 int RSMetadataEncoderInternal::flushStringTable() {
    364   slangAssert((mCurStringIndex == mEncodedStrings.size()));
    365   slangAssert((mCurStringIndex == mStrings.size()));
    366 
    367   if (mCurStringIndex == 0)
    368     return 0;
    369 
    370   // Prepare named MDNode for string table and string index table.
    371   llvm::NamedMDNode *RSMetadataStrTab =
    372       mModule->getOrInsertNamedMetadata(RS_METADATA_STRTAB_MN);
    373   RSMetadataStrTab->dropAllReferences();
    374 
    375   unsigned StrTabSize = 0;
    376   unsigned *StrIdx = reinterpret_cast<unsigned*>(
    377                         ::malloc((mStrings.size() + 1) * sizeof(unsigned)));
    378 
    379   if (StrIdx == NULL)
    380     return -1;
    381 
    382   unsigned StrIdxI = 0;  // iterator for array StrIdx
    383 
    384   // count StrTabSize and fill StrIdx by the way
    385   for (std::list<const char*>::const_iterator I = mEncodedStrings.begin(),
    386           E = mEncodedStrings.end();
    387        I != E;
    388        I++) {
    389     StrIdx[StrIdxI++] = StrTabSize;
    390     StrTabSize += ::strlen(*I) + 1 /* for '\0' */;
    391   }
    392   StrIdx[StrIdxI] = StrTabSize;
    393 
    394   // Allocate
    395   char *StrTab = reinterpret_cast<char*>(::malloc(StrTabSize));
    396   if (StrTab == NULL) {
    397     free(StrIdx);
    398     return -1;
    399   }
    400 
    401   llvm::StringRef StrTabData(StrTab, StrTabSize);
    402   llvm::StringRef StrIdxData(reinterpret_cast<const char*>(StrIdx),
    403                              mStrings.size() * sizeof(unsigned));
    404 
    405   // Copy
    406   StrIdxI = 1;
    407   for (std::list<const char*>::const_iterator I = mEncodedStrings.begin(),
    408           E = mEncodedStrings.end();
    409        I != E;
    410        I++) {
    411     // Get string length from StrIdx (O(1)) instead of call strlen again (O(n)).
    412     unsigned CurStrLength = StrIdx[StrIdxI] - StrIdx[StrIdxI - 1];
    413     ::memcpy(StrTab, *I, CurStrLength);
    414     // Move forward the pointer
    415     StrTab += CurStrLength;
    416     StrIdxI++;
    417   }
    418 
    419   // Flush to metadata
    420   llvm::Value *StrTabMDS =
    421       llvm::MDString::get(mModule->getContext(), StrTabData);
    422   llvm::Value *StrIdxMDS =
    423       llvm::MDString::get(mModule->getContext(), StrIdxData);
    424 
    425   if ((StrTabMDS == NULL) || (StrIdxMDS == NULL)) {
    426     free(StrIdx);
    427     free(StrTab);
    428     return -1;
    429   }
    430 
    431   llvm::SmallVector<llvm::Value*, 2> StrTabVal;
    432   StrTabVal.push_back(StrTabMDS);
    433   StrTabVal.push_back(StrIdxMDS);
    434   RSMetadataStrTab->addOperand(llvm::MDNode::get(mModule->getContext(),
    435                                                  StrTabVal));
    436 
    437   return 0;
    438 }
    439 
    440 // Write RS type stream
    441 int RSMetadataEncoderInternal::flushTypeInfo() {
    442   unsigned TypeInfoCount = mEncodedRSTypeInfo.size();
    443   if (TypeInfoCount <= 0) {
    444     return 0;
    445   }
    446 
    447   llvm::NamedMDNode *RSTypeInfo =
    448       mModule->getOrInsertNamedMetadata(RS_TYPE_INFO_MN);
    449   RSTypeInfo->dropAllReferences();
    450 
    451   unsigned *TypeInfos =
    452       reinterpret_cast<unsigned*>(::malloc(TypeInfoCount * sizeof(unsigned)));
    453   unsigned TypeInfosIdx = 0;  // iterator for array TypeInfos
    454 
    455   if (TypeInfos == NULL)
    456     return -1;
    457 
    458   for (std::list<unsigned>::const_iterator I = mEncodedRSTypeInfo.begin(),
    459           E = mEncodedRSTypeInfo.end();
    460        I != E;
    461        I++)
    462     TypeInfos[TypeInfosIdx++] = *I;
    463 
    464   llvm::StringRef TypeInfoData(reinterpret_cast<const char*>(TypeInfos),
    465                                TypeInfoCount * sizeof(unsigned));
    466   llvm::Value *TypeInfoMDS =
    467       llvm::MDString::get(mModule->getContext(), TypeInfoData);
    468   if (TypeInfoMDS == NULL) {
    469     free(TypeInfos);
    470     return -1;
    471   }
    472 
    473   llvm::SmallVector<llvm::Value*, 1> TypeInfo;
    474   TypeInfo.push_back(TypeInfoMDS);
    475 
    476   RSTypeInfo->addOperand(llvm::MDNode::get(mModule->getContext(),
    477                                            TypeInfo));
    478   free(TypeInfos);
    479 
    480   return 0;
    481 }
    482 
    483 int RSMetadataEncoderInternal::finalize() {
    484   int Res = flushStringTable();
    485   if (Res != 0)
    486     return Res;
    487 
    488   Res = flushTypeInfo();
    489   if (Res != 0)
    490     return Res;
    491 
    492   return 0;
    493 }
    494 
    495 ///////////////////////////////////////////////////////////////////////////////
    496 // APIs
    497 ///////////////////////////////////////////////////////////////////////////////
    498 RSMetadataEncoder *CreateRSMetadataEncoder(llvm::Module *M) {
    499   return reinterpret_cast<RSMetadataEncoder*>(new RSMetadataEncoderInternal(M));
    500 }
    501 
    502 int RSEncodeVarMetadata(RSMetadataEncoder *E, const RSVar *V) {
    503   return reinterpret_cast<RSMetadataEncoderInternal*>(E)->encodeRSVar(V);
    504 }
    505 
    506 int RSEncodeFunctionMetadata(RSMetadataEncoder *E, const RSFunction *F) {
    507   return reinterpret_cast<RSMetadataEncoderInternal*>(E)->encodeRSFunc(F);
    508 }
    509 
    510 void DestroyRSMetadataEncoder(RSMetadataEncoder *E) {
    511   RSMetadataEncoderInternal *C =
    512       reinterpret_cast<RSMetadataEncoderInternal*>(E);
    513   delete C;
    514   return;
    515 }
    516 
    517 int FinalizeRSMetadataEncoder(RSMetadataEncoder *E) {
    518   RSMetadataEncoderInternal *C =
    519       reinterpret_cast<RSMetadataEncoderInternal*>(E);
    520   int Res = C->finalize();
    521   DestroyRSMetadataEncoder(E);
    522   return Res;
    523 }
    524