1 /* 2 * Copyright 2017, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "module.h" 18 19 #include <set> 20 21 #include "builder.h" 22 #include "core_defs.h" 23 #include "instructions.h" 24 #include "types_generated.h" 25 #include "word_stream.h" 26 27 namespace android { 28 namespace spirit { 29 30 Module *Module::mInstance = nullptr; 31 32 Module *Module::getCurrentModule() { 33 if (mInstance == nullptr) { 34 return mInstance = new Module(); 35 } 36 return mInstance; 37 } 38 39 Module::Module() 40 : mNextId(1), mCapabilitiesDeleter(mCapabilities), 41 mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports), 42 mEntryPointInstsDeleter(mEntryPointInsts), 43 mExecutionModesDeleter(mExecutionModes), 44 mEntryPointsDeleter(mEntryPoints), 45 mFunctionDefinitionsDeleter(mFunctionDefinitions) { 46 mInstance = this; 47 } 48 49 Module::Module(Builder *b) 50 : Entity(b), mNextId(1), mCapabilitiesDeleter(mCapabilities), 51 mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports), 52 mEntryPointInstsDeleter(mEntryPointInsts), 53 mExecutionModesDeleter(mExecutionModes), 54 mEntryPointsDeleter(mEntryPoints), 55 mFunctionDefinitionsDeleter(mFunctionDefinitions) { 56 mInstance = this; 57 } 58 59 bool Module::resolveIds() { 60 auto &table = mIdTable; 61 62 std::unique_ptr<IVisitor> v0( 63 CreateInstructionVisitor([&table](Instruction *inst) { 64 if (inst->hasResult()) { 65 table.insert(std::make_pair(inst->getId(), inst)); 66 } 67 })); 68 v0->visit(this); 69 70 mNextId = mIdTable.rbegin()->first + 1; 71 72 int err = 0; 73 std::unique_ptr<IVisitor> v( 74 CreateInstructionVisitor([&table, &err](Instruction *inst) { 75 for (auto ref : inst->getAllIdRefs()) { 76 if (ref) { 77 auto it = table.find(ref->mId); 78 if (it != table.end()) { 79 ref->mInstruction = it->second; 80 } else { 81 std::cout << "Found no instruction for id " << ref->mId 82 << std::endl; 83 err++; 84 } 85 } 86 } 87 })); 88 v->visit(this); 89 return err == 0; 90 } 91 92 bool Module::DeserializeInternal(InputWordStream &IS) { 93 if (IS.empty()) { 94 return false; 95 } 96 97 IS >> &mMagicNumber; 98 if (mMagicNumber != 0x07230203) { 99 errs() << "Wrong Magic Number: " << mMagicNumber; 100 return false; 101 } 102 103 if (IS.empty()) { 104 return false; 105 } 106 107 IS >> &mVersion.mWord; 108 if (mVersion.mBytes[0] != 0 || mVersion.mBytes[3] != 0) { 109 return false; 110 } 111 112 if (IS.empty()) { 113 return false; 114 } 115 116 IS >> &mGeneratorMagicNumber >> &mBound >> &mReserved; 117 118 DeserializeZeroOrMore<CapabilityInst>(IS, mCapabilities); 119 DeserializeZeroOrMore<ExtensionInst>(IS, mExtensions); 120 DeserializeZeroOrMore<ExtInstImportInst>(IS, mExtInstImports); 121 122 mMemoryModel.reset(Deserialize<MemoryModelInst>(IS)); 123 if (!mMemoryModel) { 124 errs() << "Missing memory model specification.\n"; 125 return false; 126 } 127 128 DeserializeZeroOrMore<EntryPointDefinition>(IS, mEntryPoints); 129 DeserializeZeroOrMore<ExecutionModeInst>(IS, mExecutionModes); 130 for (auto entry : mEntryPoints) { 131 mEntryPointInsts.push_back(entry->getInstruction()); 132 for (auto mode : mExecutionModes) { 133 entry->applyExecutionMode(mode); 134 } 135 } 136 137 mDebugInfo.reset(Deserialize<DebugInfoSection>(IS)); 138 mAnnotations.reset(Deserialize<AnnotationSection>(IS)); 139 mGlobals.reset(Deserialize<GlobalSection>(IS)); 140 141 DeserializeZeroOrMore<FunctionDefinition>(IS, mFunctionDefinitions); 142 143 if (mFunctionDefinitions.empty()) { 144 errs() << "Missing function definitions.\n"; 145 for (int i = 0; i < 4; i++) { 146 uint32_t w; 147 IS >> &w; 148 std::cout << std::hex << w << " "; 149 } 150 std::cout << std::endl; 151 return false; 152 } 153 154 return true; 155 } 156 157 void Module::initialize() { 158 mMagicNumber = 0x07230203; 159 mVersion.mMajorMinor = {.mMinorNumber = 1, .mMajorNumber = 1}; 160 mGeneratorMagicNumber = 0x00070000; 161 mBound = 0; 162 mReserved = 0; 163 mAnnotations.reset(new AnnotationSection()); 164 } 165 166 void Module::SerializeHeader(OutputWordStream &OS) const { 167 OS << mMagicNumber; 168 OS << mVersion.mWord << mGeneratorMagicNumber; 169 if (mBound == 0) { 170 OS << mIdTable.end()->first + 1; 171 } else { 172 OS << std::max(mBound, mNextId); 173 } 174 OS << mReserved; 175 } 176 177 void Module::Serialize(OutputWordStream &OS) const { 178 SerializeHeader(OS); 179 Entity::Serialize(OS); 180 } 181 182 Module *Module::addCapability(Capability cap) { 183 mCapabilities.push_back(mBuilder->MakeCapability(cap)); 184 return this; 185 } 186 187 Module *Module::setMemoryModel(AddressingModel am, MemoryModel mm) { 188 mMemoryModel.reset(mBuilder->MakeMemoryModel(am, mm)); 189 return this; 190 } 191 192 Module *Module::addExtInstImport(const char *extName) { 193 ExtInstImportInst *extInst = mBuilder->MakeExtInstImport(extName); 194 mExtInstImports.push_back(extInst); 195 if (strcmp(extName, "GLSL.std.450") == 0) { 196 mGLExt = extInst; 197 } 198 return this; 199 } 200 201 Module *Module::addSource(SourceLanguage lang, int version) { 202 if (!mDebugInfo) { 203 mDebugInfo.reset(mBuilder->MakeDebugInfoSection()); 204 } 205 mDebugInfo->addSource(lang, version); 206 return this; 207 } 208 209 Module *Module::addSourceExtension(const char *ext) { 210 if (!mDebugInfo) { 211 mDebugInfo.reset(mBuilder->MakeDebugInfoSection()); 212 } 213 mDebugInfo->addSourceExtension(ext); 214 return this; 215 } 216 217 Module *Module::addString(const char *str) { 218 if (!mDebugInfo) { 219 mDebugInfo.reset(mBuilder->MakeDebugInfoSection()); 220 } 221 mDebugInfo->addString(str); 222 return this; 223 } 224 225 Module *Module::addEntryPoint(EntryPointDefinition *entry) { 226 mEntryPoints.push_back(entry); 227 auto newModes = entry->getExecutionModes(); 228 mExecutionModes.insert(mExecutionModes.end(), newModes.begin(), 229 newModes.end()); 230 return this; 231 } 232 233 GlobalSection *Module::getGlobalSection() { 234 if (!mGlobals) { 235 mGlobals.reset(new GlobalSection()); 236 } 237 return mGlobals.get(); 238 } 239 240 ConstantInst *Module::getConstant(TypeIntInst *type, int32_t value) { 241 return getGlobalSection()->getConstant(type, value); 242 } 243 244 ConstantInst *Module::getConstant(TypeIntInst *type, uint32_t value) { 245 return getGlobalSection()->getConstant(type, value); 246 } 247 248 ConstantInst *Module::getConstant(TypeFloatInst *type, float value) { 249 return getGlobalSection()->getConstant(type, value); 250 } 251 252 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type, 253 ConstantInst *components[], 254 size_t width) { 255 return getGlobalSection()->getConstantComposite(type, components, width); 256 } 257 258 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type, 259 ConstantInst *comp0, 260 ConstantInst *comp1, 261 ConstantInst *comp2) { 262 // TODO: verify that component types are the same and consistent with the 263 // resulting vector type 264 ConstantInst *comps[] = {comp0, comp1, comp2}; 265 return getConstantComposite(type, comps, 3); 266 } 267 268 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type, 269 ConstantInst *comp0, 270 ConstantInst *comp1, 271 ConstantInst *comp2, 272 ConstantInst *comp3) { 273 // TODO: verify that component types are the same and consistent with the 274 // resulting vector type 275 ConstantInst *comps[] = {comp0, comp1, comp2, comp3}; 276 return getConstantComposite(type, comps, 4); 277 } 278 279 TypeVoidInst *Module::getVoidType() { 280 return getGlobalSection()->getVoidType(); 281 } 282 283 TypeIntInst *Module::getIntType(int bits, bool isSigned) { 284 return getGlobalSection()->getIntType(bits, isSigned); 285 } 286 287 TypeIntInst *Module::getUnsignedIntType(int bits) { 288 return getIntType(bits, false); 289 } 290 291 TypeFloatInst *Module::getFloatType(int bits) { 292 return getGlobalSection()->getFloatType(bits); 293 } 294 295 TypeVectorInst *Module::getVectorType(Instruction *componentType, int width) { 296 return getGlobalSection()->getVectorType(componentType, width); 297 } 298 299 TypePointerInst *Module::getPointerType(StorageClass storage, 300 Instruction *pointeeType) { 301 return getGlobalSection()->getPointerType(storage, pointeeType); 302 } 303 304 TypeRuntimeArrayInst *Module::getRuntimeArrayType(Instruction *elementType) { 305 return getGlobalSection()->getRuntimeArrayType(elementType); 306 } 307 308 TypeStructInst *Module::getStructType(Instruction *fieldType[], int numField) { 309 return getGlobalSection()->getStructType(fieldType, numField); 310 } 311 312 TypeStructInst *Module::getStructType(Instruction *fieldType) { 313 return getStructType(&fieldType, 1); 314 } 315 316 TypeFunctionInst *Module::getFunctionType(Instruction *retType, 317 Instruction *const argType[], 318 size_t numArg) { 319 return getGlobalSection()->getFunctionType(retType, argType, numArg); 320 } 321 322 TypeFunctionInst * 323 Module::getFunctionType(Instruction *retType, 324 const std::vector<Instruction *> &argTypes) { 325 return getGlobalSection()->getFunctionType(retType, argTypes.data(), 326 argTypes.size()); 327 } 328 329 size_t Module::getSize(TypeVoidInst *) { return 0; } 330 331 size_t Module::getSize(TypeIntInst *intTy) { return intTy->mOperand1 / 8; } 332 333 size_t Module::getSize(TypeFloatInst *fpTy) { return fpTy->mOperand1 / 8; } 334 335 size_t Module::getSize(TypeVectorInst *vTy) { 336 return getSize(vTy->mOperand1.mInstruction) * vTy->mOperand2; 337 } 338 339 size_t Module::getSize(TypePointerInst *) { 340 return 4; // TODO: or 8? 341 } 342 343 size_t Module::getSize(TypeStructInst *structTy) { 344 size_t sz = 0; 345 for (auto ty : structTy->mOperand1) { 346 sz += getSize(ty.mInstruction); 347 } 348 return sz; 349 } 350 351 size_t Module::getSize(TypeFunctionInst *) { 352 return 4; // TODO: or 8? Is this just the size of a pointer? 353 } 354 355 size_t Module::getSize(Instruction *inst) { 356 switch (inst->getOpCode()) { 357 case OpTypeVoid: 358 return getSize(static_cast<TypeVoidInst *>(inst)); 359 case OpTypeInt: 360 return getSize(static_cast<TypeIntInst *>(inst)); 361 case OpTypeFloat: 362 return getSize(static_cast<TypeFloatInst *>(inst)); 363 case OpTypeVector: 364 return getSize(static_cast<TypeVectorInst *>(inst)); 365 case OpTypeStruct: 366 return getSize(static_cast<TypeStructInst *>(inst)); 367 case OpTypeFunction: 368 return getSize(static_cast<TypeFunctionInst *>(inst)); 369 default: 370 return 0; 371 } 372 } 373 374 Module *Module::addFunctionDefinition(FunctionDefinition *func) { 375 mFunctionDefinitions.push_back(func); 376 return this; 377 } 378 379 Instruction *Module::lookupByName(const char *name) const { 380 return mDebugInfo->lookupByName(name); 381 } 382 383 FunctionDefinition * 384 Module::getFunctionDefinitionFromInstruction(FunctionInst *inst) const { 385 for (auto fdef : mFunctionDefinitions) { 386 if (fdef->getInstruction() == inst) { 387 return fdef; 388 } 389 } 390 return nullptr; 391 } 392 393 FunctionDefinition * 394 Module::lookupFunctionDefinitionByName(const char *name) const { 395 FunctionInst *inst = static_cast<FunctionInst *>(lookupByName(name)); 396 return getFunctionDefinitionFromInstruction(inst); 397 } 398 399 const char *Module::lookupNameByInstruction(const Instruction *inst) const { 400 return mDebugInfo->lookupNameByInstruction(inst); 401 } 402 403 VariableInst *Module::getInvocationId() { 404 return getGlobalSection()->getInvocationId(); 405 } 406 407 VariableInst *Module::getNumWorkgroups() { 408 return getGlobalSection()->getNumWorkgroups(); 409 } 410 411 Module *Module::addStructType(TypeStructInst *structType) { 412 getGlobalSection()->addStructType(structType); 413 return this; 414 } 415 416 Module *Module::addVariable(VariableInst *var) { 417 getGlobalSection()->addVariable(var); 418 return this; 419 } 420 421 void Module::consolidateAnnotations() { 422 std::vector<Instruction *> annotations(mAnnotations->begin(), 423 mAnnotations->end()); 424 std::unique_ptr<IVisitor> v( 425 CreateInstructionVisitor([&annotations](Instruction *inst) -> void { 426 const auto &ann = inst->getAnnotations(); 427 annotations.insert(annotations.end(), ann.begin(), ann.end()); 428 })); 429 v->visit(this); 430 mAnnotations->clear(); 431 mAnnotations->addAnnotations(annotations.begin(), annotations.end()); 432 } 433 434 EntryPointDefinition::EntryPointDefinition(Builder *builder, 435 ExecutionModel execModel, 436 FunctionDefinition *func, 437 const char *name) 438 : Entity(builder), mFunction(func->getInstruction()), 439 mExecutionModel(execModel) { 440 mName = strndup(name, strlen(name)); 441 mEntryPointInst = mBuilder->MakeEntryPoint(execModel, mFunction, mName); 442 } 443 444 bool EntryPointDefinition::DeserializeInternal(InputWordStream &IS) { 445 if (IS.empty()) { 446 return false; 447 } 448 449 if ((mEntryPointInst = Deserialize<EntryPointInst>(IS))) { 450 return true; 451 } 452 453 return false; 454 } 455 456 EntryPointDefinition * 457 EntryPointDefinition::applyExecutionMode(ExecutionModeInst *mode) { 458 if (mode->mOperand1.mInstruction == mFunction) { 459 addExecutionMode(mode); 460 } 461 return this; 462 } 463 464 EntryPointDefinition *EntryPointDefinition::addToInterface(VariableInst *var) { 465 mInterface.push_back(var); 466 mEntryPointInst->mOperand4.push_back(var); 467 return this; 468 } 469 470 EntryPointDefinition *EntryPointDefinition::setLocalSize(uint32_t width, 471 uint32_t height, 472 uint32_t depth) { 473 mLocalSize.mWidth = width; 474 mLocalSize.mHeight = height; 475 mLocalSize.mDepth = depth; 476 477 auto mode = mBuilder->MakeExecutionMode(mFunction, ExecutionMode::LocalSize); 478 mode->addExtraOperand(width)->addExtraOperand(height)->addExtraOperand(depth); 479 480 addExecutionMode(mode); 481 482 return this; 483 } 484 485 bool DebugInfoSection::DeserializeInternal(InputWordStream &IS) { 486 while (true) { 487 if (auto str = Deserialize<StringInst>(IS)) { 488 mSources.push_back(str); 489 } else if (auto src = Deserialize<SourceInst>(IS)) { 490 mSources.push_back(src); 491 } else if (auto srcExt = Deserialize<SourceExtensionInst>(IS)) { 492 mSources.push_back(srcExt); 493 } else if (auto srcCont = Deserialize<SourceContinuedInst>(IS)) { 494 mSources.push_back(srcCont); 495 } else { 496 break; 497 } 498 } 499 500 while (true) { 501 if (auto name = Deserialize<NameInst>(IS)) { 502 mNames.push_back(name); 503 } else if (auto memName = Deserialize<MemberNameInst>(IS)) { 504 mNames.push_back(memName); 505 } else { 506 break; 507 } 508 } 509 510 return true; 511 } 512 513 DebugInfoSection *DebugInfoSection::addSource(SourceLanguage lang, 514 int version) { 515 SourceInst *source = mBuilder->MakeSource(lang, version); 516 mSources.push_back(source); 517 return this; 518 } 519 520 DebugInfoSection *DebugInfoSection::addSourceExtension(const char *ext) { 521 SourceExtensionInst *inst = mBuilder->MakeSourceExtension(ext); 522 mSources.push_back(inst); 523 return this; 524 } 525 526 DebugInfoSection *DebugInfoSection::addString(const char *str) { 527 StringInst *source = mBuilder->MakeString(str); 528 mSources.push_back(source); 529 return this; 530 } 531 532 Instruction *DebugInfoSection::lookupByName(const char *name) const { 533 for (auto inst : mNames) { 534 if (inst->getOpCode() == OpName) { 535 NameInst *nameInst = static_cast<NameInst *>(inst); 536 if (nameInst->mOperand2.compare(name) == 0) { 537 return nameInst->mOperand1.mInstruction; 538 } 539 } 540 // Ignore member names 541 } 542 return nullptr; 543 } 544 545 const char * 546 DebugInfoSection::lookupNameByInstruction(const Instruction *target) const { 547 for (auto inst : mNames) { 548 if (inst->getOpCode() == OpName) { 549 NameInst *nameInst = static_cast<NameInst *>(inst); 550 if (nameInst->mOperand1.mInstruction == target) { 551 return nameInst->mOperand2.c_str(); 552 } 553 } 554 // Ignore member names 555 } 556 return nullptr; 557 } 558 559 AnnotationSection::AnnotationSection() : mAnnotationsDeleter(mAnnotations) {} 560 561 AnnotationSection::AnnotationSection(Builder *b) 562 : Entity(b), mAnnotationsDeleter(mAnnotations) {} 563 564 bool AnnotationSection::DeserializeInternal(InputWordStream &IS) { 565 while (true) { 566 if (auto decor = Deserialize<DecorateInst>(IS)) { 567 mAnnotations.push_back(decor); 568 } else if (auto decor = Deserialize<MemberDecorateInst>(IS)) { 569 mAnnotations.push_back(decor); 570 } else if (auto decor = Deserialize<GroupDecorateInst>(IS)) { 571 mAnnotations.push_back(decor); 572 } else if (auto decor = Deserialize<GroupMemberDecorateInst>(IS)) { 573 mAnnotations.push_back(decor); 574 } else if (auto decor = Deserialize<DecorationGroupInst>(IS)) { 575 mAnnotations.push_back(decor); 576 } else { 577 break; 578 } 579 } 580 return true; 581 } 582 583 GlobalSection::GlobalSection() : mGlobalDefsDeleter(mGlobalDefs) {} 584 585 GlobalSection::GlobalSection(Builder *builder) 586 : Entity(builder), mGlobalDefsDeleter(mGlobalDefs) {} 587 588 namespace { 589 590 template <typename T> 591 T *findOrCreate(std::function<bool(T *)> criteria, std::function<T *()> factory, 592 std::vector<Instruction *> *globals) { 593 T *derived; 594 for (auto inst : *globals) { 595 if (inst->getOpCode() == T::mOpCode) { 596 T *derived = static_cast<T *>(inst); 597 if (criteria(derived)) { 598 return derived; 599 } 600 } 601 } 602 derived = factory(); 603 globals->push_back(derived); 604 return derived; 605 } 606 607 } // anonymous namespace 608 609 bool GlobalSection::DeserializeInternal(InputWordStream &IS) { 610 while (true) { 611 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS) \ 612 if (auto typeInst = Deserialize<INST_CLASS>(IS)) { \ 613 mGlobalDefs.push_back(typeInst); \ 614 continue; \ 615 } 616 #include "const_inst_dispatches_generated.h" 617 #include "type_inst_dispatches_generated.h" 618 #undef HANDLE_INSTRUCTION 619 620 if (auto globalInst = Deserialize<VariableInst>(IS)) { 621 // Check if this is function scoped 622 if (globalInst->mOperand1 == StorageClass::Function) { 623 Module::errs() << "warning: Variable (id = " << globalInst->mResult; 624 Module::errs() << ") has function scope in global section.\n"; 625 // Khronos LLVM-SPIRV convertor emits "Function" storage-class globals. 626 // As a workaround, accept such SPIR-V code here, and fix it up later 627 // in the rs2spirv compiler by correcting the storage class. 628 // In a stricter deserializer, such code should be rejected, and we 629 // should return false here. 630 } 631 mGlobalDefs.push_back(globalInst); 632 continue; 633 } 634 635 if (auto globalInst = Deserialize<UndefInst>(IS)) { 636 mGlobalDefs.push_back(globalInst); 637 continue; 638 } 639 break; 640 } 641 return true; 642 } 643 644 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, int32_t value) { 645 return findOrCreate<ConstantInst>( 646 [=](ConstantInst *c) { return c->mOperand1.intValue == value; }, 647 [=]() -> ConstantInst * { 648 LiteralContextDependentNumber cdn = {.intValue = value}; 649 return mBuilder->MakeConstant(type, cdn); 650 }, 651 &mGlobalDefs); 652 } 653 654 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, uint32_t value) { 655 return findOrCreate<ConstantInst>( 656 [=](ConstantInst *c) { return c->mOperand1.intValue == (int)value; }, 657 [=]() -> ConstantInst * { 658 LiteralContextDependentNumber cdn = {.intValue = (int)value}; 659 return mBuilder->MakeConstant(type, cdn); 660 }, 661 &mGlobalDefs); 662 } 663 664 ConstantInst *GlobalSection::getConstant(TypeFloatInst *type, float value) { 665 return findOrCreate<ConstantInst>( 666 [=](ConstantInst *c) { return c->mOperand1.floatValue == value; }, 667 [=]() -> ConstantInst * { 668 LiteralContextDependentNumber cdn = {.floatValue = value}; 669 return mBuilder->MakeConstant(type, cdn); 670 }, 671 &mGlobalDefs); 672 } 673 674 ConstantCompositeInst * 675 GlobalSection::getConstantComposite(TypeVectorInst *type, 676 ConstantInst *components[], size_t width) { 677 return findOrCreate<ConstantCompositeInst>( 678 [=](ConstantCompositeInst *c) { 679 if (c->mOperand1.size() != width) { 680 return false; 681 } 682 for (size_t i = 0; i < width; i++) { 683 if (c->mOperand1[i].mInstruction != components[i]) { 684 return false; 685 } 686 } 687 return true; 688 }, 689 [=]() -> ConstantCompositeInst * { 690 ConstantCompositeInst *c = mBuilder->MakeConstantComposite(type); 691 for (size_t i = 0; i < width; i++) { 692 c->mOperand1.push_back(components[i]); 693 } 694 return c; 695 }, 696 &mGlobalDefs); 697 } 698 699 TypeVoidInst *GlobalSection::getVoidType() { 700 return findOrCreate<TypeVoidInst>( 701 [=](TypeVoidInst *) -> bool { return true; }, 702 [=]() -> TypeVoidInst * { return mBuilder->MakeTypeVoid(); }, 703 &mGlobalDefs); 704 } 705 706 TypeIntInst *GlobalSection::getIntType(int bits, bool isSigned) { 707 if (isSigned) { 708 switch (bits) { 709 #define HANDLE_INT_SIZE(INT_TYPE, BITS, SIGNED) \ 710 case BITS: { \ 711 return findOrCreate<TypeIntInst>( \ 712 [=](TypeIntInst *intTy) -> bool { \ 713 return intTy->mOperand1 == BITS && intTy->mOperand2 == SIGNED; \ 714 }, \ 715 [=]() -> TypeIntInst * { \ 716 return mBuilder->MakeTypeInt(BITS, SIGNED); \ 717 }, \ 718 &mGlobalDefs); \ 719 } 720 HANDLE_INT_SIZE(Int, 8, 1); 721 HANDLE_INT_SIZE(Int, 16, 1); 722 HANDLE_INT_SIZE(Int, 32, 1); 723 HANDLE_INT_SIZE(Int, 64, 1); 724 default: 725 Module::errs() << "unexpected int type"; 726 } 727 } else { 728 switch (bits) { 729 HANDLE_INT_SIZE(UInt, 8, 0); 730 HANDLE_INT_SIZE(UInt, 16, 0); 731 HANDLE_INT_SIZE(UInt, 32, 0); 732 HANDLE_INT_SIZE(UInt, 64, 0); 733 default: 734 Module::errs() << "unexpected int type"; 735 } 736 } 737 #undef HANDLE_INT_SIZE 738 return nullptr; 739 } 740 741 TypeFloatInst *GlobalSection::getFloatType(int bits) { 742 switch (bits) { 743 #define HANDLE_FLOAT_SIZE(BITS) \ 744 case BITS: { \ 745 return findOrCreate<TypeFloatInst>( \ 746 [=](TypeFloatInst *floatTy) -> bool { \ 747 return floatTy->mOperand1 == BITS; \ 748 }, \ 749 [=]() -> TypeFloatInst * { return mBuilder->MakeTypeFloat(BITS); }, \ 750 &mGlobalDefs); \ 751 } 752 HANDLE_FLOAT_SIZE(16); 753 HANDLE_FLOAT_SIZE(32); 754 HANDLE_FLOAT_SIZE(64); 755 default: 756 Module::errs() << "unexpeced floating point type"; 757 } 758 #undef HANDLE_FLOAT_SIZE 759 return nullptr; 760 } 761 762 TypeVectorInst *GlobalSection::getVectorType(Instruction *componentType, 763 int width) { 764 // TODO: verify that componentType is basic numeric types 765 766 return findOrCreate<TypeVectorInst>( 767 [=](TypeVectorInst *vecTy) -> bool { 768 return vecTy->mOperand1.mInstruction == componentType && 769 vecTy->mOperand2 == width; 770 }, 771 [=]() -> TypeVectorInst * { 772 return mBuilder->MakeTypeVector(componentType, width); 773 }, 774 &mGlobalDefs); 775 } 776 777 TypePointerInst *GlobalSection::getPointerType(StorageClass storage, 778 Instruction *pointeeType) { 779 return findOrCreate<TypePointerInst>( 780 [=](TypePointerInst *type) -> bool { 781 return type->mOperand1 == storage && 782 type->mOperand2.mInstruction == pointeeType; 783 }, 784 [=]() -> TypePointerInst * { 785 return mBuilder->MakeTypePointer(storage, pointeeType); 786 }, 787 &mGlobalDefs); 788 } 789 790 TypeRuntimeArrayInst * 791 GlobalSection::getRuntimeArrayType(Instruction *elemType) { 792 return findOrCreate<TypeRuntimeArrayInst>( 793 [=](TypeRuntimeArrayInst * /*type*/) -> bool { 794 // return type->mOperand1.mInstruction == elemType; 795 return false; 796 }, 797 [=]() -> TypeRuntimeArrayInst * { 798 return mBuilder->MakeTypeRuntimeArray(elemType); 799 }, 800 &mGlobalDefs); 801 } 802 803 TypeStructInst *GlobalSection::getStructType(Instruction *fieldType[], 804 int numField) { 805 TypeStructInst *structTy = mBuilder->MakeTypeStruct(); 806 for (int i = 0; i < numField; i++) { 807 structTy->mOperand1.push_back(fieldType[i]); 808 } 809 mGlobalDefs.push_back(structTy); 810 return structTy; 811 } 812 813 TypeFunctionInst *GlobalSection::getFunctionType(Instruction *retType, 814 Instruction *const argType[], 815 size_t numArg) { 816 return findOrCreate<TypeFunctionInst>( 817 [=](TypeFunctionInst *type) -> bool { 818 if (type->mOperand1.mInstruction != retType || 819 type->mOperand2.size() != numArg) { 820 return false; 821 } 822 for (size_t i = 0; i < numArg; i++) { 823 if (type->mOperand2[i].mInstruction != argType[i]) { 824 return false; 825 } 826 } 827 return true; 828 }, 829 [=]() -> TypeFunctionInst * { 830 TypeFunctionInst *funcTy = mBuilder->MakeTypeFunction(retType); 831 for (size_t i = 0; i < numArg; i++) { 832 funcTy->mOperand2.push_back(argType[i]); 833 } 834 return funcTy; 835 }, 836 &mGlobalDefs); 837 } 838 839 GlobalSection *GlobalSection::addStructType(TypeStructInst *structType) { 840 mGlobalDefs.push_back(structType); 841 return this; 842 } 843 844 GlobalSection *GlobalSection::addVariable(VariableInst *var) { 845 mGlobalDefs.push_back(var); 846 return this; 847 } 848 849 VariableInst *GlobalSection::getInvocationId() { 850 if (mInvocationId) { 851 return mInvocationId.get(); 852 } 853 854 TypeIntInst *UIntTy = getIntType(32, false); 855 TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3); 856 TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy); 857 858 VariableInst *InvocationId = 859 mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input); 860 InvocationId->decorate(Decoration::BuiltIn) 861 ->addExtraOperand(static_cast<uint32_t>(BuiltIn::GlobalInvocationId)); 862 863 mInvocationId.reset(InvocationId); 864 865 return InvocationId; 866 } 867 868 VariableInst *GlobalSection::getNumWorkgroups() { 869 if (mNumWorkgroups) { 870 return mNumWorkgroups.get(); 871 } 872 873 TypeIntInst *UIntTy = getIntType(32, false); 874 TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3); 875 TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy); 876 877 VariableInst *GNum = mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input); 878 GNum->decorate(Decoration::BuiltIn) 879 ->addExtraOperand(static_cast<uint32_t>(BuiltIn::NumWorkgroups)); 880 881 mNumWorkgroups.reset(GNum); 882 883 return GNum; 884 } 885 886 bool FunctionDeclaration::DeserializeInternal(InputWordStream &IS) { 887 if (!Deserialize<FunctionInst>(IS)) { 888 return false; 889 } 890 891 DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams); 892 893 if (!Deserialize<FunctionEndInst>(IS)) { 894 return false; 895 } 896 897 return true; 898 } 899 900 template <> Instruction *Deserialize(InputWordStream &IS) { 901 Instruction *inst; 902 903 switch ((*IS) & 0xFFFF) { 904 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS) \ 905 case OPCODE: \ 906 inst = Deserialize<INST_CLASS>(IS); \ 907 break; 908 #include "instruction_dispatches_generated.h" 909 #undef HANDLE_INSTRUCTION 910 default: 911 Module::errs() << "unrecognized instruction"; 912 inst = nullptr; 913 } 914 915 return inst; 916 } 917 918 bool Block::DeserializeInternal(InputWordStream &IS) { 919 Instruction *inst; 920 while (((*IS) & 0xFFFF) != OpFunctionEnd && 921 (inst = Deserialize<Instruction>(IS))) { 922 mInsts.push_back(inst); 923 if (inst->getOpCode() == OpBranch || 924 inst->getOpCode() == OpBranchConditional || 925 inst->getOpCode() == OpSwitch || inst->getOpCode() == OpKill || 926 inst->getOpCode() == OpReturn || inst->getOpCode() == OpReturnValue || 927 inst->getOpCode() == OpUnreachable) { 928 break; 929 } 930 } 931 return !mInsts.empty(); 932 } 933 934 FunctionDefinition::FunctionDefinition() 935 : mParamsDeleter(mParams), mBlocksDeleter(mBlocks) {} 936 937 FunctionDefinition::FunctionDefinition(Builder *builder, FunctionInst *func, 938 FunctionEndInst *end) 939 : Entity(builder), mFunc(func), mFuncEnd(end), mParamsDeleter(mParams), 940 mBlocksDeleter(mBlocks) {} 941 942 bool FunctionDefinition::DeserializeInternal(InputWordStream &IS) { 943 mFunc.reset(Deserialize<FunctionInst>(IS)); 944 if (!mFunc) { 945 return false; 946 } 947 948 DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams); 949 DeserializeZeroOrMore<Block>(IS, mBlocks); 950 951 mFuncEnd.reset(Deserialize<FunctionEndInst>(IS)); 952 if (!mFuncEnd) { 953 return false; 954 } 955 956 return true; 957 } 958 959 Instruction *FunctionDefinition::getReturnType() const { 960 return mFunc->mResultType.mInstruction; 961 } 962 963 } // namespace spirit 964 } // namespace android 965