Home | History | Annotate | Download | only in spirit
      1 /*
      2  * Copyright 2017, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "module.h"
     18 
     19 #include <set>
     20 
     21 #include "builder.h"
     22 #include "core_defs.h"
     23 #include "instructions.h"
     24 #include "types_generated.h"
     25 #include "word_stream.h"
     26 
     27 namespace android {
     28 namespace spirit {
     29 
     30 Module *Module::mInstance = nullptr;
     31 
     32 Module *Module::getCurrentModule() {
     33   if (mInstance == nullptr) {
     34     return mInstance = new Module();
     35   }
     36   return mInstance;
     37 }
     38 
     39 Module::Module()
     40     : mNextId(1), mCapabilitiesDeleter(mCapabilities),
     41       mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports),
     42       mEntryPointInstsDeleter(mEntryPointInsts),
     43       mExecutionModesDeleter(mExecutionModes),
     44       mEntryPointsDeleter(mEntryPoints),
     45       mFunctionDefinitionsDeleter(mFunctionDefinitions) {
     46   mInstance = this;
     47 }
     48 
     49 Module::Module(Builder *b)
     50     : Entity(b), mNextId(1), mCapabilitiesDeleter(mCapabilities),
     51       mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports),
     52       mEntryPointInstsDeleter(mEntryPointInsts),
     53       mExecutionModesDeleter(mExecutionModes),
     54       mEntryPointsDeleter(mEntryPoints),
     55       mFunctionDefinitionsDeleter(mFunctionDefinitions) {
     56   mInstance = this;
     57 }
     58 
     59 bool Module::resolveIds() {
     60   auto &table = mIdTable;
     61 
     62   std::unique_ptr<IVisitor> v0(
     63       CreateInstructionVisitor([&table](Instruction *inst) {
     64         if (inst->hasResult()) {
     65           table.insert(std::make_pair(inst->getId(), inst));
     66         }
     67       }));
     68   v0->visit(this);
     69 
     70   mNextId = mIdTable.rbegin()->first + 1;
     71 
     72   int err = 0;
     73   std::unique_ptr<IVisitor> v(
     74       CreateInstructionVisitor([&table, &err](Instruction *inst) {
     75         for (auto ref : inst->getAllIdRefs()) {
     76           if (ref) {
     77             auto it = table.find(ref->mId);
     78             if (it != table.end()) {
     79               ref->mInstruction = it->second;
     80             } else {
     81               std::cout << "Found no instruction for id " << ref->mId
     82                         << std::endl;
     83               err++;
     84             }
     85           }
     86         }
     87       }));
     88   v->visit(this);
     89   return err == 0;
     90 }
     91 
     92 bool Module::DeserializeInternal(InputWordStream &IS) {
     93   if (IS.empty()) {
     94     return false;
     95   }
     96 
     97   IS >> &mMagicNumber;
     98   if (mMagicNumber != 0x07230203) {
     99     errs() << "Wrong Magic Number: " << mMagicNumber;
    100     return false;
    101   }
    102 
    103   if (IS.empty()) {
    104     return false;
    105   }
    106 
    107   IS >> &mVersion.mWord;
    108   if (mVersion.mBytes[0] != 0 || mVersion.mBytes[3] != 0) {
    109     return false;
    110   }
    111 
    112   if (IS.empty()) {
    113     return false;
    114   }
    115 
    116   IS >> &mGeneratorMagicNumber >> &mBound >> &mReserved;
    117 
    118   DeserializeZeroOrMore<CapabilityInst>(IS, mCapabilities);
    119   DeserializeZeroOrMore<ExtensionInst>(IS, mExtensions);
    120   DeserializeZeroOrMore<ExtInstImportInst>(IS, mExtInstImports);
    121 
    122   mMemoryModel.reset(Deserialize<MemoryModelInst>(IS));
    123   if (!mMemoryModel) {
    124     errs() << "Missing memory model specification.\n";
    125     return false;
    126   }
    127 
    128   DeserializeZeroOrMore<EntryPointDefinition>(IS, mEntryPoints);
    129   DeserializeZeroOrMore<ExecutionModeInst>(IS, mExecutionModes);
    130   for (auto entry : mEntryPoints) {
    131     mEntryPointInsts.push_back(entry->getInstruction());
    132     for (auto mode : mExecutionModes) {
    133       entry->applyExecutionMode(mode);
    134     }
    135   }
    136 
    137   mDebugInfo.reset(Deserialize<DebugInfoSection>(IS));
    138   mAnnotations.reset(Deserialize<AnnotationSection>(IS));
    139   mGlobals.reset(Deserialize<GlobalSection>(IS));
    140 
    141   DeserializeZeroOrMore<FunctionDefinition>(IS, mFunctionDefinitions);
    142 
    143   if (mFunctionDefinitions.empty()) {
    144     errs() << "Missing function definitions.\n";
    145     for (int i = 0; i < 4; i++) {
    146       uint32_t w;
    147       IS >> &w;
    148       std::cout << std::hex << w << " ";
    149     }
    150     std::cout << std::endl;
    151     return false;
    152   }
    153 
    154   return true;
    155 }
    156 
    157 void Module::initialize() {
    158   mMagicNumber = 0x07230203;
    159   mVersion.mMajorMinor = {.mMinorNumber = 1, .mMajorNumber = 1};
    160   mGeneratorMagicNumber = 0x00070000;
    161   mBound = 0;
    162   mReserved = 0;
    163   mAnnotations.reset(new AnnotationSection());
    164 }
    165 
    166 void Module::SerializeHeader(OutputWordStream &OS) const {
    167   OS << mMagicNumber;
    168   OS << mVersion.mWord << mGeneratorMagicNumber;
    169   if (mBound == 0) {
    170     OS << mIdTable.end()->first + 1;
    171   } else {
    172     OS << std::max(mBound, mNextId);
    173   }
    174   OS << mReserved;
    175 }
    176 
    177 void Module::Serialize(OutputWordStream &OS) const {
    178   SerializeHeader(OS);
    179   Entity::Serialize(OS);
    180 }
    181 
    182 Module *Module::addCapability(Capability cap) {
    183   mCapabilities.push_back(mBuilder->MakeCapability(cap));
    184   return this;
    185 }
    186 
    187 Module *Module::setMemoryModel(AddressingModel am, MemoryModel mm) {
    188   mMemoryModel.reset(mBuilder->MakeMemoryModel(am, mm));
    189   return this;
    190 }
    191 
    192 Module *Module::addExtInstImport(const char *extName) {
    193   ExtInstImportInst *extInst = mBuilder->MakeExtInstImport(extName);
    194   mExtInstImports.push_back(extInst);
    195   if (strcmp(extName, "GLSL.std.450") == 0) {
    196     mGLExt = extInst;
    197   }
    198   return this;
    199 }
    200 
    201 Module *Module::addSource(SourceLanguage lang, int version) {
    202   if (!mDebugInfo) {
    203     mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
    204   }
    205   mDebugInfo->addSource(lang, version);
    206   return this;
    207 }
    208 
    209 Module *Module::addSourceExtension(const char *ext) {
    210   if (!mDebugInfo) {
    211     mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
    212   }
    213   mDebugInfo->addSourceExtension(ext);
    214   return this;
    215 }
    216 
    217 Module *Module::addString(const char *str) {
    218   if (!mDebugInfo) {
    219     mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
    220   }
    221   mDebugInfo->addString(str);
    222   return this;
    223 }
    224 
    225 Module *Module::addEntryPoint(EntryPointDefinition *entry) {
    226   mEntryPoints.push_back(entry);
    227   auto newModes = entry->getExecutionModes();
    228   mExecutionModes.insert(mExecutionModes.end(), newModes.begin(),
    229                          newModes.end());
    230   return this;
    231 }
    232 
    233 GlobalSection *Module::getGlobalSection() {
    234   if (!mGlobals) {
    235     mGlobals.reset(new GlobalSection());
    236   }
    237   return mGlobals.get();
    238 }
    239 
    240 ConstantInst *Module::getConstant(TypeIntInst *type, int32_t value) {
    241   return getGlobalSection()->getConstant(type, value);
    242 }
    243 
    244 ConstantInst *Module::getConstant(TypeIntInst *type, uint32_t value) {
    245   return getGlobalSection()->getConstant(type, value);
    246 }
    247 
    248 ConstantInst *Module::getConstant(TypeFloatInst *type, float value) {
    249   return getGlobalSection()->getConstant(type, value);
    250 }
    251 
    252 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
    253                                                     ConstantInst *components[],
    254                                                     size_t width) {
    255   return getGlobalSection()->getConstantComposite(type, components, width);
    256 }
    257 
    258 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
    259                                                     ConstantInst *comp0,
    260                                                     ConstantInst *comp1,
    261                                                     ConstantInst *comp2) {
    262   // TODO: verify that component types are the same and consistent with the
    263   // resulting vector type
    264   ConstantInst *comps[] = {comp0, comp1, comp2};
    265   return getConstantComposite(type, comps, 3);
    266 }
    267 
    268 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
    269                                                     ConstantInst *comp0,
    270                                                     ConstantInst *comp1,
    271                                                     ConstantInst *comp2,
    272                                                     ConstantInst *comp3) {
    273   // TODO: verify that component types are the same and consistent with the
    274   // resulting vector type
    275   ConstantInst *comps[] = {comp0, comp1, comp2, comp3};
    276   return getConstantComposite(type, comps, 4);
    277 }
    278 
    279 TypeVoidInst *Module::getVoidType() {
    280   return getGlobalSection()->getVoidType();
    281 }
    282 
    283 TypeIntInst *Module::getIntType(int bits, bool isSigned) {
    284   return getGlobalSection()->getIntType(bits, isSigned);
    285 }
    286 
    287 TypeIntInst *Module::getUnsignedIntType(int bits) {
    288   return getIntType(bits, false);
    289 }
    290 
    291 TypeFloatInst *Module::getFloatType(int bits) {
    292   return getGlobalSection()->getFloatType(bits);
    293 }
    294 
    295 TypeVectorInst *Module::getVectorType(Instruction *componentType, int width) {
    296   return getGlobalSection()->getVectorType(componentType, width);
    297 }
    298 
    299 TypePointerInst *Module::getPointerType(StorageClass storage,
    300                                         Instruction *pointeeType) {
    301   return getGlobalSection()->getPointerType(storage, pointeeType);
    302 }
    303 
    304 TypeRuntimeArrayInst *Module::getRuntimeArrayType(Instruction *elementType) {
    305   return getGlobalSection()->getRuntimeArrayType(elementType);
    306 }
    307 
    308 TypeStructInst *Module::getStructType(Instruction *fieldType[], int numField) {
    309   return getGlobalSection()->getStructType(fieldType, numField);
    310 }
    311 
    312 TypeStructInst *Module::getStructType(Instruction *fieldType) {
    313   return getStructType(&fieldType, 1);
    314 }
    315 
    316 TypeFunctionInst *Module::getFunctionType(Instruction *retType,
    317                                           Instruction *const argType[],
    318                                           size_t numArg) {
    319   return getGlobalSection()->getFunctionType(retType, argType, numArg);
    320 }
    321 
    322 TypeFunctionInst *
    323 Module::getFunctionType(Instruction *retType,
    324                         const std::vector<Instruction *> &argTypes) {
    325   return getGlobalSection()->getFunctionType(retType, argTypes.data(),
    326                                              argTypes.size());
    327 }
    328 
    329 size_t Module::getSize(TypeVoidInst *) { return 0; }
    330 
    331 size_t Module::getSize(TypeIntInst *intTy) { return intTy->mOperand1 / 8; }
    332 
    333 size_t Module::getSize(TypeFloatInst *fpTy) { return fpTy->mOperand1 / 8; }
    334 
    335 size_t Module::getSize(TypeVectorInst *vTy) {
    336   return getSize(vTy->mOperand1.mInstruction) * vTy->mOperand2;
    337 }
    338 
    339 size_t Module::getSize(TypePointerInst *) {
    340   return 4; // TODO: or 8?
    341 }
    342 
    343 size_t Module::getSize(TypeStructInst *structTy) {
    344   size_t sz = 0;
    345   for (auto ty : structTy->mOperand1) {
    346     sz += getSize(ty.mInstruction);
    347   }
    348   return sz;
    349 }
    350 
    351 size_t Module::getSize(TypeFunctionInst *) {
    352   return 4; // TODO: or 8? Is this just the size of a pointer?
    353 }
    354 
    355 size_t Module::getSize(Instruction *inst) {
    356   switch (inst->getOpCode()) {
    357   case OpTypeVoid:
    358     return getSize(static_cast<TypeVoidInst *>(inst));
    359   case OpTypeInt:
    360     return getSize(static_cast<TypeIntInst *>(inst));
    361   case OpTypeFloat:
    362     return getSize(static_cast<TypeFloatInst *>(inst));
    363   case OpTypeVector:
    364     return getSize(static_cast<TypeVectorInst *>(inst));
    365   case OpTypeStruct:
    366     return getSize(static_cast<TypeStructInst *>(inst));
    367   case OpTypeFunction:
    368     return getSize(static_cast<TypeFunctionInst *>(inst));
    369   default:
    370     return 0;
    371   }
    372 }
    373 
    374 Module *Module::addFunctionDefinition(FunctionDefinition *func) {
    375   mFunctionDefinitions.push_back(func);
    376   return this;
    377 }
    378 
    379 Instruction *Module::lookupByName(const char *name) const {
    380   return mDebugInfo->lookupByName(name);
    381 }
    382 
    383 FunctionDefinition *
    384 Module::getFunctionDefinitionFromInstruction(FunctionInst *inst) const {
    385   for (auto fdef : mFunctionDefinitions) {
    386     if (fdef->getInstruction() == inst) {
    387       return fdef;
    388     }
    389   }
    390   return nullptr;
    391 }
    392 
    393 FunctionDefinition *
    394 Module::lookupFunctionDefinitionByName(const char *name) const {
    395   FunctionInst *inst = static_cast<FunctionInst *>(lookupByName(name));
    396   return getFunctionDefinitionFromInstruction(inst);
    397 }
    398 
    399 const char *Module::lookupNameByInstruction(const Instruction *inst) const {
    400   return mDebugInfo->lookupNameByInstruction(inst);
    401 }
    402 
    403 VariableInst *Module::getInvocationId() {
    404   return getGlobalSection()->getInvocationId();
    405 }
    406 
    407 VariableInst *Module::getNumWorkgroups() {
    408   return getGlobalSection()->getNumWorkgroups();
    409 }
    410 
    411 Module *Module::addStructType(TypeStructInst *structType) {
    412   getGlobalSection()->addStructType(structType);
    413   return this;
    414 }
    415 
    416 Module *Module::addVariable(VariableInst *var) {
    417   getGlobalSection()->addVariable(var);
    418   return this;
    419 }
    420 
    421 void Module::consolidateAnnotations() {
    422   std::vector<Instruction *> annotations(mAnnotations->begin(),
    423                                       mAnnotations->end());
    424   std::unique_ptr<IVisitor> v(
    425       CreateInstructionVisitor([&annotations](Instruction *inst) -> void {
    426         const auto &ann = inst->getAnnotations();
    427         annotations.insert(annotations.end(), ann.begin(), ann.end());
    428       }));
    429   v->visit(this);
    430   mAnnotations->clear();
    431   mAnnotations->addAnnotations(annotations.begin(), annotations.end());
    432 }
    433 
    434 EntryPointDefinition::EntryPointDefinition(Builder *builder,
    435                                            ExecutionModel execModel,
    436                                            FunctionDefinition *func,
    437                                            const char *name)
    438     : Entity(builder), mFunction(func->getInstruction()),
    439       mExecutionModel(execModel) {
    440   mName = strndup(name, strlen(name));
    441   mEntryPointInst = mBuilder->MakeEntryPoint(execModel, mFunction, mName);
    442 }
    443 
    444 bool EntryPointDefinition::DeserializeInternal(InputWordStream &IS) {
    445   if (IS.empty()) {
    446     return false;
    447   }
    448 
    449   if ((mEntryPointInst = Deserialize<EntryPointInst>(IS))) {
    450     return true;
    451   }
    452 
    453   return false;
    454 }
    455 
    456 EntryPointDefinition *
    457 EntryPointDefinition::applyExecutionMode(ExecutionModeInst *mode) {
    458   if (mode->mOperand1.mInstruction == mFunction) {
    459     addExecutionMode(mode);
    460   }
    461   return this;
    462 }
    463 
    464 EntryPointDefinition *EntryPointDefinition::addToInterface(VariableInst *var) {
    465   mInterface.push_back(var);
    466   mEntryPointInst->mOperand4.push_back(var);
    467   return this;
    468 }
    469 
    470 EntryPointDefinition *EntryPointDefinition::setLocalSize(uint32_t width,
    471                                                          uint32_t height,
    472                                                          uint32_t depth) {
    473   mLocalSize.mWidth = width;
    474   mLocalSize.mHeight = height;
    475   mLocalSize.mDepth = depth;
    476 
    477   auto mode = mBuilder->MakeExecutionMode(mFunction, ExecutionMode::LocalSize);
    478   mode->addExtraOperand(width)->addExtraOperand(height)->addExtraOperand(depth);
    479 
    480   addExecutionMode(mode);
    481 
    482   return this;
    483 }
    484 
    485 bool DebugInfoSection::DeserializeInternal(InputWordStream &IS) {
    486   while (true) {
    487     if (auto str = Deserialize<StringInst>(IS)) {
    488       mSources.push_back(str);
    489     } else if (auto src = Deserialize<SourceInst>(IS)) {
    490       mSources.push_back(src);
    491     } else if (auto srcExt = Deserialize<SourceExtensionInst>(IS)) {
    492       mSources.push_back(srcExt);
    493     } else if (auto srcCont = Deserialize<SourceContinuedInst>(IS)) {
    494       mSources.push_back(srcCont);
    495     } else {
    496       break;
    497     }
    498   }
    499 
    500   while (true) {
    501     if (auto name = Deserialize<NameInst>(IS)) {
    502       mNames.push_back(name);
    503     } else if (auto memName = Deserialize<MemberNameInst>(IS)) {
    504       mNames.push_back(memName);
    505     } else {
    506       break;
    507     }
    508   }
    509 
    510   return true;
    511 }
    512 
    513 DebugInfoSection *DebugInfoSection::addSource(SourceLanguage lang,
    514                                               int version) {
    515   SourceInst *source = mBuilder->MakeSource(lang, version);
    516   mSources.push_back(source);
    517   return this;
    518 }
    519 
    520 DebugInfoSection *DebugInfoSection::addSourceExtension(const char *ext) {
    521   SourceExtensionInst *inst = mBuilder->MakeSourceExtension(ext);
    522   mSources.push_back(inst);
    523   return this;
    524 }
    525 
    526 DebugInfoSection *DebugInfoSection::addString(const char *str) {
    527   StringInst *source = mBuilder->MakeString(str);
    528   mSources.push_back(source);
    529   return this;
    530 }
    531 
    532 Instruction *DebugInfoSection::lookupByName(const char *name) const {
    533   for (auto inst : mNames) {
    534     if (inst->getOpCode() == OpName) {
    535       NameInst *nameInst = static_cast<NameInst *>(inst);
    536       if (nameInst->mOperand2.compare(name) == 0) {
    537         return nameInst->mOperand1.mInstruction;
    538       }
    539     }
    540     // Ignore member names
    541   }
    542   return nullptr;
    543 }
    544 
    545 const char *
    546 DebugInfoSection::lookupNameByInstruction(const Instruction *target) const {
    547   for (auto inst : mNames) {
    548     if (inst->getOpCode() == OpName) {
    549       NameInst *nameInst = static_cast<NameInst *>(inst);
    550       if (nameInst->mOperand1.mInstruction == target) {
    551         return nameInst->mOperand2.c_str();
    552       }
    553     }
    554     // Ignore member names
    555   }
    556   return nullptr;
    557 }
    558 
    559 AnnotationSection::AnnotationSection() : mAnnotationsDeleter(mAnnotations) {}
    560 
    561 AnnotationSection::AnnotationSection(Builder *b)
    562     : Entity(b), mAnnotationsDeleter(mAnnotations) {}
    563 
    564 bool AnnotationSection::DeserializeInternal(InputWordStream &IS) {
    565   while (true) {
    566     if (auto decor = Deserialize<DecorateInst>(IS)) {
    567       mAnnotations.push_back(decor);
    568     } else if (auto decor = Deserialize<MemberDecorateInst>(IS)) {
    569       mAnnotations.push_back(decor);
    570     } else if (auto decor = Deserialize<GroupDecorateInst>(IS)) {
    571       mAnnotations.push_back(decor);
    572     } else if (auto decor = Deserialize<GroupMemberDecorateInst>(IS)) {
    573       mAnnotations.push_back(decor);
    574     } else if (auto decor = Deserialize<DecorationGroupInst>(IS)) {
    575       mAnnotations.push_back(decor);
    576     } else {
    577       break;
    578     }
    579   }
    580   return true;
    581 }
    582 
    583 GlobalSection::GlobalSection() : mGlobalDefsDeleter(mGlobalDefs) {}
    584 
    585 GlobalSection::GlobalSection(Builder *builder)
    586     : Entity(builder), mGlobalDefsDeleter(mGlobalDefs) {}
    587 
    588 namespace {
    589 
    590 template <typename T>
    591 T *findOrCreate(std::function<bool(T *)> criteria, std::function<T *()> factory,
    592                 std::vector<Instruction *> *globals) {
    593   T *derived;
    594   for (auto inst : *globals) {
    595     if (inst->getOpCode() == T::mOpCode) {
    596       T *derived = static_cast<T *>(inst);
    597       if (criteria(derived)) {
    598         return derived;
    599       }
    600     }
    601   }
    602   derived = factory();
    603   globals->push_back(derived);
    604   return derived;
    605 }
    606 
    607 } // anonymous namespace
    608 
    609 bool GlobalSection::DeserializeInternal(InputWordStream &IS) {
    610   while (true) {
    611 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS)                                 \
    612   if (auto typeInst = Deserialize<INST_CLASS>(IS)) {                           \
    613     mGlobalDefs.push_back(typeInst);                                           \
    614     continue;                                                                  \
    615   }
    616 #include "const_inst_dispatches_generated.h"
    617 #include "type_inst_dispatches_generated.h"
    618 #undef HANDLE_INSTRUCTION
    619 
    620     if (auto globalInst = Deserialize<VariableInst>(IS)) {
    621       // Check if this is function scoped
    622       if (globalInst->mOperand1 == StorageClass::Function) {
    623         Module::errs() << "warning: Variable (id = " << globalInst->mResult;
    624         Module::errs() << ") has function scope in global section.\n";
    625         // Khronos LLVM-SPIRV convertor emits "Function" storage-class globals.
    626         // As a workaround, accept such SPIR-V code here, and fix it up later
    627         // in the rs2spirv compiler by correcting the storage class.
    628         // In a stricter deserializer, such code should be rejected, and we
    629         // should return false here.
    630       }
    631       mGlobalDefs.push_back(globalInst);
    632       continue;
    633     }
    634 
    635     if (auto globalInst = Deserialize<UndefInst>(IS)) {
    636       mGlobalDefs.push_back(globalInst);
    637       continue;
    638     }
    639     break;
    640   }
    641   return true;
    642 }
    643 
    644 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, int32_t value) {
    645   return findOrCreate<ConstantInst>(
    646       [=](ConstantInst *c) { return c->mOperand1.intValue == value; },
    647       [=]() -> ConstantInst * {
    648         LiteralContextDependentNumber cdn = {.intValue = value};
    649         return mBuilder->MakeConstant(type, cdn);
    650       },
    651       &mGlobalDefs);
    652 }
    653 
    654 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, uint32_t value) {
    655   return findOrCreate<ConstantInst>(
    656       [=](ConstantInst *c) { return c->mOperand1.intValue == (int)value; },
    657       [=]() -> ConstantInst * {
    658         LiteralContextDependentNumber cdn = {.intValue = (int)value};
    659         return mBuilder->MakeConstant(type, cdn);
    660       },
    661       &mGlobalDefs);
    662 }
    663 
    664 ConstantInst *GlobalSection::getConstant(TypeFloatInst *type, float value) {
    665   return findOrCreate<ConstantInst>(
    666       [=](ConstantInst *c) { return c->mOperand1.floatValue == value; },
    667       [=]() -> ConstantInst * {
    668         LiteralContextDependentNumber cdn = {.floatValue = value};
    669         return mBuilder->MakeConstant(type, cdn);
    670       },
    671       &mGlobalDefs);
    672 }
    673 
    674 ConstantCompositeInst *
    675 GlobalSection::getConstantComposite(TypeVectorInst *type,
    676                                     ConstantInst *components[], size_t width) {
    677   return findOrCreate<ConstantCompositeInst>(
    678       [=](ConstantCompositeInst *c) {
    679         if (c->mOperand1.size() != width) {
    680           return false;
    681         }
    682         for (size_t i = 0; i < width; i++) {
    683           if (c->mOperand1[i].mInstruction != components[i]) {
    684             return false;
    685           }
    686         }
    687         return true;
    688       },
    689       [=]() -> ConstantCompositeInst * {
    690         ConstantCompositeInst *c = mBuilder->MakeConstantComposite(type);
    691         for (size_t i = 0; i < width; i++) {
    692           c->mOperand1.push_back(components[i]);
    693         }
    694         return c;
    695       },
    696       &mGlobalDefs);
    697 }
    698 
    699 TypeVoidInst *GlobalSection::getVoidType() {
    700   return findOrCreate<TypeVoidInst>(
    701       [=](TypeVoidInst *) -> bool { return true; },
    702       [=]() -> TypeVoidInst * { return mBuilder->MakeTypeVoid(); },
    703       &mGlobalDefs);
    704 }
    705 
    706 TypeIntInst *GlobalSection::getIntType(int bits, bool isSigned) {
    707   if (isSigned) {
    708     switch (bits) {
    709 #define HANDLE_INT_SIZE(INT_TYPE, BITS, SIGNED)                                \
    710   case BITS: {                                                                 \
    711     return findOrCreate<TypeIntInst>(                                          \
    712         [=](TypeIntInst *intTy) -> bool {                                      \
    713           return intTy->mOperand1 == BITS && intTy->mOperand2 == SIGNED;       \
    714         },                                                                     \
    715         [=]() -> TypeIntInst * {                                               \
    716           return mBuilder->MakeTypeInt(BITS, SIGNED);                          \
    717         },                                                                     \
    718         &mGlobalDefs);                                                         \
    719   }
    720       HANDLE_INT_SIZE(Int, 8, 1);
    721       HANDLE_INT_SIZE(Int, 16, 1);
    722       HANDLE_INT_SIZE(Int, 32, 1);
    723       HANDLE_INT_SIZE(Int, 64, 1);
    724     default:
    725       Module::errs() << "unexpected int type";
    726     }
    727   } else {
    728     switch (bits) {
    729       HANDLE_INT_SIZE(UInt, 8, 0);
    730       HANDLE_INT_SIZE(UInt, 16, 0);
    731       HANDLE_INT_SIZE(UInt, 32, 0);
    732       HANDLE_INT_SIZE(UInt, 64, 0);
    733     default:
    734       Module::errs() << "unexpected int type";
    735     }
    736   }
    737 #undef HANDLE_INT_SIZE
    738   return nullptr;
    739 }
    740 
    741 TypeFloatInst *GlobalSection::getFloatType(int bits) {
    742   switch (bits) {
    743 #define HANDLE_FLOAT_SIZE(BITS)                                                \
    744   case BITS: {                                                                 \
    745     return findOrCreate<TypeFloatInst>(                                        \
    746         [=](TypeFloatInst *floatTy) -> bool {                                  \
    747           return floatTy->mOperand1 == BITS;                                   \
    748         },                                                                     \
    749         [=]() -> TypeFloatInst * { return mBuilder->MakeTypeFloat(BITS); },    \
    750         &mGlobalDefs);                                                         \
    751   }
    752     HANDLE_FLOAT_SIZE(16);
    753     HANDLE_FLOAT_SIZE(32);
    754     HANDLE_FLOAT_SIZE(64);
    755   default:
    756     Module::errs() << "unexpeced floating point type";
    757   }
    758 #undef HANDLE_FLOAT_SIZE
    759   return nullptr;
    760 }
    761 
    762 TypeVectorInst *GlobalSection::getVectorType(Instruction *componentType,
    763                                              int width) {
    764   // TODO: verify that componentType is basic numeric types
    765 
    766   return findOrCreate<TypeVectorInst>(
    767       [=](TypeVectorInst *vecTy) -> bool {
    768         return vecTy->mOperand1.mInstruction == componentType &&
    769                vecTy->mOperand2 == width;
    770       },
    771       [=]() -> TypeVectorInst * {
    772         return mBuilder->MakeTypeVector(componentType, width);
    773       },
    774       &mGlobalDefs);
    775 }
    776 
    777 TypePointerInst *GlobalSection::getPointerType(StorageClass storage,
    778                                                Instruction *pointeeType) {
    779   return findOrCreate<TypePointerInst>(
    780       [=](TypePointerInst *type) -> bool {
    781         return type->mOperand1 == storage &&
    782                type->mOperand2.mInstruction == pointeeType;
    783       },
    784       [=]() -> TypePointerInst * {
    785         return mBuilder->MakeTypePointer(storage, pointeeType);
    786       },
    787       &mGlobalDefs);
    788 }
    789 
    790 TypeRuntimeArrayInst *
    791 GlobalSection::getRuntimeArrayType(Instruction *elemType) {
    792   return findOrCreate<TypeRuntimeArrayInst>(
    793       [=](TypeRuntimeArrayInst * /*type*/) -> bool {
    794         // return type->mOperand1.mInstruction == elemType;
    795         return false;
    796       },
    797       [=]() -> TypeRuntimeArrayInst * {
    798         return mBuilder->MakeTypeRuntimeArray(elemType);
    799       },
    800       &mGlobalDefs);
    801 }
    802 
    803 TypeStructInst *GlobalSection::getStructType(Instruction *fieldType[],
    804                                              int numField) {
    805   TypeStructInst *structTy = mBuilder->MakeTypeStruct();
    806   for (int i = 0; i < numField; i++) {
    807     structTy->mOperand1.push_back(fieldType[i]);
    808   }
    809   mGlobalDefs.push_back(structTy);
    810   return structTy;
    811 }
    812 
    813 TypeFunctionInst *GlobalSection::getFunctionType(Instruction *retType,
    814                                                  Instruction *const argType[],
    815                                                  size_t numArg) {
    816   return findOrCreate<TypeFunctionInst>(
    817       [=](TypeFunctionInst *type) -> bool {
    818         if (type->mOperand1.mInstruction != retType ||
    819             type->mOperand2.size() != numArg) {
    820           return false;
    821         }
    822         for (size_t i = 0; i < numArg; i++) {
    823           if (type->mOperand2[i].mInstruction != argType[i]) {
    824             return false;
    825           }
    826         }
    827         return true;
    828       },
    829       [=]() -> TypeFunctionInst * {
    830         TypeFunctionInst *funcTy = mBuilder->MakeTypeFunction(retType);
    831         for (size_t i = 0; i < numArg; i++) {
    832           funcTy->mOperand2.push_back(argType[i]);
    833         }
    834         return funcTy;
    835       },
    836       &mGlobalDefs);
    837 }
    838 
    839 GlobalSection *GlobalSection::addStructType(TypeStructInst *structType) {
    840   mGlobalDefs.push_back(structType);
    841   return this;
    842 }
    843 
    844 GlobalSection *GlobalSection::addVariable(VariableInst *var) {
    845   mGlobalDefs.push_back(var);
    846   return this;
    847 }
    848 
    849 VariableInst *GlobalSection::getInvocationId() {
    850   if (mInvocationId) {
    851     return mInvocationId.get();
    852   }
    853 
    854   TypeIntInst *UIntTy = getIntType(32, false);
    855   TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3);
    856   TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy);
    857 
    858   VariableInst *InvocationId =
    859       mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input);
    860   InvocationId->decorate(Decoration::BuiltIn)
    861       ->addExtraOperand(static_cast<uint32_t>(BuiltIn::GlobalInvocationId));
    862 
    863   mInvocationId.reset(InvocationId);
    864 
    865   return InvocationId;
    866 }
    867 
    868 VariableInst *GlobalSection::getNumWorkgroups() {
    869   if (mNumWorkgroups) {
    870     return mNumWorkgroups.get();
    871   }
    872 
    873   TypeIntInst *UIntTy = getIntType(32, false);
    874   TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3);
    875   TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy);
    876 
    877   VariableInst *GNum = mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input);
    878   GNum->decorate(Decoration::BuiltIn)
    879       ->addExtraOperand(static_cast<uint32_t>(BuiltIn::NumWorkgroups));
    880 
    881   mNumWorkgroups.reset(GNum);
    882 
    883   return GNum;
    884 }
    885 
    886 bool FunctionDeclaration::DeserializeInternal(InputWordStream &IS) {
    887   if (!Deserialize<FunctionInst>(IS)) {
    888     return false;
    889   }
    890 
    891   DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams);
    892 
    893   if (!Deserialize<FunctionEndInst>(IS)) {
    894     return false;
    895   }
    896 
    897   return true;
    898 }
    899 
    900 template <> Instruction *Deserialize(InputWordStream &IS) {
    901   Instruction *inst;
    902 
    903   switch ((*IS) & 0xFFFF) {
    904 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS)                                 \
    905   case OPCODE:                                                                 \
    906     inst = Deserialize<INST_CLASS>(IS);                                        \
    907     break;
    908 #include "instruction_dispatches_generated.h"
    909 #undef HANDLE_INSTRUCTION
    910   default:
    911     Module::errs() << "unrecognized instruction";
    912     inst = nullptr;
    913   }
    914 
    915   return inst;
    916 }
    917 
    918 bool Block::DeserializeInternal(InputWordStream &IS) {
    919   Instruction *inst;
    920   while (((*IS) & 0xFFFF) != OpFunctionEnd &&
    921          (inst = Deserialize<Instruction>(IS))) {
    922     mInsts.push_back(inst);
    923     if (inst->getOpCode() == OpBranch ||
    924         inst->getOpCode() == OpBranchConditional ||
    925         inst->getOpCode() == OpSwitch || inst->getOpCode() == OpKill ||
    926         inst->getOpCode() == OpReturn || inst->getOpCode() == OpReturnValue ||
    927         inst->getOpCode() == OpUnreachable) {
    928       break;
    929     }
    930   }
    931   return !mInsts.empty();
    932 }
    933 
    934 FunctionDefinition::FunctionDefinition()
    935     : mParamsDeleter(mParams), mBlocksDeleter(mBlocks) {}
    936 
    937 FunctionDefinition::FunctionDefinition(Builder *builder, FunctionInst *func,
    938                                        FunctionEndInst *end)
    939     : Entity(builder), mFunc(func), mFuncEnd(end), mParamsDeleter(mParams),
    940       mBlocksDeleter(mBlocks) {}
    941 
    942 bool FunctionDefinition::DeserializeInternal(InputWordStream &IS) {
    943   mFunc.reset(Deserialize<FunctionInst>(IS));
    944   if (!mFunc) {
    945     return false;
    946   }
    947 
    948   DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams);
    949   DeserializeZeroOrMore<Block>(IS, mBlocks);
    950 
    951   mFuncEnd.reset(Deserialize<FunctionEndInst>(IS));
    952   if (!mFuncEnd) {
    953     return false;
    954   }
    955 
    956   return true;
    957 }
    958 
    959 Instruction *FunctionDefinition::getReturnType() const {
    960   return mFunc->mResultType.mInstruction;
    961 }
    962 
    963 } // namespace spirit
    964 } // namespace android
    965