Home | History | Annotate | Download | only in spirit
      1 /*
      2  * Copyright 2017, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "module.h"
     18 
     19 #include <set>
     20 
     21 #include "builder.h"
     22 #include "core_defs.h"
     23 #include "instructions.h"
     24 #include "types_generated.h"
     25 #include "word_stream.h"
     26 
     27 namespace android {
     28 namespace spirit {
     29 
     30 Module *Module::mInstance = nullptr;
     31 
     32 Module *Module::getCurrentModule() {
     33   if (mInstance == nullptr) {
     34     return mInstance = new Module();
     35   }
     36   return mInstance;
     37 }
     38 
     39 Module::Module()
     40     : mNextId(1), mCapabilitiesDeleter(mCapabilities),
     41       mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports),
     42       mEntryPointInstsDeleter(mEntryPointInsts),
     43       mExecutionModesDeleter(mExecutionModes),
     44       mEntryPointsDeleter(mEntryPoints),
     45       mFunctionDefinitionsDeleter(mFunctionDefinitions) {
     46   mInstance = this;
     47 }
     48 
     49 Module::Module(Builder *b)
     50     : Entity(b), mNextId(1), mCapabilitiesDeleter(mCapabilities),
     51       mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports),
     52       mEntryPointInstsDeleter(mEntryPointInsts),
     53       mExecutionModesDeleter(mExecutionModes),
     54       mEntryPointsDeleter(mEntryPoints),
     55       mFunctionDefinitionsDeleter(mFunctionDefinitions) {
     56   mInstance = this;
     57 }
     58 
     59 bool Module::resolveIds() {
     60   auto &table = mIdTable;
     61 
     62   std::unique_ptr<IVisitor> v0(
     63       CreateInstructionVisitor([&table](Instruction *inst) {
     64         if (inst->hasResult()) {
     65           table.insert(std::make_pair(inst->getId(), inst));
     66         }
     67       }));
     68   v0->visit(this);
     69 
     70   mNextId = mIdTable.rbegin()->first + 1;
     71 
     72   int err = 0;
     73   std::unique_ptr<IVisitor> v(
     74       CreateInstructionVisitor([&table, &err](Instruction *inst) {
     75         for (auto ref : inst->getAllIdRefs()) {
     76           if (ref) {
     77             auto it = table.find(ref->mId);
     78             if (it != table.end()) {
     79               ref->mInstruction = it->second;
     80             } else {
     81               std::cout << "Found no instruction for id " << ref->mId
     82                         << std::endl;
     83               err++;
     84             }
     85           }
     86         }
     87       }));
     88   v->visit(this);
     89   return err == 0;
     90 }
     91 
     92 bool Module::DeserializeInternal(InputWordStream &IS) {
     93   if (IS.empty()) {
     94     return false;
     95   }
     96 
     97   IS >> &mMagicNumber;
     98   if (mMagicNumber != 0x07230203) {
     99     errs() << "Wrong Magic Number: " << mMagicNumber;
    100     return false;
    101   }
    102 
    103   if (IS.empty()) {
    104     return false;
    105   }
    106 
    107   IS >> &mVersion.mWord;
    108   if (mVersion.mBytes[0] != 0 || mVersion.mBytes[3] != 0) {
    109     return false;
    110   }
    111 
    112   if (IS.empty()) {
    113     return false;
    114   }
    115 
    116   IS >> &mGeneratorMagicNumber >> &mBound >> &mReserved;
    117 
    118   DeserializeZeroOrMore<CapabilityInst>(IS, mCapabilities);
    119   DeserializeZeroOrMore<ExtensionInst>(IS, mExtensions);
    120   DeserializeZeroOrMore<ExtInstImportInst>(IS, mExtInstImports);
    121 
    122   mMemoryModel.reset(Deserialize<MemoryModelInst>(IS));
    123   if (!mMemoryModel) {
    124     errs() << "Missing memory model specification.\n";
    125     return false;
    126   }
    127 
    128   DeserializeZeroOrMore<EntryPointDefinition>(IS, mEntryPoints);
    129   DeserializeZeroOrMore<ExecutionModeInst>(IS, mExecutionModes);
    130   for (auto entry : mEntryPoints) {
    131     mEntryPointInsts.push_back(entry->getInstruction());
    132     for (auto mode : mExecutionModes) {
    133       entry->applyExecutionMode(mode);
    134     }
    135   }
    136 
    137   mDebugInfo.reset(Deserialize<DebugInfoSection>(IS));
    138   mAnnotations.reset(Deserialize<AnnotationSection>(IS));
    139   mGlobals.reset(Deserialize<GlobalSection>(IS));
    140 
    141   DeserializeZeroOrMore<FunctionDefinition>(IS, mFunctionDefinitions);
    142 
    143   if (mFunctionDefinitions.empty()) {
    144     errs() << "Missing function definitions.\n";
    145     for (int i = 0; i < 4; i++) {
    146       uint32_t w;
    147       IS >> &w;
    148       std::cout << std::hex << w << " ";
    149     }
    150     std::cout << std::endl;
    151     return false;
    152   }
    153 
    154   return true;
    155 }
    156 
    157 void Module::initialize() {
    158   mMagicNumber = 0x07230203;
    159   mVersion.mMajorMinor = {.mMinorNumber = 1, .mMajorNumber = 1};
    160   mGeneratorMagicNumber = 0x00070000;
    161   mBound = 0;
    162   mReserved = 0;
    163   mAnnotations.reset(new AnnotationSection());
    164 }
    165 
    166 void Module::SerializeHeader(OutputWordStream &OS) const {
    167   OS << mMagicNumber;
    168   OS << mVersion.mWord << mGeneratorMagicNumber;
    169   if (mBound == 0) {
    170     OS << mIdTable.end()->first + 1;
    171   } else {
    172     OS << std::max(mBound, mNextId);
    173   }
    174   OS << mReserved;
    175 }
    176 
    177 void Module::Serialize(OutputWordStream &OS) const {
    178   SerializeHeader(OS);
    179   Entity::Serialize(OS);
    180 }
    181 
    182 Module *Module::addCapability(Capability cap) {
    183   mCapabilities.push_back(mBuilder->MakeCapability(cap));
    184   return this;
    185 }
    186 
    187 Module *Module::setMemoryModel(AddressingModel am, MemoryModel mm) {
    188   mMemoryModel.reset(mBuilder->MakeMemoryModel(am, mm));
    189   return this;
    190 }
    191 
    192 Module *Module::addExtInstImport(const char *extName) {
    193   ExtInstImportInst *extInst = mBuilder->MakeExtInstImport(extName);
    194   mExtInstImports.push_back(extInst);
    195   if (strcmp(extName, "GLSL.std.450") == 0) {
    196     mGLExt = extInst;
    197   }
    198   return this;
    199 }
    200 
    201 Module *Module::addSource(SourceLanguage lang, int version) {
    202   if (!mDebugInfo) {
    203     mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
    204   }
    205   mDebugInfo->addSource(lang, version);
    206   return this;
    207 }
    208 
    209 Module *Module::addSourceExtension(const char *ext) {
    210   if (!mDebugInfo) {
    211     mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
    212   }
    213   mDebugInfo->addSourceExtension(ext);
    214   return this;
    215 }
    216 
    217 Module *Module::addString(const char *str) {
    218   if (!mDebugInfo) {
    219     mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
    220   }
    221   mDebugInfo->addString(str);
    222   return this;
    223 }
    224 
    225 Module *Module::addEntryPoint(EntryPointDefinition *entry) {
    226   mEntryPoints.push_back(entry);
    227   auto newModes = entry->getExecutionModes();
    228   mExecutionModes.insert(mExecutionModes.end(), newModes.begin(),
    229                          newModes.end());
    230   return this;
    231 }
    232 
    233 const std::string Module::findStringOfPrefix(const char *prefix) const {
    234   if (!mDebugInfo) {
    235     return std::string();
    236   }
    237   return mDebugInfo->findStringOfPrefix(prefix);
    238 }
    239 
    240 GlobalSection *Module::getGlobalSection() {
    241   if (!mGlobals) {
    242     mGlobals.reset(new GlobalSection());
    243   }
    244   return mGlobals.get();
    245 }
    246 
    247 ConstantInst *Module::getConstant(TypeIntInst *type, int32_t value) {
    248   return getGlobalSection()->getConstant(type, value);
    249 }
    250 
    251 ConstantInst *Module::getConstant(TypeIntInst *type, uint32_t value) {
    252   return getGlobalSection()->getConstant(type, value);
    253 }
    254 
    255 ConstantInst *Module::getConstant(TypeFloatInst *type, float value) {
    256   return getGlobalSection()->getConstant(type, value);
    257 }
    258 
    259 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
    260                                                     ConstantInst *components[],
    261                                                     size_t width) {
    262   return getGlobalSection()->getConstantComposite(type, components, width);
    263 }
    264 
    265 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
    266                                                     ConstantInst *comp0,
    267                                                     ConstantInst *comp1,
    268                                                     ConstantInst *comp2) {
    269   // TODO: verify that component types are the same and consistent with the
    270   // resulting vector type
    271   ConstantInst *comps[] = {comp0, comp1, comp2};
    272   return getConstantComposite(type, comps, 3);
    273 }
    274 
    275 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
    276                                                     ConstantInst *comp0,
    277                                                     ConstantInst *comp1,
    278                                                     ConstantInst *comp2,
    279                                                     ConstantInst *comp3) {
    280   // TODO: verify that component types are the same and consistent with the
    281   // resulting vector type
    282   ConstantInst *comps[] = {comp0, comp1, comp2, comp3};
    283   return getConstantComposite(type, comps, 4);
    284 }
    285 
    286 TypeVoidInst *Module::getVoidType() {
    287   return getGlobalSection()->getVoidType();
    288 }
    289 
    290 TypeIntInst *Module::getIntType(int bits, bool isSigned) {
    291   return getGlobalSection()->getIntType(bits, isSigned);
    292 }
    293 
    294 TypeIntInst *Module::getUnsignedIntType(int bits) {
    295   return getIntType(bits, false);
    296 }
    297 
    298 TypeFloatInst *Module::getFloatType(int bits) {
    299   return getGlobalSection()->getFloatType(bits);
    300 }
    301 
    302 TypeVectorInst *Module::getVectorType(Instruction *componentType, int width) {
    303   return getGlobalSection()->getVectorType(componentType, width);
    304 }
    305 
    306 TypePointerInst *Module::getPointerType(StorageClass storage,
    307                                         Instruction *pointeeType) {
    308   return getGlobalSection()->getPointerType(storage, pointeeType);
    309 }
    310 
    311 TypeRuntimeArrayInst *Module::getRuntimeArrayType(Instruction *elementType) {
    312   return getGlobalSection()->getRuntimeArrayType(elementType);
    313 }
    314 
    315 TypeStructInst *Module::getStructType(Instruction *fieldType[], int numField) {
    316   return getGlobalSection()->getStructType(fieldType, numField);
    317 }
    318 
    319 TypeStructInst *Module::getStructType(Instruction *fieldType) {
    320   return getStructType(&fieldType, 1);
    321 }
    322 
    323 TypeFunctionInst *Module::getFunctionType(Instruction *retType,
    324                                           Instruction *const argType[],
    325                                           size_t numArg) {
    326   return getGlobalSection()->getFunctionType(retType, argType, numArg);
    327 }
    328 
    329 TypeFunctionInst *
    330 Module::getFunctionType(Instruction *retType,
    331                         const std::vector<Instruction *> &argTypes) {
    332   return getGlobalSection()->getFunctionType(retType, argTypes.data(),
    333                                              argTypes.size());
    334 }
    335 
    336 size_t Module::getSize(TypeVoidInst *) { return 0; }
    337 
    338 size_t Module::getSize(TypeIntInst *intTy) { return intTy->mOperand1 / 8; }
    339 
    340 size_t Module::getSize(TypeFloatInst *fpTy) { return fpTy->mOperand1 / 8; }
    341 
    342 size_t Module::getSize(TypeVectorInst *vTy) {
    343   return getSize(vTy->mOperand1.mInstruction) * vTy->mOperand2;
    344 }
    345 
    346 size_t Module::getSize(TypePointerInst *) {
    347   return 4; // TODO: or 8?
    348 }
    349 
    350 size_t Module::getSize(TypeStructInst *structTy) {
    351   size_t sz = 0;
    352   for (auto ty : structTy->mOperand1) {
    353     sz += getSize(ty.mInstruction);
    354   }
    355   return sz;
    356 }
    357 
    358 size_t Module::getSize(TypeFunctionInst *) {
    359   return 4; // TODO: or 8? Is this just the size of a pointer?
    360 }
    361 
    362 size_t Module::getSize(Instruction *inst) {
    363   switch (inst->getOpCode()) {
    364   case OpTypeVoid:
    365     return getSize(static_cast<TypeVoidInst *>(inst));
    366   case OpTypeInt:
    367     return getSize(static_cast<TypeIntInst *>(inst));
    368   case OpTypeFloat:
    369     return getSize(static_cast<TypeFloatInst *>(inst));
    370   case OpTypeVector:
    371     return getSize(static_cast<TypeVectorInst *>(inst));
    372   case OpTypeStruct:
    373     return getSize(static_cast<TypeStructInst *>(inst));
    374   case OpTypeFunction:
    375     return getSize(static_cast<TypeFunctionInst *>(inst));
    376   default:
    377     return 0;
    378   }
    379 }
    380 
    381 Module *Module::addFunctionDefinition(FunctionDefinition *func) {
    382   mFunctionDefinitions.push_back(func);
    383   return this;
    384 }
    385 
    386 Instruction *Module::lookupByName(const char *name) const {
    387   return mDebugInfo->lookupByName(name);
    388 }
    389 
    390 FunctionDefinition *
    391 Module::getFunctionDefinitionFromInstruction(FunctionInst *inst) const {
    392   for (auto fdef : mFunctionDefinitions) {
    393     if (fdef->getInstruction() == inst) {
    394       return fdef;
    395     }
    396   }
    397   return nullptr;
    398 }
    399 
    400 FunctionDefinition *
    401 Module::lookupFunctionDefinitionByName(const char *name) const {
    402   FunctionInst *inst = static_cast<FunctionInst *>(lookupByName(name));
    403   return getFunctionDefinitionFromInstruction(inst);
    404 }
    405 
    406 const char *Module::lookupNameByInstruction(const Instruction *inst) const {
    407   return mDebugInfo->lookupNameByInstruction(inst);
    408 }
    409 
    410 VariableInst *Module::getInvocationId() {
    411   return getGlobalSection()->getInvocationId();
    412 }
    413 
    414 VariableInst *Module::getNumWorkgroups() {
    415   return getGlobalSection()->getNumWorkgroups();
    416 }
    417 
    418 Module *Module::addStructType(TypeStructInst *structType) {
    419   getGlobalSection()->addStructType(structType);
    420   return this;
    421 }
    422 
    423 Module *Module::addVariable(VariableInst *var) {
    424   getGlobalSection()->addVariable(var);
    425   return this;
    426 }
    427 
    428 void Module::consolidateAnnotations() {
    429   std::vector<Instruction *> annotations(mAnnotations->begin(),
    430                                       mAnnotations->end());
    431   std::unique_ptr<IVisitor> v(
    432       CreateInstructionVisitor([&annotations](Instruction *inst) -> void {
    433         const auto &ann = inst->getAnnotations();
    434         annotations.insert(annotations.end(), ann.begin(), ann.end());
    435       }));
    436   v->visit(this);
    437   mAnnotations->clear();
    438   mAnnotations->addAnnotations(annotations.begin(), annotations.end());
    439 }
    440 
    441 EntryPointDefinition::EntryPointDefinition(Builder *builder,
    442                                            ExecutionModel execModel,
    443                                            FunctionDefinition *func,
    444                                            const char *name)
    445     : Entity(builder), mFunction(func->getInstruction()),
    446       mExecutionModel(execModel) {
    447   mName = strndup(name, strlen(name));
    448   mEntryPointInst = mBuilder->MakeEntryPoint(execModel, mFunction, mName);
    449   (void)mExecutionModel; // suppress unused private field warning
    450 }
    451 
    452 bool EntryPointDefinition::DeserializeInternal(InputWordStream &IS) {
    453   if (IS.empty()) {
    454     return false;
    455   }
    456 
    457   if ((mEntryPointInst = Deserialize<EntryPointInst>(IS))) {
    458     return true;
    459   }
    460 
    461   return false;
    462 }
    463 
    464 EntryPointDefinition *
    465 EntryPointDefinition::applyExecutionMode(ExecutionModeInst *mode) {
    466   if (mode->mOperand1.mInstruction == mFunction) {
    467     addExecutionMode(mode);
    468   }
    469   return this;
    470 }
    471 
    472 EntryPointDefinition *EntryPointDefinition::addToInterface(VariableInst *var) {
    473   mInterface.push_back(var);
    474   mEntryPointInst->mOperand4.push_back(var);
    475   return this;
    476 }
    477 
    478 EntryPointDefinition *EntryPointDefinition::setLocalSize(uint32_t width,
    479                                                          uint32_t height,
    480                                                          uint32_t depth) {
    481   mLocalSize.mWidth = width;
    482   mLocalSize.mHeight = height;
    483   mLocalSize.mDepth = depth;
    484 
    485   auto mode = mBuilder->MakeExecutionMode(mFunction, ExecutionMode::LocalSize);
    486   mode->addExtraOperand(width)->addExtraOperand(height)->addExtraOperand(depth);
    487 
    488   addExecutionMode(mode);
    489 
    490   return this;
    491 }
    492 
    493 bool DebugInfoSection::DeserializeInternal(InputWordStream &IS) {
    494   while (true) {
    495     if (auto str = Deserialize<StringInst>(IS)) {
    496       mSources.push_back(str);
    497     } else if (auto src = Deserialize<SourceInst>(IS)) {
    498       mSources.push_back(src);
    499     } else if (auto srcExt = Deserialize<SourceExtensionInst>(IS)) {
    500       mSources.push_back(srcExt);
    501     } else if (auto srcCont = Deserialize<SourceContinuedInst>(IS)) {
    502       mSources.push_back(srcCont);
    503     } else {
    504       break;
    505     }
    506   }
    507 
    508   while (true) {
    509     if (auto name = Deserialize<NameInst>(IS)) {
    510       mNames.push_back(name);
    511     } else if (auto memName = Deserialize<MemberNameInst>(IS)) {
    512       mNames.push_back(memName);
    513     } else {
    514       break;
    515     }
    516   }
    517 
    518   return true;
    519 }
    520 
    521 DebugInfoSection *DebugInfoSection::addSource(SourceLanguage lang,
    522                                               int version) {
    523   SourceInst *source = mBuilder->MakeSource(lang, version);
    524   mSources.push_back(source);
    525   return this;
    526 }
    527 
    528 DebugInfoSection *DebugInfoSection::addSourceExtension(const char *ext) {
    529   SourceExtensionInst *inst = mBuilder->MakeSourceExtension(ext);
    530   mSources.push_back(inst);
    531   return this;
    532 }
    533 
    534 DebugInfoSection *DebugInfoSection::addString(const char *str) {
    535   StringInst *source = mBuilder->MakeString(str);
    536   mSources.push_back(source);
    537   return this;
    538 }
    539 
    540 std::string DebugInfoSection::findStringOfPrefix(const char *prefix) {
    541   auto it = std::find_if(
    542       mSources.begin(), mSources.end(), [prefix](Instruction *inst) -> bool {
    543         if (inst->getOpCode() != OpString) {
    544           return false;
    545         }
    546         const StringInst *strInst = static_cast<const StringInst *>(inst);
    547         const std::string &str = strInst->mOperand1;
    548         return str.find(prefix) == 0;
    549       });
    550   if (it == mSources.end()) {
    551     return "";
    552   }
    553   StringInst *strInst = static_cast<StringInst *>(*it);
    554   return strInst->mOperand1;
    555 }
    556 
    557 Instruction *DebugInfoSection::lookupByName(const char *name) const {
    558   for (auto inst : mNames) {
    559     if (inst->getOpCode() == OpName) {
    560       NameInst *nameInst = static_cast<NameInst *>(inst);
    561       if (nameInst->mOperand2.compare(name) == 0) {
    562         return nameInst->mOperand1.mInstruction;
    563       }
    564     }
    565     // Ignore member names
    566   }
    567   return nullptr;
    568 }
    569 
    570 const char *
    571 DebugInfoSection::lookupNameByInstruction(const Instruction *target) const {
    572   for (auto inst : mNames) {
    573     if (inst->getOpCode() == OpName) {
    574       NameInst *nameInst = static_cast<NameInst *>(inst);
    575       if (nameInst->mOperand1.mInstruction == target) {
    576         return nameInst->mOperand2.c_str();
    577       }
    578     }
    579     // Ignore member names
    580   }
    581   return nullptr;
    582 }
    583 
    584 AnnotationSection::AnnotationSection() : mAnnotationsDeleter(mAnnotations) {}
    585 
    586 AnnotationSection::AnnotationSection(Builder *b)
    587     : Entity(b), mAnnotationsDeleter(mAnnotations) {}
    588 
    589 bool AnnotationSection::DeserializeInternal(InputWordStream &IS) {
    590   while (true) {
    591     if (auto decor = Deserialize<DecorateInst>(IS)) {
    592       mAnnotations.push_back(decor);
    593     } else if (auto decor = Deserialize<MemberDecorateInst>(IS)) {
    594       mAnnotations.push_back(decor);
    595     } else if (auto decor = Deserialize<GroupDecorateInst>(IS)) {
    596       mAnnotations.push_back(decor);
    597     } else if (auto decor = Deserialize<GroupMemberDecorateInst>(IS)) {
    598       mAnnotations.push_back(decor);
    599     } else if (auto decor = Deserialize<DecorationGroupInst>(IS)) {
    600       mAnnotations.push_back(decor);
    601     } else {
    602       break;
    603     }
    604   }
    605   return true;
    606 }
    607 
    608 GlobalSection::GlobalSection() : mGlobalDefsDeleter(mGlobalDefs) {}
    609 
    610 GlobalSection::GlobalSection(Builder *builder)
    611     : Entity(builder), mGlobalDefsDeleter(mGlobalDefs) {}
    612 
    613 namespace {
    614 
    615 template <typename T>
    616 T *findOrCreate(std::function<bool(T *)> criteria, std::function<T *()> factory,
    617                 std::vector<Instruction *> *globals) {
    618   T *derived;
    619   for (auto inst : *globals) {
    620     if (inst->getOpCode() == T::mOpCode) {
    621       T *derived = static_cast<T *>(inst);
    622       if (criteria(derived)) {
    623         return derived;
    624       }
    625     }
    626   }
    627   derived = factory();
    628   globals->push_back(derived);
    629   return derived;
    630 }
    631 
    632 } // anonymous namespace
    633 
    634 bool GlobalSection::DeserializeInternal(InputWordStream &IS) {
    635   while (true) {
    636 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS)                                 \
    637   if (auto typeInst = Deserialize<INST_CLASS>(IS)) {                           \
    638     mGlobalDefs.push_back(typeInst);                                           \
    639     continue;                                                                  \
    640   }
    641 #include "const_inst_dispatches_generated.h"
    642 #include "type_inst_dispatches_generated.h"
    643 #undef HANDLE_INSTRUCTION
    644 
    645     if (auto globalInst = Deserialize<VariableInst>(IS)) {
    646       // Check if this is function scoped
    647       if (globalInst->mOperand1 == StorageClass::Function) {
    648         Module::errs() << "warning: Variable (id = " << globalInst->mResult;
    649         Module::errs() << ") has function scope in global section.\n";
    650         // Khronos LLVM-SPIRV convertor emits "Function" storage-class globals.
    651         // As a workaround, accept such SPIR-V code here, and fix it up later
    652         // in the rs2spirv compiler by correcting the storage class.
    653         // In a stricter deserializer, such code should be rejected, and we
    654         // should return false here.
    655       }
    656       mGlobalDefs.push_back(globalInst);
    657       continue;
    658     }
    659 
    660     if (auto globalInst = Deserialize<UndefInst>(IS)) {
    661       mGlobalDefs.push_back(globalInst);
    662       continue;
    663     }
    664     break;
    665   }
    666   return true;
    667 }
    668 
    669 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, int32_t value) {
    670   return findOrCreate<ConstantInst>(
    671       [=](ConstantInst *c) { return c->mOperand1.intValue == value; },
    672       [=]() -> ConstantInst * {
    673         LiteralContextDependentNumber cdn = {.intValue = value};
    674         return mBuilder->MakeConstant(type, cdn);
    675       },
    676       &mGlobalDefs);
    677 }
    678 
    679 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, uint32_t value) {
    680   return findOrCreate<ConstantInst>(
    681       [=](ConstantInst *c) { return c->mOperand1.intValue == (int)value; },
    682       [=]() -> ConstantInst * {
    683         LiteralContextDependentNumber cdn = {.intValue = (int)value};
    684         return mBuilder->MakeConstant(type, cdn);
    685       },
    686       &mGlobalDefs);
    687 }
    688 
    689 ConstantInst *GlobalSection::getConstant(TypeFloatInst *type, float value) {
    690   return findOrCreate<ConstantInst>(
    691       [=](ConstantInst *c) { return c->mOperand1.floatValue == value; },
    692       [=]() -> ConstantInst * {
    693         LiteralContextDependentNumber cdn = {.floatValue = value};
    694         return mBuilder->MakeConstant(type, cdn);
    695       },
    696       &mGlobalDefs);
    697 }
    698 
    699 ConstantCompositeInst *
    700 GlobalSection::getConstantComposite(TypeVectorInst *type,
    701                                     ConstantInst *components[], size_t width) {
    702   return findOrCreate<ConstantCompositeInst>(
    703       [=](ConstantCompositeInst *c) {
    704         if (c->mOperand1.size() != width) {
    705           return false;
    706         }
    707         for (size_t i = 0; i < width; i++) {
    708           if (c->mOperand1[i].mInstruction != components[i]) {
    709             return false;
    710           }
    711         }
    712         return true;
    713       },
    714       [=]() -> ConstantCompositeInst * {
    715         ConstantCompositeInst *c = mBuilder->MakeConstantComposite(type);
    716         for (size_t i = 0; i < width; i++) {
    717           c->mOperand1.push_back(components[i]);
    718         }
    719         return c;
    720       },
    721       &mGlobalDefs);
    722 }
    723 
    724 TypeVoidInst *GlobalSection::getVoidType() {
    725   return findOrCreate<TypeVoidInst>(
    726       [=](TypeVoidInst *) -> bool { return true; },
    727       [=]() -> TypeVoidInst * { return mBuilder->MakeTypeVoid(); },
    728       &mGlobalDefs);
    729 }
    730 
    731 TypeIntInst *GlobalSection::getIntType(int bits, bool isSigned) {
    732   if (isSigned) {
    733     switch (bits) {
    734 #define HANDLE_INT_SIZE(INT_TYPE, BITS, SIGNED)                                \
    735   case BITS: {                                                                 \
    736     return findOrCreate<TypeIntInst>(                                          \
    737         [=](TypeIntInst *intTy) -> bool {                                      \
    738           return intTy->mOperand1 == BITS && intTy->mOperand2 == SIGNED;       \
    739         },                                                                     \
    740         [=]() -> TypeIntInst * {                                               \
    741           return mBuilder->MakeTypeInt(BITS, SIGNED);                          \
    742         },                                                                     \
    743         &mGlobalDefs);                                                         \
    744   }
    745       HANDLE_INT_SIZE(Int, 8, 1);
    746       HANDLE_INT_SIZE(Int, 16, 1);
    747       HANDLE_INT_SIZE(Int, 32, 1);
    748       HANDLE_INT_SIZE(Int, 64, 1);
    749     default:
    750       Module::errs() << "unexpected int type";
    751     }
    752   } else {
    753     switch (bits) {
    754       HANDLE_INT_SIZE(UInt, 8, 0);
    755       HANDLE_INT_SIZE(UInt, 16, 0);
    756       HANDLE_INT_SIZE(UInt, 32, 0);
    757       HANDLE_INT_SIZE(UInt, 64, 0);
    758     default:
    759       Module::errs() << "unexpected int type";
    760     }
    761   }
    762 #undef HANDLE_INT_SIZE
    763   return nullptr;
    764 }
    765 
    766 TypeFloatInst *GlobalSection::getFloatType(int bits) {
    767   switch (bits) {
    768 #define HANDLE_FLOAT_SIZE(BITS)                                                \
    769   case BITS: {                                                                 \
    770     return findOrCreate<TypeFloatInst>(                                        \
    771         [=](TypeFloatInst *floatTy) -> bool {                                  \
    772           return floatTy->mOperand1 == BITS;                                   \
    773         },                                                                     \
    774         [=]() -> TypeFloatInst * { return mBuilder->MakeTypeFloat(BITS); },    \
    775         &mGlobalDefs);                                                         \
    776   }
    777     HANDLE_FLOAT_SIZE(16);
    778     HANDLE_FLOAT_SIZE(32);
    779     HANDLE_FLOAT_SIZE(64);
    780   default:
    781     Module::errs() << "unexpeced floating point type";
    782   }
    783 #undef HANDLE_FLOAT_SIZE
    784   return nullptr;
    785 }
    786 
    787 TypeVectorInst *GlobalSection::getVectorType(Instruction *componentType,
    788                                              int width) {
    789   // TODO: verify that componentType is basic numeric types
    790 
    791   return findOrCreate<TypeVectorInst>(
    792       [=](TypeVectorInst *vecTy) -> bool {
    793         return vecTy->mOperand1.mInstruction == componentType &&
    794                vecTy->mOperand2 == width;
    795       },
    796       [=]() -> TypeVectorInst * {
    797         return mBuilder->MakeTypeVector(componentType, width);
    798       },
    799       &mGlobalDefs);
    800 }
    801 
    802 TypePointerInst *GlobalSection::getPointerType(StorageClass storage,
    803                                                Instruction *pointeeType) {
    804   return findOrCreate<TypePointerInst>(
    805       [=](TypePointerInst *type) -> bool {
    806         return type->mOperand1 == storage &&
    807                type->mOperand2.mInstruction == pointeeType;
    808       },
    809       [=]() -> TypePointerInst * {
    810         return mBuilder->MakeTypePointer(storage, pointeeType);
    811       },
    812       &mGlobalDefs);
    813 }
    814 
    815 TypeRuntimeArrayInst *
    816 GlobalSection::getRuntimeArrayType(Instruction *elemType) {
    817   return findOrCreate<TypeRuntimeArrayInst>(
    818       [=](TypeRuntimeArrayInst * /*type*/) -> bool {
    819         // return type->mOperand1.mInstruction == elemType;
    820         return false;
    821       },
    822       [=]() -> TypeRuntimeArrayInst * {
    823         return mBuilder->MakeTypeRuntimeArray(elemType);
    824       },
    825       &mGlobalDefs);
    826 }
    827 
    828 TypeStructInst *GlobalSection::getStructType(Instruction *fieldType[],
    829                                              int numField) {
    830   TypeStructInst *structTy = mBuilder->MakeTypeStruct();
    831   for (int i = 0; i < numField; i++) {
    832     structTy->mOperand1.push_back(fieldType[i]);
    833   }
    834   mGlobalDefs.push_back(structTy);
    835   return structTy;
    836 }
    837 
    838 TypeFunctionInst *GlobalSection::getFunctionType(Instruction *retType,
    839                                                  Instruction *const argType[],
    840                                                  size_t numArg) {
    841   return findOrCreate<TypeFunctionInst>(
    842       [=](TypeFunctionInst *type) -> bool {
    843         if (type->mOperand1.mInstruction != retType ||
    844             type->mOperand2.size() != numArg) {
    845           return false;
    846         }
    847         for (size_t i = 0; i < numArg; i++) {
    848           if (type->mOperand2[i].mInstruction != argType[i]) {
    849             return false;
    850           }
    851         }
    852         return true;
    853       },
    854       [=]() -> TypeFunctionInst * {
    855         TypeFunctionInst *funcTy = mBuilder->MakeTypeFunction(retType);
    856         for (size_t i = 0; i < numArg; i++) {
    857           funcTy->mOperand2.push_back(argType[i]);
    858         }
    859         return funcTy;
    860       },
    861       &mGlobalDefs);
    862 }
    863 
    864 GlobalSection *GlobalSection::addStructType(TypeStructInst *structType) {
    865   mGlobalDefs.push_back(structType);
    866   return this;
    867 }
    868 
    869 GlobalSection *GlobalSection::addVariable(VariableInst *var) {
    870   mGlobalDefs.push_back(var);
    871   return this;
    872 }
    873 
    874 VariableInst *GlobalSection::getInvocationId() {
    875   if (mInvocationId) {
    876     return mInvocationId.get();
    877   }
    878 
    879   TypeIntInst *UIntTy = getIntType(32, false);
    880   TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3);
    881   TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy);
    882 
    883   VariableInst *InvocationId =
    884       mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input);
    885   InvocationId->decorate(Decoration::BuiltIn)
    886       ->addExtraOperand(static_cast<uint32_t>(BuiltIn::GlobalInvocationId));
    887 
    888   mInvocationId.reset(InvocationId);
    889 
    890   return InvocationId;
    891 }
    892 
    893 VariableInst *GlobalSection::getNumWorkgroups() {
    894   if (mNumWorkgroups) {
    895     return mNumWorkgroups.get();
    896   }
    897 
    898   TypeIntInst *UIntTy = getIntType(32, false);
    899   TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3);
    900   TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy);
    901 
    902   VariableInst *GNum = mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input);
    903   GNum->decorate(Decoration::BuiltIn)
    904       ->addExtraOperand(static_cast<uint32_t>(BuiltIn::NumWorkgroups));
    905 
    906   mNumWorkgroups.reset(GNum);
    907 
    908   return GNum;
    909 }
    910 
    911 bool FunctionDeclaration::DeserializeInternal(InputWordStream &IS) {
    912   if (!(mFunc = Deserialize<FunctionInst>(IS))) {
    913     return false;
    914   }
    915 
    916   DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams);
    917 
    918   if (!(mFuncEnd = Deserialize<FunctionEndInst>(IS))) {
    919     return false;
    920   }
    921 
    922   return true;
    923 }
    924 
    925 template <> Instruction *Deserialize(InputWordStream &IS) {
    926   Instruction *inst;
    927 
    928   switch ((*IS) & 0xFFFF) {
    929 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS)                                 \
    930   case OPCODE:                                                                 \
    931     inst = Deserialize<INST_CLASS>(IS);                                        \
    932     break;
    933 #include "instruction_dispatches_generated.h"
    934 #undef HANDLE_INSTRUCTION
    935   default:
    936     Module::errs() << "unrecognized instruction";
    937     inst = nullptr;
    938   }
    939 
    940   return inst;
    941 }
    942 
    943 bool Block::DeserializeInternal(InputWordStream &IS) {
    944   Instruction *inst;
    945   while (((*IS) & 0xFFFF) != OpFunctionEnd &&
    946          (inst = Deserialize<Instruction>(IS))) {
    947     mInsts.push_back(inst);
    948     if (inst->getOpCode() == OpBranch ||
    949         inst->getOpCode() == OpBranchConditional ||
    950         inst->getOpCode() == OpSwitch || inst->getOpCode() == OpKill ||
    951         inst->getOpCode() == OpReturn || inst->getOpCode() == OpReturnValue ||
    952         inst->getOpCode() == OpUnreachable) {
    953       break;
    954     }
    955   }
    956   return !mInsts.empty();
    957 }
    958 
    959 FunctionDefinition::FunctionDefinition()
    960     : mParamsDeleter(mParams), mBlocksDeleter(mBlocks) {}
    961 
    962 FunctionDefinition::FunctionDefinition(Builder *builder, FunctionInst *func,
    963                                        FunctionEndInst *end)
    964     : Entity(builder), mFunc(func), mFuncEnd(end), mParamsDeleter(mParams),
    965       mBlocksDeleter(mBlocks) {}
    966 
    967 bool FunctionDefinition::DeserializeInternal(InputWordStream &IS) {
    968   mFunc.reset(Deserialize<FunctionInst>(IS));
    969   if (!mFunc) {
    970     return false;
    971   }
    972 
    973   DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams);
    974   DeserializeZeroOrMore<Block>(IS, mBlocks);
    975 
    976   mFuncEnd.reset(Deserialize<FunctionEndInst>(IS));
    977   if (!mFuncEnd) {
    978     return false;
    979   }
    980 
    981   return true;
    982 }
    983 
    984 Instruction *FunctionDefinition::getReturnType() const {
    985   return mFunc->mResultType.mInstruction;
    986 }
    987 
    988 } // namespace spirit
    989 } // namespace android
    990