Home | History | Annotate | Download | only in slang
      1 /*
      2  * Copyright 2010, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "slang_rs_metadata_spec.h"
     18 
     19 #include <cstdlib>
     20 #include <list>
     21 #include <map>
     22 #include <string>
     23 
     24 #include "llvm/ADT/SmallVector.h"
     25 #include "llvm/ADT/StringRef.h"
     26 
     27 #include "llvm/Metadata.h"
     28 #include "llvm/Module.h"
     29 
     30 #include "slang_assert.h"
     31 #include "slang_rs_type_spec.h"
     32 
     33 #define RS_METADATA_STRTAB_MN   "#rs_metadata_strtab"
     34 #define RS_TYPE_INFO_MN         "#rs_type_info"
     35 #define RS_EXPORT_VAR_MN        "#rs_export_var"
     36 #define RS_EXPORT_FUNC_MN       "#rs_export_func"
     37 #define RS_EXPORT_RECORD_TYPE_NAME_MN_PREFIX  "%"
     38 
     39 ///////////////////////////////////////////////////////////////////////////////
     40 // Useful utility functions
     41 ///////////////////////////////////////////////////////////////////////////////
     42 static bool EncodeInteger(llvm::LLVMContext &C,
     43                           unsigned I,
     44                           llvm::SmallVectorImpl<llvm::Value*> &Op) {
     45   llvm::StringRef S(reinterpret_cast<const char*>(&I), sizeof(I));
     46   llvm::MDString *MDS = llvm::MDString::get(C, S);
     47 
     48   if (MDS == NULL)
     49     return false;
     50   Op.push_back(MDS);
     51   return true;
     52 }
     53 
     54 ///////////////////////////////////////////////////////////////////////////////
     55 // class RSMetadataEncoderInternal
     56 ///////////////////////////////////////////////////////////////////////////////
     57 namespace {
     58 
     59 class RSMetadataEncoderInternal {
     60  private:
     61   llvm::Module *mModule;
     62 
     63   typedef std::map</* key */unsigned, unsigned/* index */> TypesMapTy;
     64   TypesMapTy mTypes;
     65   std::list<unsigned> mEncodedRSTypeInfo;  // simply a sequece of integers
     66   unsigned mCurTypeIndex;
     67 
     68   // A special type for lookup created record type. It uses record name as key.
     69   typedef std::map</* name */std::string, unsigned/* index */> RecordTypesMapTy;
     70   RecordTypesMapTy mRecordTypes;
     71 
     72   typedef std::map<std::string, unsigned/* index */> StringsMapTy;
     73   StringsMapTy mStrings;
     74   std::list<const char*> mEncodedStrings;
     75   unsigned mCurStringIndex;
     76 
     77   llvm::NamedMDNode *mVarInfoMetadata;
     78   llvm::NamedMDNode *mFuncInfoMetadata;
     79 
     80   // This function check the return value of function:
     81   //   joinString, encodeTypeBase, encode*Type(), encodeRSType, encodeRSVar,
     82   //   and encodeRSFunc. Return false if the value of Index indicates failure.
     83   inline bool checkReturnIndex(unsigned *Index) {
     84     if (*Index == 0)
     85       return false;
     86     else
     87       (*Index)--;
     88     return true;
     89   }
     90 
     91   unsigned joinString(const std::string &S);
     92 
     93   unsigned encodeTypeBase(const struct RSTypeBase *Base);
     94   unsigned encodeTypeBaseAsKey(const struct RSTypeBase *Base);
     95 #define ENUM_RS_DATA_TYPE_CLASS(x)  \
     96   unsigned encode ## x ## Type(const union RSType *T);
     97 RS_DATA_TYPE_CLASS_ENUMS
     98 #undef ENUM_RS_DATA_TYPE_CLASS
     99 
    100   unsigned encodeRSType(const union RSType *T);
    101 
    102   int flushStringTable();
    103   int flushTypeInfo();
    104 
    105  public:
    106   explicit RSMetadataEncoderInternal(llvm::Module *M);
    107 
    108   int encodeRSVar(const RSVar *V);
    109   int encodeRSFunc(const RSFunction *F);
    110 
    111   int finalize();
    112 };
    113 }
    114 
    115 RSMetadataEncoderInternal::RSMetadataEncoderInternal(llvm::Module *M)
    116     : mModule(M),
    117       mCurTypeIndex(0),
    118       mCurStringIndex(0),
    119       mVarInfoMetadata(NULL),
    120       mFuncInfoMetadata(NULL) {
    121   mTypes.clear();
    122   mEncodedRSTypeInfo.clear();
    123   mRecordTypes.clear();
    124   mStrings.clear();
    125 
    126   return;
    127 }
    128 
    129 // Return (StringIndex + 1) when successfully join the string and 0 if there's
    130 // any error.
    131 unsigned RSMetadataEncoderInternal::joinString(const std::string &S) {
    132   StringsMapTy::const_iterator I = mStrings.find(S);
    133 
    134   if (I != mStrings.end())
    135     return (I->second + 1);
    136 
    137   // Add S into mStrings
    138   std::pair<StringsMapTy::iterator, bool> Res =
    139       mStrings.insert(std::make_pair(S, mCurStringIndex));
    140   // Insertion failed
    141   if (!Res.second)
    142     return 0;
    143 
    144   // Add S into mEncodedStrings
    145   mEncodedStrings.push_back(Res.first->first.c_str());
    146   mCurStringIndex++;
    147 
    148   // Return (StringIndex + 1)
    149   return (Res.first->second + 1);
    150 }
    151 
    152 unsigned
    153 RSMetadataEncoderInternal::encodeTypeBase(const struct RSTypeBase *Base) {
    154   mEncodedRSTypeInfo.push_back(Base->bits);
    155   return ++mCurTypeIndex;
    156 }
    157 
    158 unsigned RSMetadataEncoderInternal::encodeTypeBaseAsKey(
    159     const struct RSTypeBase *Base) {
    160   TypesMapTy::const_iterator I = mTypes.find(Base->bits);
    161   if (I != mTypes.end())
    162     return (I->second + 1);
    163 
    164   // Add Base into mTypes
    165   std::pair<TypesMapTy::iterator, bool> Res =
    166       mTypes.insert(std::make_pair(Base->bits, mCurTypeIndex));
    167   // Insertion failed
    168   if (!Res.second)
    169     return 0;
    170 
    171   // Push to mEncodedRSTypeInfo. This will also update mCurTypeIndex.
    172   return encodeTypeBase(Base);
    173 }
    174 
    175 unsigned RSMetadataEncoderInternal::encodePrimitiveType(const union RSType *T) {
    176   return encodeTypeBaseAsKey(RS_GET_TYPE_BASE(T));
    177 }
    178 
    179 unsigned RSMetadataEncoderInternal::encodePointerType(const union RSType *T) {
    180   // Encode pointee type first
    181   unsigned PointeeType = encodeRSType(RS_POINTER_TYPE_GET_POINTEE_TYPE(T));
    182   if (!checkReturnIndex(&PointeeType))
    183     return 0;
    184 
    185   unsigned Res = encodeTypeBaseAsKey(RS_GET_TYPE_BASE(T));
    186   // Push PointeeType after the base type
    187   mEncodedRSTypeInfo.push_back(PointeeType);
    188   return Res;
    189 }
    190 
    191 unsigned RSMetadataEncoderInternal::encodeVectorType(const union RSType *T) {
    192   return encodeTypeBaseAsKey(RS_GET_TYPE_BASE(T));
    193 }
    194 
    195 unsigned RSMetadataEncoderInternal::encodeMatrixType(const union RSType *T) {
    196   return encodeTypeBaseAsKey(RS_GET_TYPE_BASE(T));
    197 }
    198 
    199 unsigned
    200 RSMetadataEncoderInternal::encodeConstantArrayType(const union RSType *T) {
    201   // Encode element type
    202   unsigned ElementType =
    203       encodeRSType(RS_CONSTANT_ARRAY_TYPE_GET_ELEMENT_TYPE(T));
    204   if (!checkReturnIndex(&ElementType))
    205     return 0;
    206 
    207   unsigned Res = encodeTypeBase(RS_GET_TYPE_BASE(T));
    208   // Push the ElementType after the type base
    209   mEncodedRSTypeInfo.push_back(ElementType);
    210   return Res;
    211 }
    212 
    213 unsigned RSMetadataEncoderInternal::encodeRecordType(const union RSType *T) {
    214   // Construct record name
    215   std::string RecordInfoMetadataName(RS_EXPORT_RECORD_TYPE_NAME_MN_PREFIX);
    216   RecordInfoMetadataName.append(RS_RECORD_TYPE_GET_NAME(T));
    217 
    218   // Try to find it in mRecordTypes
    219   RecordTypesMapTy::const_iterator I =
    220       mRecordTypes.find(RecordInfoMetadataName);
    221 
    222   // This record type has been encoded before. Fast return its index here.
    223   if (I != mRecordTypes.end())
    224     return (I->second + 1);
    225 
    226   // Encode this record type into mTypes. Encode record name string first.
    227   unsigned RecordName = joinString(RecordInfoMetadataName);
    228   if (!checkReturnIndex(&RecordName))
    229     return 0;
    230 
    231   unsigned Base = encodeTypeBase(RS_GET_TYPE_BASE(T));
    232   if (!checkReturnIndex(&Base))
    233     return 0;
    234 
    235   // Push record name after encoding the type base
    236   mEncodedRSTypeInfo.push_back(RecordName);
    237 
    238   // Add this record type into the map
    239   std::pair<StringsMapTy::iterator, bool> Res =
    240       mRecordTypes.insert(std::make_pair(RecordInfoMetadataName, Base));
    241   // Insertion failed
    242   if (!Res.second)
    243     return 0;
    244 
    245   // Create a named MDNode for this record type. We cannot create this before
    246   // encoding type base into Types and updating mRecordTypes. This is because
    247   // we may have structure like:
    248   //
    249   //            struct foo {
    250   //              ...
    251   //              struct foo *bar;  // self type reference
    252   //              ...
    253   //            }
    254   llvm::NamedMDNode *RecordInfoMetadata =
    255       mModule->getOrInsertNamedMetadata(RecordInfoMetadataName);
    256 
    257   slangAssert((RecordInfoMetadata->getNumOperands() == 0) &&
    258               "Record created before!");
    259 
    260   // Encode field info into this named MDNode
    261   llvm::SmallVector<llvm::Value*, 3> FieldInfo;
    262 
    263   for (unsigned i = 0; i < RS_RECORD_TYPE_GET_NUM_FIELDS(T); i++) {
    264     // 1. field name
    265     unsigned FieldName = joinString(RS_RECORD_TYPE_GET_FIELD_NAME(T, i));
    266     if (!checkReturnIndex(&FieldName))
    267       return 0;
    268     if (!EncodeInteger(mModule->getContext(),
    269                        FieldName,
    270                        FieldInfo)) {
    271       return 0;
    272     }
    273 
    274     // 2. field type
    275     unsigned FieldType = encodeRSType(RS_RECORD_TYPE_GET_FIELD_TYPE(T, i));
    276     if (!checkReturnIndex(&FieldType))
    277       return 0;
    278     if (!EncodeInteger(mModule->getContext(),
    279                        FieldType,
    280                        FieldInfo)) {
    281       return 0;
    282     }
    283 
    284     // 3. field data kind
    285     if (!EncodeInteger(mModule->getContext(),
    286                        RS_RECORD_TYPE_GET_FIELD_DATA_KIND(T, i),
    287                        FieldInfo)) {
    288       return 0;
    289     }
    290 
    291     RecordInfoMetadata->addOperand(llvm::MDNode::get(mModule->getContext(),
    292                                                      FieldInfo));
    293     FieldInfo.clear();
    294   }
    295 
    296   return (Res.first->second + 1);
    297 }
    298 
    299 unsigned RSMetadataEncoderInternal::encodeRSType(const union RSType *T) {
    300   switch (static_cast<enum RSTypeClass>(RS_TYPE_GET_CLASS(T))) {
    301 #define ENUM_RS_DATA_TYPE_CLASS(x)  \
    302     case RS_TC_ ## x: return encode ## x ## Type(T);
    303     RS_DATA_TYPE_CLASS_ENUMS
    304 #undef ENUM_RS_DATA_TYPE_CLASS
    305     default: return 0;
    306   }
    307   return 0;
    308 }
    309 
    310 int RSMetadataEncoderInternal::encodeRSVar(const RSVar *V) {
    311   // check parameter
    312   if ((V == NULL) || (V->name == NULL) || (V->type == NULL))
    313     return -1;
    314 
    315   // 1. var name
    316   unsigned VarName = joinString(V->name);
    317   if (!checkReturnIndex(&VarName)) {
    318     return -2;
    319   }
    320 
    321   // 2. type
    322   unsigned Type = encodeRSType(V->type);
    323 
    324   llvm::SmallVector<llvm::Value*, 1> VarInfo;
    325 
    326   if (!EncodeInteger(mModule->getContext(), VarName, VarInfo)) {
    327     return -3;
    328   }
    329   if (!EncodeInteger(mModule->getContext(), Type, VarInfo)) {
    330     return -4;
    331   }
    332 
    333   if (mVarInfoMetadata == NULL)
    334     mVarInfoMetadata = mModule->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
    335 
    336   mVarInfoMetadata->addOperand(llvm::MDNode::get(mModule->getContext(),
    337                                                  VarInfo));
    338 
    339   return 0;
    340 }
    341 
    342 int RSMetadataEncoderInternal::encodeRSFunc(const RSFunction *F) {
    343   // check parameter
    344   if ((F == NULL) || (F->name == NULL)) {
    345     return -1;
    346   }
    347 
    348   // 1. var name
    349   unsigned FuncName = joinString(F->name);
    350   if (!checkReturnIndex(&FuncName)) {
    351     return -2;
    352   }
    353 
    354   llvm::SmallVector<llvm::Value*, 1> FuncInfo;
    355   if (!EncodeInteger(mModule->getContext(), FuncName, FuncInfo)) {
    356     return -3;
    357   }
    358 
    359   if (mFuncInfoMetadata == NULL)
    360     mFuncInfoMetadata = mModule->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
    361 
    362   mFuncInfoMetadata->addOperand(llvm::MDNode::get(mModule->getContext(),
    363                                                   FuncInfo));
    364 
    365   return 0;
    366 }
    367 
    368 // Write string table and string index table
    369 int RSMetadataEncoderInternal::flushStringTable() {
    370   slangAssert((mCurStringIndex == mEncodedStrings.size()));
    371   slangAssert((mCurStringIndex == mStrings.size()));
    372 
    373   if (mCurStringIndex == 0)
    374     return 0;
    375 
    376   // Prepare named MDNode for string table and string index table.
    377   llvm::NamedMDNode *RSMetadataStrTab =
    378       mModule->getOrInsertNamedMetadata(RS_METADATA_STRTAB_MN);
    379   RSMetadataStrTab->dropAllReferences();
    380 
    381   unsigned StrTabSize = 0;
    382   unsigned *StrIdx = reinterpret_cast<unsigned*>(
    383                         ::malloc((mStrings.size() + 1) * sizeof(unsigned)));
    384 
    385   if (StrIdx == NULL)
    386     return -1;
    387 
    388   unsigned StrIdxI = 0;  // iterator for array StrIdx
    389 
    390   // count StrTabSize and fill StrIdx by the way
    391   for (std::list<const char*>::const_iterator I = mEncodedStrings.begin(),
    392           E = mEncodedStrings.end();
    393        I != E;
    394        I++) {
    395     StrIdx[StrIdxI++] = StrTabSize;
    396     StrTabSize += ::strlen(*I) + 1 /* for '\0' */;
    397   }
    398   StrIdx[StrIdxI] = StrTabSize;
    399 
    400   // Allocate
    401   char *StrTab = reinterpret_cast<char*>(::malloc(StrTabSize));
    402   if (StrTab == NULL) {
    403     free(StrIdx);
    404     return -1;
    405   }
    406 
    407   llvm::StringRef StrTabData(StrTab, StrTabSize);
    408   llvm::StringRef StrIdxData(reinterpret_cast<const char*>(StrIdx),
    409                              mStrings.size() * sizeof(unsigned));
    410 
    411   // Copy
    412   StrIdxI = 1;
    413   for (std::list<const char*>::const_iterator I = mEncodedStrings.begin(),
    414           E = mEncodedStrings.end();
    415        I != E;
    416        I++) {
    417     // Get string length from StrIdx (O(1)) instead of call strlen again (O(n)).
    418     unsigned CurStrLength = StrIdx[StrIdxI] - StrIdx[StrIdxI - 1];
    419     ::memcpy(StrTab, *I, CurStrLength);
    420     // Move forward the pointer
    421     StrTab += CurStrLength;
    422     StrIdxI++;
    423   }
    424 
    425   // Flush to metadata
    426   llvm::Value *StrTabMDS =
    427       llvm::MDString::get(mModule->getContext(), StrTabData);
    428   llvm::Value *StrIdxMDS =
    429       llvm::MDString::get(mModule->getContext(), StrIdxData);
    430 
    431   if ((StrTabMDS == NULL) || (StrIdxMDS == NULL)) {
    432     free(StrIdx);
    433     free(StrTab);
    434     return -1;
    435   }
    436 
    437   llvm::SmallVector<llvm::Value*, 2> StrTabVal;
    438   StrTabVal.push_back(StrTabMDS);
    439   StrTabVal.push_back(StrIdxMDS);
    440   RSMetadataStrTab->addOperand(llvm::MDNode::get(mModule->getContext(),
    441                                                  StrTabVal));
    442 
    443   return 0;
    444 }
    445 
    446 // Write RS type stream
    447 int RSMetadataEncoderInternal::flushTypeInfo() {
    448   unsigned TypeInfoCount = mEncodedRSTypeInfo.size();
    449   if (TypeInfoCount <= 0) {
    450     return 0;
    451   }
    452 
    453   llvm::NamedMDNode *RSTypeInfo =
    454       mModule->getOrInsertNamedMetadata(RS_TYPE_INFO_MN);
    455   RSTypeInfo->dropAllReferences();
    456 
    457   unsigned *TypeInfos =
    458       reinterpret_cast<unsigned*>(::malloc(TypeInfoCount * sizeof(unsigned)));
    459   unsigned TypeInfosIdx = 0;  // iterator for array TypeInfos
    460 
    461   if (TypeInfos == NULL)
    462     return -1;
    463 
    464   for (std::list<unsigned>::const_iterator I = mEncodedRSTypeInfo.begin(),
    465           E = mEncodedRSTypeInfo.end();
    466        I != E;
    467        I++)
    468     TypeInfos[TypeInfosIdx++] = *I;
    469 
    470   llvm::StringRef TypeInfoData(reinterpret_cast<const char*>(TypeInfos),
    471                                TypeInfoCount * sizeof(unsigned));
    472   llvm::Value *TypeInfoMDS =
    473       llvm::MDString::get(mModule->getContext(), TypeInfoData);
    474   if (TypeInfoMDS == NULL) {
    475     free(TypeInfos);
    476     return -1;
    477   }
    478 
    479   llvm::SmallVector<llvm::Value*, 1> TypeInfo;
    480   TypeInfo.push_back(TypeInfoMDS);
    481 
    482   RSTypeInfo->addOperand(llvm::MDNode::get(mModule->getContext(),
    483                                            TypeInfo));
    484   free(TypeInfos);
    485 
    486   return 0;
    487 }
    488 
    489 int RSMetadataEncoderInternal::finalize() {
    490   int Res = flushStringTable();
    491   if (Res != 0)
    492     return Res;
    493 
    494   Res = flushTypeInfo();
    495   if (Res != 0)
    496     return Res;
    497 
    498   return 0;
    499 }
    500 
    501 ///////////////////////////////////////////////////////////////////////////////
    502 // APIs
    503 ///////////////////////////////////////////////////////////////////////////////
    504 RSMetadataEncoder *CreateRSMetadataEncoder(llvm::Module *M) {
    505   return reinterpret_cast<RSMetadataEncoder*>(new RSMetadataEncoderInternal(M));
    506 }
    507 
    508 int RSEncodeVarMetadata(RSMetadataEncoder *E, const RSVar *V) {
    509   return reinterpret_cast<RSMetadataEncoderInternal*>(E)->encodeRSVar(V);
    510 }
    511 
    512 int RSEncodeFunctionMetadata(RSMetadataEncoder *E, const RSFunction *F) {
    513   return reinterpret_cast<RSMetadataEncoderInternal*>(E)->encodeRSFunc(F);
    514 }
    515 
    516 void DestroyRSMetadataEncoder(RSMetadataEncoder *E) {
    517   RSMetadataEncoderInternal *C =
    518       reinterpret_cast<RSMetadataEncoderInternal*>(E);
    519   delete C;
    520   return;
    521 }
    522 
    523 int FinalizeRSMetadataEncoder(RSMetadataEncoder *E) {
    524   RSMetadataEncoderInternal *C =
    525       reinterpret_cast<RSMetadataEncoderInternal*>(E);
    526   int Res = C->finalize();
    527   DestroyRSMetadataEncoder(E);
    528   return Res;
    529 }
    530