Home | History | Annotate | Download | only in spirit
      1 /*
      2  * Copyright 2017, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef MODULE_H
     18 #define MODULE_H
     19 
     20 #include <iostream>
     21 #include <map>
     22 #include <vector>
     23 
     24 #include "core_defs.h"
     25 #include "entity.h"
     26 #include "instructions.h"
     27 #include "stl_util.h"
     28 #include "types_generated.h"
     29 #include "visitor.h"
     30 
     31 namespace android {
     32 namespace spirit {
     33 
     34 class Builder;
     35 class AnnotationSection;
     36 class CapabilityInst;
     37 class DebugInfoSection;
     38 class ExtensionInst;
     39 class ExtInstImportInst;
     40 class EntryPointInst;
     41 class ExecutionModeInst;
     42 class EntryPointDefinition;
     43 class FunctionDeclaration;
     44 class FunctionDefinition;
     45 class GlobalSection;
     46 class InputWordStream;
     47 class Instruction;
     48 class MemoryModelInst;
     49 
     50 union VersionNumber {
     51   struct {
     52     uint8_t mLowZero;
     53     uint8_t mMinorNumber;
     54     uint8_t mMajorNumber;
     55     uint8_t mHighZero;
     56   } mMajorMinor;
     57   uint8_t mBytes[4];
     58   uint32_t mWord;
     59 };
     60 
     61 class Module : public Entity {
     62 public:
     63   static Module *getCurrentModule();
     64   uint32_t nextId() { return mNextId++; }
     65 
     66   Module();
     67 
     68   Module(Builder *b);
     69 
     70   virtual ~Module() {}
     71 
     72   bool DeserializeInternal(InputWordStream &IS) override;
     73 
     74   void Serialize(OutputWordStream &OS) const override;
     75 
     76   void SerializeHeader(OutputWordStream &OS) const;
     77 
     78   void registerId(uint32_t id, Instruction *inst) {
     79     mIdTable.insert(std::make_pair(id, inst));
     80   }
     81 
     82   void initialize();
     83 
     84   bool resolveIds();
     85 
     86   void accept(IVisitor *v) override {
     87     for (auto cap : mCapabilities) {
     88       v->visit(cap);
     89     }
     90     for (auto ext : mExtensions) {
     91       v->visit(ext);
     92     }
     93     for (auto imp : mExtInstImports) {
     94       v->visit(imp);
     95     }
     96 
     97     v->visit(mMemoryModel.get());
     98 
     99     for (auto entry : mEntryPoints) {
    100       v->visit(entry);
    101     }
    102 
    103     for (auto mode : mExecutionModes) {
    104       v->visit(mode);
    105     }
    106 
    107     v->visit(mDebugInfo.get());
    108     if (mAnnotations) {
    109       v->visit(mAnnotations.get());
    110     }
    111     if (mGlobals) {
    112       v->visit(mGlobals.get());
    113     }
    114 
    115     for (auto def : mFunctionDefinitions) {
    116       v->visit(def);
    117     }
    118   }
    119 
    120   static std::ostream &errs() { return std::cerr; }
    121 
    122   Module *addCapability(Capability cap);
    123   Module *setMemoryModel(AddressingModel am, MemoryModel mm);
    124   Module *addExtInstImport(const char *extName);
    125   Module *addSource(SourceLanguage lang, int version);
    126   Module *addSourceExtension(const char *ext);
    127   Module *addString(const char *ext);
    128   Module *addEntryPoint(EntryPointDefinition *entry);
    129 
    130   ExtInstImportInst *getGLExt() const { return mGLExt; }
    131 
    132   const std::string findStringOfPrefix(const char *prefix) const;
    133 
    134   GlobalSection *getGlobalSection();
    135 
    136   Instruction *lookupByName(const char *) const;
    137   FunctionDefinition *
    138   getFunctionDefinitionFromInstruction(FunctionInst *) const;
    139   FunctionDefinition *lookupFunctionDefinitionByName(const char *name) const;
    140 
    141   // Find the name of the instruction, e.g., the name of a function (OpFunction
    142   // instruction).
    143   // The returned string is owned by the OpName instruction, whose first operand
    144   // is the instruction being queried on.
    145   const char *lookupNameByInstruction(const Instruction *) const;
    146 
    147   VariableInst *getInvocationId();
    148   VariableInst *getNumWorkgroups();
    149 
    150   // Adds a struct type built somewhere else.
    151   Module *addStructType(TypeStructInst *structType);
    152   Module *addVariable(VariableInst *var);
    153 
    154   // Methods to look up types. Create them if not found.
    155   TypeVoidInst *getVoidType();
    156   TypeIntInst *getIntType(int bits, bool isSigned = true);
    157   TypeIntInst *getUnsignedIntType(int bits);
    158   TypeFloatInst *getFloatType(int bits);
    159   TypeVectorInst *getVectorType(Instruction *componentType, int width);
    160   TypePointerInst *getPointerType(StorageClass storage,
    161                                   Instruction *pointeeType);
    162   TypeRuntimeArrayInst *getRuntimeArrayType(Instruction *elementType);
    163 
    164   // This implies that struct types are strictly structural equivalent, i.e.,
    165   // two structs are equivalent i.f.f. their fields are equivalent, recursively.
    166   TypeStructInst *getStructType(Instruction *fieldType[], int numField);
    167   TypeStructInst *getStructType(const std::vector<Instruction *> &fieldType);
    168   TypeStructInst *getStructType(Instruction *field0Type);
    169   TypeStructInst *getStructType(Instruction *field0Type,
    170                                 Instruction *field1Type);
    171   TypeStructInst *getStructType(Instruction *field0Type,
    172                                 Instruction *field1Type,
    173                                 Instruction *field2Type);
    174 
    175   // TODO: Can function types of different decorations be considered the same?
    176   TypeFunctionInst *getFunctionType(Instruction *retType,
    177                                     Instruction *const argType[],
    178                                     size_t numArg);
    179   TypeFunctionInst *getFunctionType(Instruction *retType,
    180                                     const std::vector<Instruction *> &argTypes);
    181 
    182   size_t getSize(TypeVoidInst *voidTy);
    183   size_t getSize(TypeIntInst *intTy);
    184   size_t getSize(TypeFloatInst *fpTy);
    185   size_t getSize(TypeVectorInst *vTy);
    186   size_t getSize(TypePointerInst *ptrTy);
    187   size_t getSize(TypeStructInst *structTy);
    188   size_t getSize(TypeFunctionInst *funcTy);
    189   size_t getSize(Instruction *inst);
    190 
    191   ConstantInst *getConstant(TypeIntInst *type, int32_t value);
    192   ConstantInst *getConstant(TypeIntInst *type, uint32_t value);
    193   ConstantInst *getConstant(TypeFloatInst *type, float value);
    194 
    195   ConstantCompositeInst *getConstantComposite(TypeVectorInst *type,
    196                                               ConstantInst *components[],
    197                                               size_t width);
    198   ConstantCompositeInst *
    199   getConstantComposite(Instruction *type,
    200                        const std::vector<ConstantInst *> &components);
    201   ConstantCompositeInst *getConstantComposite(Instruction *type,
    202                                               ConstantInst *comp0,
    203                                               ConstantInst *comp1);
    204   ConstantCompositeInst *getConstantComposite(TypeVectorInst *type,
    205                                               ConstantInst *comp0,
    206                                               ConstantInst *comp1,
    207                                               ConstantInst *comp2);
    208   ConstantCompositeInst *getConstantComposite(TypeVectorInst *type,
    209                                               ConstantInst *comp0,
    210                                               ConstantInst *comp1,
    211                                               ConstantInst *comp2,
    212                                               ConstantInst *comp3);
    213 
    214   Module *addFunctionDefinition(FunctionDefinition *func);
    215 
    216   void consolidateAnnotations();
    217 
    218 private:
    219   static Module *mInstance;
    220   uint32_t mNextId;
    221   std::map<uint32_t, Instruction *> mIdTable;
    222 
    223   uint32_t mMagicNumber;
    224   VersionNumber mVersion;
    225   uint32_t mGeneratorMagicNumber;
    226   uint32_t mBound;
    227   uint32_t mReserved;
    228 
    229   std::vector<CapabilityInst *> mCapabilities;
    230   std::vector<ExtensionInst *> mExtensions;
    231   std::vector<ExtInstImportInst *> mExtInstImports;
    232   std::unique_ptr<MemoryModelInst> mMemoryModel;
    233   std::vector<EntryPointInst *> mEntryPointInsts;
    234   std::vector<ExecutionModeInst *> mExecutionModes;
    235   std::vector<EntryPointDefinition *> mEntryPoints;
    236   std::unique_ptr<DebugInfoSection> mDebugInfo;
    237   std::unique_ptr<AnnotationSection> mAnnotations;
    238   std::unique_ptr<GlobalSection> mGlobals;
    239   std::vector<FunctionDefinition *> mFunctionDefinitions;
    240 
    241   ExtInstImportInst *mGLExt;
    242 
    243   ContainerDeleter<std::vector<CapabilityInst *>> mCapabilitiesDeleter;
    244   ContainerDeleter<std::vector<ExtensionInst *>> mExtensionsDeleter;
    245   ContainerDeleter<std::vector<ExtInstImportInst *>> mExtInstImportsDeleter;
    246   ContainerDeleter<std::vector<EntryPointInst *>> mEntryPointInstsDeleter;
    247   ContainerDeleter<std::vector<ExecutionModeInst *>> mExecutionModesDeleter;
    248   ContainerDeleter<std::vector<EntryPointDefinition *>> mEntryPointsDeleter;
    249   ContainerDeleter<std::vector<FunctionDefinition *>>
    250       mFunctionDefinitionsDeleter;
    251 };
    252 
    253 struct Extent3D {
    254   uint32_t mWidth;
    255   uint32_t mHeight;
    256   uint32_t mDepth;
    257 };
    258 
    259 class EntryPointDefinition : public Entity {
    260 public:
    261   EntryPointDefinition() {}
    262   EntryPointDefinition(Builder *builder, ExecutionModel execModel,
    263                        FunctionDefinition *func, const char *name);
    264 
    265   virtual ~EntryPointDefinition() {
    266     // Nothing to do here since ~Module() will delete entities referenced here
    267   }
    268 
    269   void accept(IVisitor *visitor) override {
    270     visitor->visit(mEntryPointInst);
    271     // Do not visit the ExecutionMode instructions here. They are linked here
    272     // for convinience, and for convinience only. They are all grouped, stored,
    273     // and serialized directly in the module in a section right after all
    274     // EntryPoint instructions. Visit them from there.
    275   }
    276 
    277   bool DeserializeInternal(InputWordStream &IS) override;
    278 
    279   EntryPointDefinition *addToInterface(VariableInst *var);
    280   EntryPointDefinition *addExecutionMode(ExecutionModeInst *mode) {
    281     mExecutionModeInsts.push_back(mode);
    282     return this;
    283   }
    284   const std::vector<ExecutionModeInst *> &getExecutionModes() const {
    285     return mExecutionModeInsts;
    286   }
    287 
    288   EntryPointDefinition *setLocalSize(uint32_t width, uint32_t height,
    289                                      uint32_t depth);
    290 
    291   EntryPointDefinition *applyExecutionMode(ExecutionModeInst *mode);
    292 
    293   EntryPointInst *getInstruction() const { return mEntryPointInst; }
    294 
    295 private:
    296   const char *mName;
    297   FunctionInst *mFunction;
    298   ExecutionModel mExecutionModel;
    299   std::vector<VariableInst *> mInterface;
    300   Extent3D mLocalSize;
    301 
    302   EntryPointInst *mEntryPointInst;
    303   std::vector<ExecutionModeInst *> mExecutionModeInsts;
    304 };
    305 
    306 class DebugInfoSection : public Entity {
    307 public:
    308   DebugInfoSection() : mSourcesDeleter(mSources), mNamesDeleter(mNames) {}
    309   DebugInfoSection(Builder *b)
    310       : Entity(b), mSourcesDeleter(mSources), mNamesDeleter(mNames) {}
    311 
    312   virtual ~DebugInfoSection() {}
    313 
    314   bool DeserializeInternal(InputWordStream &IS) override;
    315 
    316   DebugInfoSection *addSource(SourceLanguage lang, int version);
    317   DebugInfoSection *addSourceExtension(const char *ext);
    318   DebugInfoSection *addString(const char *str);
    319 
    320   std::string findStringOfPrefix(const char *prefix);
    321 
    322   Instruction *lookupByName(const char *name) const;
    323   const char *lookupNameByInstruction(const Instruction *) const;
    324 
    325   void accept(IVisitor *v) override {
    326     for (auto source : mSources) {
    327       v->visit(source);
    328     }
    329     for (auto name : mNames) {
    330       v->visit(name);
    331     }
    332   }
    333 
    334 private:
    335   // (OpString|OpSource|OpSourceExtension|OpSourceContinued)*
    336   std::vector<Instruction *> mSources;
    337   // (OpName|OpMemberName)*
    338   std::vector<Instruction *> mNames;
    339 
    340   ContainerDeleter<std::vector<Instruction *>> mSourcesDeleter;
    341   ContainerDeleter<std::vector<Instruction *>> mNamesDeleter;
    342 };
    343 
    344 class AnnotationSection : public Entity {
    345 public:
    346   AnnotationSection();
    347   AnnotationSection(Builder *b);
    348 
    349   virtual ~AnnotationSection() {}
    350 
    351   bool DeserializeInternal(InputWordStream &IS) override;
    352 
    353   void accept(IVisitor *v) override {
    354     for (auto inst : mAnnotations) {
    355       v->visit(inst);
    356     }
    357   }
    358 
    359   template <typename T> void addAnnotations(T begin, T end) {
    360     mAnnotations.insert<T>(std::end(mAnnotations), begin, end);
    361   }
    362 
    363   std::vector<Instruction *>::const_iterator begin() const {
    364     return mAnnotations.begin();
    365   }
    366 
    367   std::vector<Instruction *>::const_iterator end() const {
    368     return mAnnotations.end();
    369   }
    370 
    371   void clear() { mAnnotations.clear(); }
    372 
    373 private:
    374   std::vector<Instruction *> mAnnotations; // OpDecorate, etc.
    375 
    376   ContainerDeleter<std::vector<Instruction *>> mAnnotationsDeleter;
    377 };
    378 
    379 // Types, constants, and globals
    380 class GlobalSection : public Entity {
    381 public:
    382   GlobalSection();
    383   GlobalSection(Builder *builder);
    384 
    385   virtual ~GlobalSection() {}
    386 
    387   bool DeserializeInternal(InputWordStream &IS) override;
    388 
    389   void accept(IVisitor *v) override {
    390     for (auto inst : mGlobalDefs) {
    391       v->visit(inst);
    392     }
    393 
    394     if (mInvocationId) {
    395       v->visit(mInvocationId.get());
    396     }
    397 
    398     if (mNumWorkgroups) {
    399       v->visit(mNumWorkgroups.get());
    400     }
    401   }
    402 
    403   ConstantInst *getConstant(TypeIntInst *type, int32_t value);
    404   ConstantInst *getConstant(TypeIntInst *type, uint32_t value);
    405   ConstantInst *getConstant(TypeFloatInst *type, float value);
    406   ConstantCompositeInst *getConstantComposite(TypeVectorInst *type,
    407                                               ConstantInst *components[],
    408                                               size_t width);
    409 
    410   // Methods to look up types. Create them if not found.
    411   TypeVoidInst *getVoidType();
    412   TypeIntInst *getIntType(int bits, bool isSigned = true);
    413   TypeFloatInst *getFloatType(int bits);
    414   TypeVectorInst *getVectorType(Instruction *componentType, int width);
    415   TypePointerInst *getPointerType(StorageClass storage,
    416                                   Instruction *pointeeType);
    417   TypeRuntimeArrayInst *getRuntimeArrayType(Instruction *elementType);
    418 
    419   // This implies that struct types are strictly structural equivalent, i.e.,
    420   // two structs are equivalent i.f.f. their fields are equivalent, recursively.
    421   TypeStructInst *getStructType(Instruction *fieldType[], int numField);
    422   // TypeStructInst *getStructType(const std::vector<Instruction *>
    423   // &fieldTypes);
    424 
    425   // TODO: Can function types of different decorations be considered the same?
    426   TypeFunctionInst *getFunctionType(Instruction *retType,
    427                                     Instruction *const argType[],
    428                                     size_t numArg);
    429   // TypeStructInst *addStructType(Instruction *fieldType[], int numField);
    430   GlobalSection *addStructType(TypeStructInst *structType);
    431   GlobalSection *addVariable(VariableInst *var);
    432 
    433   VariableInst *getInvocationId();
    434   VariableInst *getNumWorkgroups();
    435 
    436 private:
    437   // TODO: Add structure to this.
    438   // Separate types, constants, variables, etc.
    439   std::vector<Instruction *> mGlobalDefs;
    440   std::unique_ptr<VariableInst> mInvocationId;
    441   std::unique_ptr<VariableInst> mNumWorkgroups;
    442 
    443   ContainerDeleter<std::vector<Instruction *>> mGlobalDefsDeleter;
    444 };
    445 
    446 class FunctionDeclaration : public Entity {
    447 public:
    448   virtual ~FunctionDeclaration() {}
    449 
    450   bool DeserializeInternal(InputWordStream &IS) override;
    451 
    452   void accept(IVisitor *v) override {
    453     v->visit(mFunc);
    454     for (auto param : mParams) {
    455       v->visit(param);
    456     }
    457     v->visit(mFuncEnd);
    458   }
    459 
    460 private:
    461   FunctionInst *mFunc;
    462   std::vector<FunctionParameterInst *> mParams;
    463   FunctionEndInst *mFuncEnd;
    464 };
    465 
    466 class Block : public Entity {
    467 public:
    468   Block() {}
    469   Block(Builder *b) : Entity(b) {}
    470 
    471   virtual ~Block() {}
    472 
    473   bool DeserializeInternal(InputWordStream &IS) override;
    474 
    475   void accept(IVisitor *v) override {
    476     for (auto inst : mInsts) {
    477       v->visit(inst);
    478     }
    479   }
    480 
    481   Block *addInstruction(Instruction *inst) {
    482     mInsts.push_back(inst);
    483     return this;
    484   }
    485 
    486 private:
    487   std::vector<Instruction *> mInsts;
    488 };
    489 
    490 class FunctionDefinition : public Entity {
    491 public:
    492   FunctionDefinition();
    493   FunctionDefinition(Builder *builder, FunctionInst *func,
    494                      FunctionEndInst *end);
    495 
    496   virtual ~FunctionDefinition() {}
    497 
    498   bool DeserializeInternal(InputWordStream &IS) override;
    499 
    500   void accept(IVisitor *v) override {
    501     v->visit(mFunc.get());
    502     for (auto param : mParams) {
    503       v->visit(param);
    504     }
    505     for (auto block : mBlocks) {
    506       v->visit(block);
    507     }
    508     v->visit(mFuncEnd.get());
    509   }
    510 
    511   FunctionDefinition *addBlock(Block *b) {
    512     mBlocks.push_back(b);
    513     return this;
    514   }
    515 
    516   FunctionInst *getInstruction() const { return mFunc.get(); }
    517   FunctionParameterInst *getParameter(uint32_t i) const { return mParams[i]; }
    518 
    519   Instruction *getReturnType() const;
    520 
    521 private:
    522   std::unique_ptr<FunctionInst> mFunc;
    523   std::vector<FunctionParameterInst *> mParams;
    524   std::vector<Block *> mBlocks;
    525   std::unique_ptr<FunctionEndInst> mFuncEnd;
    526 
    527   ContainerDeleter<std::vector<FunctionParameterInst *>> mParamsDeleter;
    528   ContainerDeleter<std::vector<Block *>> mBlocksDeleter;
    529 };
    530 
    531 } // namespace spirit
    532 } // namespace android
    533 
    534 #endif // MODULE_H
    535