1 /* 2 * Copyright 2017, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "module.h" 18 19 #include <set> 20 21 #include "builder.h" 22 #include "core_defs.h" 23 #include "instructions.h" 24 #include "types_generated.h" 25 #include "word_stream.h" 26 27 namespace android { 28 namespace spirit { 29 30 Module *Module::mInstance = nullptr; 31 32 Module *Module::getCurrentModule() { 33 if (mInstance == nullptr) { 34 return mInstance = new Module(); 35 } 36 return mInstance; 37 } 38 39 Module::Module() 40 : mNextId(1), mCapabilitiesDeleter(mCapabilities), 41 mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports), 42 mEntryPointInstsDeleter(mEntryPointInsts), 43 mExecutionModesDeleter(mExecutionModes), 44 mEntryPointsDeleter(mEntryPoints), 45 mFunctionDefinitionsDeleter(mFunctionDefinitions) { 46 mInstance = this; 47 } 48 49 Module::Module(Builder *b) 50 : Entity(b), mNextId(1), mCapabilitiesDeleter(mCapabilities), 51 mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports), 52 mEntryPointInstsDeleter(mEntryPointInsts), 53 mExecutionModesDeleter(mExecutionModes), 54 mEntryPointsDeleter(mEntryPoints), 55 mFunctionDefinitionsDeleter(mFunctionDefinitions) { 56 mInstance = this; 57 } 58 59 bool Module::resolveIds() { 60 auto &table = mIdTable; 61 62 std::unique_ptr<IVisitor> v0( 63 CreateInstructionVisitor([&table](Instruction *inst) { 64 if (inst->hasResult()) { 65 table.insert(std::make_pair(inst->getId(), inst)); 66 } 67 })); 68 v0->visit(this); 69 70 mNextId = mIdTable.rbegin()->first + 1; 71 72 int err = 0; 73 std::unique_ptr<IVisitor> v( 74 CreateInstructionVisitor([&table, &err](Instruction *inst) { 75 for (auto ref : inst->getAllIdRefs()) { 76 if (ref) { 77 auto it = table.find(ref->mId); 78 if (it != table.end()) { 79 ref->mInstruction = it->second; 80 } else { 81 std::cout << "Found no instruction for id " << ref->mId 82 << std::endl; 83 err++; 84 } 85 } 86 } 87 })); 88 v->visit(this); 89 return err == 0; 90 } 91 92 bool Module::DeserializeInternal(InputWordStream &IS) { 93 if (IS.empty()) { 94 return false; 95 } 96 97 IS >> &mMagicNumber; 98 if (mMagicNumber != 0x07230203) { 99 errs() << "Wrong Magic Number: " << mMagicNumber; 100 return false; 101 } 102 103 if (IS.empty()) { 104 return false; 105 } 106 107 IS >> &mVersion.mWord; 108 if (mVersion.mBytes[0] != 0 || mVersion.mBytes[3] != 0) { 109 return false; 110 } 111 112 if (IS.empty()) { 113 return false; 114 } 115 116 IS >> &mGeneratorMagicNumber >> &mBound >> &mReserved; 117 118 DeserializeZeroOrMore<CapabilityInst>(IS, mCapabilities); 119 DeserializeZeroOrMore<ExtensionInst>(IS, mExtensions); 120 DeserializeZeroOrMore<ExtInstImportInst>(IS, mExtInstImports); 121 122 mMemoryModel.reset(Deserialize<MemoryModelInst>(IS)); 123 if (!mMemoryModel) { 124 errs() << "Missing memory model specification.\n"; 125 return false; 126 } 127 128 DeserializeZeroOrMore<EntryPointDefinition>(IS, mEntryPoints); 129 DeserializeZeroOrMore<ExecutionModeInst>(IS, mExecutionModes); 130 for (auto entry : mEntryPoints) { 131 mEntryPointInsts.push_back(entry->getInstruction()); 132 for (auto mode : mExecutionModes) { 133 entry->applyExecutionMode(mode); 134 } 135 } 136 137 mDebugInfo.reset(Deserialize<DebugInfoSection>(IS)); 138 mAnnotations.reset(Deserialize<AnnotationSection>(IS)); 139 mGlobals.reset(Deserialize<GlobalSection>(IS)); 140 141 DeserializeZeroOrMore<FunctionDefinition>(IS, mFunctionDefinitions); 142 143 if (mFunctionDefinitions.empty()) { 144 errs() << "Missing function definitions.\n"; 145 for (int i = 0; i < 4; i++) { 146 uint32_t w; 147 IS >> &w; 148 std::cout << std::hex << w << " "; 149 } 150 std::cout << std::endl; 151 return false; 152 } 153 154 return true; 155 } 156 157 void Module::initialize() { 158 mMagicNumber = 0x07230203; 159 mVersion.mMajorMinor = {.mMinorNumber = 1, .mMajorNumber = 1}; 160 mGeneratorMagicNumber = 0x00070000; 161 mBound = 0; 162 mReserved = 0; 163 mAnnotations.reset(new AnnotationSection()); 164 } 165 166 void Module::SerializeHeader(OutputWordStream &OS) const { 167 OS << mMagicNumber; 168 OS << mVersion.mWord << mGeneratorMagicNumber; 169 if (mBound == 0) { 170 OS << mIdTable.end()->first + 1; 171 } else { 172 OS << std::max(mBound, mNextId); 173 } 174 OS << mReserved; 175 } 176 177 void Module::Serialize(OutputWordStream &OS) const { 178 SerializeHeader(OS); 179 Entity::Serialize(OS); 180 } 181 182 Module *Module::addCapability(Capability cap) { 183 mCapabilities.push_back(mBuilder->MakeCapability(cap)); 184 return this; 185 } 186 187 Module *Module::setMemoryModel(AddressingModel am, MemoryModel mm) { 188 mMemoryModel.reset(mBuilder->MakeMemoryModel(am, mm)); 189 return this; 190 } 191 192 Module *Module::addExtInstImport(const char *extName) { 193 ExtInstImportInst *extInst = mBuilder->MakeExtInstImport(extName); 194 mExtInstImports.push_back(extInst); 195 if (strcmp(extName, "GLSL.std.450") == 0) { 196 mGLExt = extInst; 197 } 198 return this; 199 } 200 201 Module *Module::addSource(SourceLanguage lang, int version) { 202 if (!mDebugInfo) { 203 mDebugInfo.reset(mBuilder->MakeDebugInfoSection()); 204 } 205 mDebugInfo->addSource(lang, version); 206 return this; 207 } 208 209 Module *Module::addSourceExtension(const char *ext) { 210 if (!mDebugInfo) { 211 mDebugInfo.reset(mBuilder->MakeDebugInfoSection()); 212 } 213 mDebugInfo->addSourceExtension(ext); 214 return this; 215 } 216 217 Module *Module::addString(const char *str) { 218 if (!mDebugInfo) { 219 mDebugInfo.reset(mBuilder->MakeDebugInfoSection()); 220 } 221 mDebugInfo->addString(str); 222 return this; 223 } 224 225 Module *Module::addEntryPoint(EntryPointDefinition *entry) { 226 mEntryPoints.push_back(entry); 227 auto newModes = entry->getExecutionModes(); 228 mExecutionModes.insert(mExecutionModes.end(), newModes.begin(), 229 newModes.end()); 230 return this; 231 } 232 233 const std::string Module::findStringOfPrefix(const char *prefix) const { 234 if (!mDebugInfo) { 235 return std::string(); 236 } 237 return mDebugInfo->findStringOfPrefix(prefix); 238 } 239 240 GlobalSection *Module::getGlobalSection() { 241 if (!mGlobals) { 242 mGlobals.reset(new GlobalSection()); 243 } 244 return mGlobals.get(); 245 } 246 247 ConstantInst *Module::getConstant(TypeIntInst *type, int32_t value) { 248 return getGlobalSection()->getConstant(type, value); 249 } 250 251 ConstantInst *Module::getConstant(TypeIntInst *type, uint32_t value) { 252 return getGlobalSection()->getConstant(type, value); 253 } 254 255 ConstantInst *Module::getConstant(TypeFloatInst *type, float value) { 256 return getGlobalSection()->getConstant(type, value); 257 } 258 259 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type, 260 ConstantInst *components[], 261 size_t width) { 262 return getGlobalSection()->getConstantComposite(type, components, width); 263 } 264 265 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type, 266 ConstantInst *comp0, 267 ConstantInst *comp1, 268 ConstantInst *comp2) { 269 // TODO: verify that component types are the same and consistent with the 270 // resulting vector type 271 ConstantInst *comps[] = {comp0, comp1, comp2}; 272 return getConstantComposite(type, comps, 3); 273 } 274 275 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type, 276 ConstantInst *comp0, 277 ConstantInst *comp1, 278 ConstantInst *comp2, 279 ConstantInst *comp3) { 280 // TODO: verify that component types are the same and consistent with the 281 // resulting vector type 282 ConstantInst *comps[] = {comp0, comp1, comp2, comp3}; 283 return getConstantComposite(type, comps, 4); 284 } 285 286 TypeVoidInst *Module::getVoidType() { 287 return getGlobalSection()->getVoidType(); 288 } 289 290 TypeIntInst *Module::getIntType(int bits, bool isSigned) { 291 return getGlobalSection()->getIntType(bits, isSigned); 292 } 293 294 TypeIntInst *Module::getUnsignedIntType(int bits) { 295 return getIntType(bits, false); 296 } 297 298 TypeFloatInst *Module::getFloatType(int bits) { 299 return getGlobalSection()->getFloatType(bits); 300 } 301 302 TypeVectorInst *Module::getVectorType(Instruction *componentType, int width) { 303 return getGlobalSection()->getVectorType(componentType, width); 304 } 305 306 TypePointerInst *Module::getPointerType(StorageClass storage, 307 Instruction *pointeeType) { 308 return getGlobalSection()->getPointerType(storage, pointeeType); 309 } 310 311 TypeRuntimeArrayInst *Module::getRuntimeArrayType(Instruction *elementType) { 312 return getGlobalSection()->getRuntimeArrayType(elementType); 313 } 314 315 TypeStructInst *Module::getStructType(Instruction *fieldType[], int numField) { 316 return getGlobalSection()->getStructType(fieldType, numField); 317 } 318 319 TypeStructInst *Module::getStructType(Instruction *fieldType) { 320 return getStructType(&fieldType, 1); 321 } 322 323 TypeFunctionInst *Module::getFunctionType(Instruction *retType, 324 Instruction *const argType[], 325 size_t numArg) { 326 return getGlobalSection()->getFunctionType(retType, argType, numArg); 327 } 328 329 TypeFunctionInst * 330 Module::getFunctionType(Instruction *retType, 331 const std::vector<Instruction *> &argTypes) { 332 return getGlobalSection()->getFunctionType(retType, argTypes.data(), 333 argTypes.size()); 334 } 335 336 size_t Module::getSize(TypeVoidInst *) { return 0; } 337 338 size_t Module::getSize(TypeIntInst *intTy) { return intTy->mOperand1 / 8; } 339 340 size_t Module::getSize(TypeFloatInst *fpTy) { return fpTy->mOperand1 / 8; } 341 342 size_t Module::getSize(TypeVectorInst *vTy) { 343 return getSize(vTy->mOperand1.mInstruction) * vTy->mOperand2; 344 } 345 346 size_t Module::getSize(TypePointerInst *) { 347 return 4; // TODO: or 8? 348 } 349 350 size_t Module::getSize(TypeStructInst *structTy) { 351 size_t sz = 0; 352 for (auto ty : structTy->mOperand1) { 353 sz += getSize(ty.mInstruction); 354 } 355 return sz; 356 } 357 358 size_t Module::getSize(TypeFunctionInst *) { 359 return 4; // TODO: or 8? Is this just the size of a pointer? 360 } 361 362 size_t Module::getSize(Instruction *inst) { 363 switch (inst->getOpCode()) { 364 case OpTypeVoid: 365 return getSize(static_cast<TypeVoidInst *>(inst)); 366 case OpTypeInt: 367 return getSize(static_cast<TypeIntInst *>(inst)); 368 case OpTypeFloat: 369 return getSize(static_cast<TypeFloatInst *>(inst)); 370 case OpTypeVector: 371 return getSize(static_cast<TypeVectorInst *>(inst)); 372 case OpTypeStruct: 373 return getSize(static_cast<TypeStructInst *>(inst)); 374 case OpTypeFunction: 375 return getSize(static_cast<TypeFunctionInst *>(inst)); 376 default: 377 return 0; 378 } 379 } 380 381 Module *Module::addFunctionDefinition(FunctionDefinition *func) { 382 mFunctionDefinitions.push_back(func); 383 return this; 384 } 385 386 Instruction *Module::lookupByName(const char *name) const { 387 return mDebugInfo->lookupByName(name); 388 } 389 390 FunctionDefinition * 391 Module::getFunctionDefinitionFromInstruction(FunctionInst *inst) const { 392 for (auto fdef : mFunctionDefinitions) { 393 if (fdef->getInstruction() == inst) { 394 return fdef; 395 } 396 } 397 return nullptr; 398 } 399 400 FunctionDefinition * 401 Module::lookupFunctionDefinitionByName(const char *name) const { 402 FunctionInst *inst = static_cast<FunctionInst *>(lookupByName(name)); 403 return getFunctionDefinitionFromInstruction(inst); 404 } 405 406 const char *Module::lookupNameByInstruction(const Instruction *inst) const { 407 return mDebugInfo->lookupNameByInstruction(inst); 408 } 409 410 VariableInst *Module::getInvocationId() { 411 return getGlobalSection()->getInvocationId(); 412 } 413 414 VariableInst *Module::getNumWorkgroups() { 415 return getGlobalSection()->getNumWorkgroups(); 416 } 417 418 Module *Module::addStructType(TypeStructInst *structType) { 419 getGlobalSection()->addStructType(structType); 420 return this; 421 } 422 423 Module *Module::addVariable(VariableInst *var) { 424 getGlobalSection()->addVariable(var); 425 return this; 426 } 427 428 void Module::consolidateAnnotations() { 429 std::vector<Instruction *> annotations(mAnnotations->begin(), 430 mAnnotations->end()); 431 std::unique_ptr<IVisitor> v( 432 CreateInstructionVisitor([&annotations](Instruction *inst) -> void { 433 const auto &ann = inst->getAnnotations(); 434 annotations.insert(annotations.end(), ann.begin(), ann.end()); 435 })); 436 v->visit(this); 437 mAnnotations->clear(); 438 mAnnotations->addAnnotations(annotations.begin(), annotations.end()); 439 } 440 441 EntryPointDefinition::EntryPointDefinition(Builder *builder, 442 ExecutionModel execModel, 443 FunctionDefinition *func, 444 const char *name) 445 : Entity(builder), mFunction(func->getInstruction()), 446 mExecutionModel(execModel) { 447 mName = strndup(name, strlen(name)); 448 mEntryPointInst = mBuilder->MakeEntryPoint(execModel, mFunction, mName); 449 (void)mExecutionModel; // suppress unused private field warning 450 } 451 452 bool EntryPointDefinition::DeserializeInternal(InputWordStream &IS) { 453 if (IS.empty()) { 454 return false; 455 } 456 457 if ((mEntryPointInst = Deserialize<EntryPointInst>(IS))) { 458 return true; 459 } 460 461 return false; 462 } 463 464 EntryPointDefinition * 465 EntryPointDefinition::applyExecutionMode(ExecutionModeInst *mode) { 466 if (mode->mOperand1.mInstruction == mFunction) { 467 addExecutionMode(mode); 468 } 469 return this; 470 } 471 472 EntryPointDefinition *EntryPointDefinition::addToInterface(VariableInst *var) { 473 mInterface.push_back(var); 474 mEntryPointInst->mOperand4.push_back(var); 475 return this; 476 } 477 478 EntryPointDefinition *EntryPointDefinition::setLocalSize(uint32_t width, 479 uint32_t height, 480 uint32_t depth) { 481 mLocalSize.mWidth = width; 482 mLocalSize.mHeight = height; 483 mLocalSize.mDepth = depth; 484 485 auto mode = mBuilder->MakeExecutionMode(mFunction, ExecutionMode::LocalSize); 486 mode->addExtraOperand(width)->addExtraOperand(height)->addExtraOperand(depth); 487 488 addExecutionMode(mode); 489 490 return this; 491 } 492 493 bool DebugInfoSection::DeserializeInternal(InputWordStream &IS) { 494 while (true) { 495 if (auto str = Deserialize<StringInst>(IS)) { 496 mSources.push_back(str); 497 } else if (auto src = Deserialize<SourceInst>(IS)) { 498 mSources.push_back(src); 499 } else if (auto srcExt = Deserialize<SourceExtensionInst>(IS)) { 500 mSources.push_back(srcExt); 501 } else if (auto srcCont = Deserialize<SourceContinuedInst>(IS)) { 502 mSources.push_back(srcCont); 503 } else { 504 break; 505 } 506 } 507 508 while (true) { 509 if (auto name = Deserialize<NameInst>(IS)) { 510 mNames.push_back(name); 511 } else if (auto memName = Deserialize<MemberNameInst>(IS)) { 512 mNames.push_back(memName); 513 } else { 514 break; 515 } 516 } 517 518 return true; 519 } 520 521 DebugInfoSection *DebugInfoSection::addSource(SourceLanguage lang, 522 int version) { 523 SourceInst *source = mBuilder->MakeSource(lang, version); 524 mSources.push_back(source); 525 return this; 526 } 527 528 DebugInfoSection *DebugInfoSection::addSourceExtension(const char *ext) { 529 SourceExtensionInst *inst = mBuilder->MakeSourceExtension(ext); 530 mSources.push_back(inst); 531 return this; 532 } 533 534 DebugInfoSection *DebugInfoSection::addString(const char *str) { 535 StringInst *source = mBuilder->MakeString(str); 536 mSources.push_back(source); 537 return this; 538 } 539 540 std::string DebugInfoSection::findStringOfPrefix(const char *prefix) { 541 auto it = std::find_if( 542 mSources.begin(), mSources.end(), [prefix](Instruction *inst) -> bool { 543 if (inst->getOpCode() != OpString) { 544 return false; 545 } 546 const StringInst *strInst = static_cast<const StringInst *>(inst); 547 const std::string &str = strInst->mOperand1; 548 return str.find(prefix) == 0; 549 }); 550 if (it == mSources.end()) { 551 return ""; 552 } 553 StringInst *strInst = static_cast<StringInst *>(*it); 554 return strInst->mOperand1; 555 } 556 557 Instruction *DebugInfoSection::lookupByName(const char *name) const { 558 for (auto inst : mNames) { 559 if (inst->getOpCode() == OpName) { 560 NameInst *nameInst = static_cast<NameInst *>(inst); 561 if (nameInst->mOperand2.compare(name) == 0) { 562 return nameInst->mOperand1.mInstruction; 563 } 564 } 565 // Ignore member names 566 } 567 return nullptr; 568 } 569 570 const char * 571 DebugInfoSection::lookupNameByInstruction(const Instruction *target) const { 572 for (auto inst : mNames) { 573 if (inst->getOpCode() == OpName) { 574 NameInst *nameInst = static_cast<NameInst *>(inst); 575 if (nameInst->mOperand1.mInstruction == target) { 576 return nameInst->mOperand2.c_str(); 577 } 578 } 579 // Ignore member names 580 } 581 return nullptr; 582 } 583 584 AnnotationSection::AnnotationSection() : mAnnotationsDeleter(mAnnotations) {} 585 586 AnnotationSection::AnnotationSection(Builder *b) 587 : Entity(b), mAnnotationsDeleter(mAnnotations) {} 588 589 bool AnnotationSection::DeserializeInternal(InputWordStream &IS) { 590 while (true) { 591 if (auto decor = Deserialize<DecorateInst>(IS)) { 592 mAnnotations.push_back(decor); 593 } else if (auto decor = Deserialize<MemberDecorateInst>(IS)) { 594 mAnnotations.push_back(decor); 595 } else if (auto decor = Deserialize<GroupDecorateInst>(IS)) { 596 mAnnotations.push_back(decor); 597 } else if (auto decor = Deserialize<GroupMemberDecorateInst>(IS)) { 598 mAnnotations.push_back(decor); 599 } else if (auto decor = Deserialize<DecorationGroupInst>(IS)) { 600 mAnnotations.push_back(decor); 601 } else { 602 break; 603 } 604 } 605 return true; 606 } 607 608 GlobalSection::GlobalSection() : mGlobalDefsDeleter(mGlobalDefs) {} 609 610 GlobalSection::GlobalSection(Builder *builder) 611 : Entity(builder), mGlobalDefsDeleter(mGlobalDefs) {} 612 613 namespace { 614 615 template <typename T> 616 T *findOrCreate(std::function<bool(T *)> criteria, std::function<T *()> factory, 617 std::vector<Instruction *> *globals) { 618 T *derived; 619 for (auto inst : *globals) { 620 if (inst->getOpCode() == T::mOpCode) { 621 T *derived = static_cast<T *>(inst); 622 if (criteria(derived)) { 623 return derived; 624 } 625 } 626 } 627 derived = factory(); 628 globals->push_back(derived); 629 return derived; 630 } 631 632 } // anonymous namespace 633 634 bool GlobalSection::DeserializeInternal(InputWordStream &IS) { 635 while (true) { 636 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS) \ 637 if (auto typeInst = Deserialize<INST_CLASS>(IS)) { \ 638 mGlobalDefs.push_back(typeInst); \ 639 continue; \ 640 } 641 #include "const_inst_dispatches_generated.h" 642 #include "type_inst_dispatches_generated.h" 643 #undef HANDLE_INSTRUCTION 644 645 if (auto globalInst = Deserialize<VariableInst>(IS)) { 646 // Check if this is function scoped 647 if (globalInst->mOperand1 == StorageClass::Function) { 648 Module::errs() << "warning: Variable (id = " << globalInst->mResult; 649 Module::errs() << ") has function scope in global section.\n"; 650 // Khronos LLVM-SPIRV convertor emits "Function" storage-class globals. 651 // As a workaround, accept such SPIR-V code here, and fix it up later 652 // in the rs2spirv compiler by correcting the storage class. 653 // In a stricter deserializer, such code should be rejected, and we 654 // should return false here. 655 } 656 mGlobalDefs.push_back(globalInst); 657 continue; 658 } 659 660 if (auto globalInst = Deserialize<UndefInst>(IS)) { 661 mGlobalDefs.push_back(globalInst); 662 continue; 663 } 664 break; 665 } 666 return true; 667 } 668 669 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, int32_t value) { 670 return findOrCreate<ConstantInst>( 671 [=](ConstantInst *c) { return c->mOperand1.intValue == value; }, 672 [=]() -> ConstantInst * { 673 LiteralContextDependentNumber cdn = {.intValue = value}; 674 return mBuilder->MakeConstant(type, cdn); 675 }, 676 &mGlobalDefs); 677 } 678 679 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, uint32_t value) { 680 return findOrCreate<ConstantInst>( 681 [=](ConstantInst *c) { return c->mOperand1.intValue == (int)value; }, 682 [=]() -> ConstantInst * { 683 LiteralContextDependentNumber cdn = {.intValue = (int)value}; 684 return mBuilder->MakeConstant(type, cdn); 685 }, 686 &mGlobalDefs); 687 } 688 689 ConstantInst *GlobalSection::getConstant(TypeFloatInst *type, float value) { 690 return findOrCreate<ConstantInst>( 691 [=](ConstantInst *c) { return c->mOperand1.floatValue == value; }, 692 [=]() -> ConstantInst * { 693 LiteralContextDependentNumber cdn = {.floatValue = value}; 694 return mBuilder->MakeConstant(type, cdn); 695 }, 696 &mGlobalDefs); 697 } 698 699 ConstantCompositeInst * 700 GlobalSection::getConstantComposite(TypeVectorInst *type, 701 ConstantInst *components[], size_t width) { 702 return findOrCreate<ConstantCompositeInst>( 703 [=](ConstantCompositeInst *c) { 704 if (c->mOperand1.size() != width) { 705 return false; 706 } 707 for (size_t i = 0; i < width; i++) { 708 if (c->mOperand1[i].mInstruction != components[i]) { 709 return false; 710 } 711 } 712 return true; 713 }, 714 [=]() -> ConstantCompositeInst * { 715 ConstantCompositeInst *c = mBuilder->MakeConstantComposite(type); 716 for (size_t i = 0; i < width; i++) { 717 c->mOperand1.push_back(components[i]); 718 } 719 return c; 720 }, 721 &mGlobalDefs); 722 } 723 724 TypeVoidInst *GlobalSection::getVoidType() { 725 return findOrCreate<TypeVoidInst>( 726 [=](TypeVoidInst *) -> bool { return true; }, 727 [=]() -> TypeVoidInst * { return mBuilder->MakeTypeVoid(); }, 728 &mGlobalDefs); 729 } 730 731 TypeIntInst *GlobalSection::getIntType(int bits, bool isSigned) { 732 if (isSigned) { 733 switch (bits) { 734 #define HANDLE_INT_SIZE(INT_TYPE, BITS, SIGNED) \ 735 case BITS: { \ 736 return findOrCreate<TypeIntInst>( \ 737 [=](TypeIntInst *intTy) -> bool { \ 738 return intTy->mOperand1 == BITS && intTy->mOperand2 == SIGNED; \ 739 }, \ 740 [=]() -> TypeIntInst * { \ 741 return mBuilder->MakeTypeInt(BITS, SIGNED); \ 742 }, \ 743 &mGlobalDefs); \ 744 } 745 HANDLE_INT_SIZE(Int, 8, 1); 746 HANDLE_INT_SIZE(Int, 16, 1); 747 HANDLE_INT_SIZE(Int, 32, 1); 748 HANDLE_INT_SIZE(Int, 64, 1); 749 default: 750 Module::errs() << "unexpected int type"; 751 } 752 } else { 753 switch (bits) { 754 HANDLE_INT_SIZE(UInt, 8, 0); 755 HANDLE_INT_SIZE(UInt, 16, 0); 756 HANDLE_INT_SIZE(UInt, 32, 0); 757 HANDLE_INT_SIZE(UInt, 64, 0); 758 default: 759 Module::errs() << "unexpected int type"; 760 } 761 } 762 #undef HANDLE_INT_SIZE 763 return nullptr; 764 } 765 766 TypeFloatInst *GlobalSection::getFloatType(int bits) { 767 switch (bits) { 768 #define HANDLE_FLOAT_SIZE(BITS) \ 769 case BITS: { \ 770 return findOrCreate<TypeFloatInst>( \ 771 [=](TypeFloatInst *floatTy) -> bool { \ 772 return floatTy->mOperand1 == BITS; \ 773 }, \ 774 [=]() -> TypeFloatInst * { return mBuilder->MakeTypeFloat(BITS); }, \ 775 &mGlobalDefs); \ 776 } 777 HANDLE_FLOAT_SIZE(16); 778 HANDLE_FLOAT_SIZE(32); 779 HANDLE_FLOAT_SIZE(64); 780 default: 781 Module::errs() << "unexpeced floating point type"; 782 } 783 #undef HANDLE_FLOAT_SIZE 784 return nullptr; 785 } 786 787 TypeVectorInst *GlobalSection::getVectorType(Instruction *componentType, 788 int width) { 789 // TODO: verify that componentType is basic numeric types 790 791 return findOrCreate<TypeVectorInst>( 792 [=](TypeVectorInst *vecTy) -> bool { 793 return vecTy->mOperand1.mInstruction == componentType && 794 vecTy->mOperand2 == width; 795 }, 796 [=]() -> TypeVectorInst * { 797 return mBuilder->MakeTypeVector(componentType, width); 798 }, 799 &mGlobalDefs); 800 } 801 802 TypePointerInst *GlobalSection::getPointerType(StorageClass storage, 803 Instruction *pointeeType) { 804 return findOrCreate<TypePointerInst>( 805 [=](TypePointerInst *type) -> bool { 806 return type->mOperand1 == storage && 807 type->mOperand2.mInstruction == pointeeType; 808 }, 809 [=]() -> TypePointerInst * { 810 return mBuilder->MakeTypePointer(storage, pointeeType); 811 }, 812 &mGlobalDefs); 813 } 814 815 TypeRuntimeArrayInst * 816 GlobalSection::getRuntimeArrayType(Instruction *elemType) { 817 return findOrCreate<TypeRuntimeArrayInst>( 818 [=](TypeRuntimeArrayInst * /*type*/) -> bool { 819 // return type->mOperand1.mInstruction == elemType; 820 return false; 821 }, 822 [=]() -> TypeRuntimeArrayInst * { 823 return mBuilder->MakeTypeRuntimeArray(elemType); 824 }, 825 &mGlobalDefs); 826 } 827 828 TypeStructInst *GlobalSection::getStructType(Instruction *fieldType[], 829 int numField) { 830 TypeStructInst *structTy = mBuilder->MakeTypeStruct(); 831 for (int i = 0; i < numField; i++) { 832 structTy->mOperand1.push_back(fieldType[i]); 833 } 834 mGlobalDefs.push_back(structTy); 835 return structTy; 836 } 837 838 TypeFunctionInst *GlobalSection::getFunctionType(Instruction *retType, 839 Instruction *const argType[], 840 size_t numArg) { 841 return findOrCreate<TypeFunctionInst>( 842 [=](TypeFunctionInst *type) -> bool { 843 if (type->mOperand1.mInstruction != retType || 844 type->mOperand2.size() != numArg) { 845 return false; 846 } 847 for (size_t i = 0; i < numArg; i++) { 848 if (type->mOperand2[i].mInstruction != argType[i]) { 849 return false; 850 } 851 } 852 return true; 853 }, 854 [=]() -> TypeFunctionInst * { 855 TypeFunctionInst *funcTy = mBuilder->MakeTypeFunction(retType); 856 for (size_t i = 0; i < numArg; i++) { 857 funcTy->mOperand2.push_back(argType[i]); 858 } 859 return funcTy; 860 }, 861 &mGlobalDefs); 862 } 863 864 GlobalSection *GlobalSection::addStructType(TypeStructInst *structType) { 865 mGlobalDefs.push_back(structType); 866 return this; 867 } 868 869 GlobalSection *GlobalSection::addVariable(VariableInst *var) { 870 mGlobalDefs.push_back(var); 871 return this; 872 } 873 874 VariableInst *GlobalSection::getInvocationId() { 875 if (mInvocationId) { 876 return mInvocationId.get(); 877 } 878 879 TypeIntInst *UIntTy = getIntType(32, false); 880 TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3); 881 TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy); 882 883 VariableInst *InvocationId = 884 mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input); 885 InvocationId->decorate(Decoration::BuiltIn) 886 ->addExtraOperand(static_cast<uint32_t>(BuiltIn::GlobalInvocationId)); 887 888 mInvocationId.reset(InvocationId); 889 890 return InvocationId; 891 } 892 893 VariableInst *GlobalSection::getNumWorkgroups() { 894 if (mNumWorkgroups) { 895 return mNumWorkgroups.get(); 896 } 897 898 TypeIntInst *UIntTy = getIntType(32, false); 899 TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3); 900 TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy); 901 902 VariableInst *GNum = mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input); 903 GNum->decorate(Decoration::BuiltIn) 904 ->addExtraOperand(static_cast<uint32_t>(BuiltIn::NumWorkgroups)); 905 906 mNumWorkgroups.reset(GNum); 907 908 return GNum; 909 } 910 911 bool FunctionDeclaration::DeserializeInternal(InputWordStream &IS) { 912 if (!(mFunc = Deserialize<FunctionInst>(IS))) { 913 return false; 914 } 915 916 DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams); 917 918 if (!(mFuncEnd = Deserialize<FunctionEndInst>(IS))) { 919 return false; 920 } 921 922 return true; 923 } 924 925 template <> Instruction *Deserialize(InputWordStream &IS) { 926 Instruction *inst; 927 928 switch ((*IS) & 0xFFFF) { 929 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS) \ 930 case OPCODE: \ 931 inst = Deserialize<INST_CLASS>(IS); \ 932 break; 933 #include "instruction_dispatches_generated.h" 934 #undef HANDLE_INSTRUCTION 935 default: 936 Module::errs() << "unrecognized instruction"; 937 inst = nullptr; 938 } 939 940 return inst; 941 } 942 943 bool Block::DeserializeInternal(InputWordStream &IS) { 944 Instruction *inst; 945 while (((*IS) & 0xFFFF) != OpFunctionEnd && 946 (inst = Deserialize<Instruction>(IS))) { 947 mInsts.push_back(inst); 948 if (inst->getOpCode() == OpBranch || 949 inst->getOpCode() == OpBranchConditional || 950 inst->getOpCode() == OpSwitch || inst->getOpCode() == OpKill || 951 inst->getOpCode() == OpReturn || inst->getOpCode() == OpReturnValue || 952 inst->getOpCode() == OpUnreachable) { 953 break; 954 } 955 } 956 return !mInsts.empty(); 957 } 958 959 FunctionDefinition::FunctionDefinition() 960 : mParamsDeleter(mParams), mBlocksDeleter(mBlocks) {} 961 962 FunctionDefinition::FunctionDefinition(Builder *builder, FunctionInst *func, 963 FunctionEndInst *end) 964 : Entity(builder), mFunc(func), mFuncEnd(end), mParamsDeleter(mParams), 965 mBlocksDeleter(mBlocks) {} 966 967 bool FunctionDefinition::DeserializeInternal(InputWordStream &IS) { 968 mFunc.reset(Deserialize<FunctionInst>(IS)); 969 if (!mFunc) { 970 return false; 971 } 972 973 DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams); 974 DeserializeZeroOrMore<Block>(IS, mBlocks); 975 976 mFuncEnd.reset(Deserialize<FunctionEndInst>(IS)); 977 if (!mFuncEnd) { 978 return false; 979 } 980 981 return true; 982 } 983 984 Instruction *FunctionDefinition::getReturnType() const { 985 return mFunc->mResultType.mInstruction; 986 } 987 988 } // namespace spirit 989 } // namespace android 990