Home | History | Annotate | Download | only in compiler
      1 /*
      2  * Copyright 2017, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "Wrapper.h"
     18 
     19 #include "llvm/IR/Module.h"
     20 
     21 #include "Builtin.h"
     22 #include "Context.h"
     23 #include "GlobalAllocSPIRITPass.h"
     24 #include "RSAllocationUtils.h"
     25 #include "bcinfo/MetadataExtractor.h"
     26 #include "builder.h"
     27 #include "instructions.h"
     28 #include "module.h"
     29 #include "pass.h"
     30 
     31 #include <sstream>
     32 #include <vector>
     33 
     34 using bcinfo::MetadataExtractor;
     35 
     36 namespace android {
     37 namespace spirit {
     38 
     39 VariableInst *AddBuffer(Instruction *elementType, uint32_t binding, Builder &b,
     40                         Module *m) {
     41   auto ArrTy = m->getRuntimeArrayType(elementType);
     42   const size_t stride = m->getSize(elementType);
     43   ArrTy->decorate(Decoration::ArrayStride)->addExtraOperand(stride);
     44   auto StructTy = m->getStructType(ArrTy);
     45   StructTy->decorate(Decoration::BufferBlock);
     46   StructTy->memberDecorate(0, Decoration::Offset)->addExtraOperand(0);
     47 
     48   auto StructPtrTy = m->getPointerType(StorageClass::Uniform, StructTy);
     49 
     50   VariableInst *bufferVar = b.MakeVariable(StructPtrTy, StorageClass::Uniform);
     51   bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
     52   bufferVar->decorate(Decoration::Binding)->addExtraOperand(binding);
     53   m->addVariable(bufferVar);
     54 
     55   return bufferVar;
     56 }
     57 
     58 bool AddWrapper(const char *name, const uint32_t signature,
     59                 const uint32_t numInput, Builder &b, Module *m) {
     60   FunctionDefinition *kernel = m->lookupFunctionDefinitionByName(name);
     61   if (kernel == nullptr) {
     62     // In the metadata for RenderScript LLVM bitcode, the first foreach kernel
     63     // is always reserved for the root kernel, even though in the most recent RS
     64     // apps it does not exist. Simply bypass wrapper generation here, and return
     65     // true for this case.
     66     // Otherwise, if a non-root kernel function cannot be found, it is a
     67     // fatal internal error which is really unexpected.
     68     return (strncmp(name, "root", 4) == 0);
     69   }
     70 
     71   // The following three cases are not supported
     72   if (!MetadataExtractor::hasForEachSignatureKernel(signature)) {
     73     // Not handling old-style kernel
     74     return false;
     75   }
     76 
     77   if (MetadataExtractor::hasForEachSignatureUsrData(signature)) {
     78     // Not handling the user argument
     79     return false;
     80   }
     81 
     82   if (MetadataExtractor::hasForEachSignatureCtxt(signature)) {
     83     // Not handling the context argument
     84     return false;
     85   }
     86 
     87   TypeVoidInst *VoidTy = m->getVoidType();
     88   TypeFunctionInst *FuncTy = m->getFunctionType(VoidTy, nullptr, 0);
     89   FunctionDefinition *Func =
     90       b.MakeFunctionDefinition(VoidTy, FunctionControl::None, FuncTy);
     91   m->addFunctionDefinition(Func);
     92 
     93   Block *Blk = b.MakeBlock();
     94   Func->addBlock(Blk);
     95 
     96   Blk->addInstruction(b.MakeLabel());
     97 
     98   TypeIntInst *UIntTy = m->getUnsignedIntType(32);
     99 
    100   Instruction *XValue = nullptr;
    101   Instruction *YValue = nullptr;
    102   Instruction *ZValue = nullptr;
    103   Instruction *Index = nullptr;
    104   VariableInst *InvocationId = nullptr;
    105   VariableInst *NumWorkgroups = nullptr;
    106 
    107   if (MetadataExtractor::hasForEachSignatureIn(signature) ||
    108       MetadataExtractor::hasForEachSignatureOut(signature) ||
    109       MetadataExtractor::hasForEachSignatureX(signature) ||
    110       MetadataExtractor::hasForEachSignatureY(signature) ||
    111       MetadataExtractor::hasForEachSignatureZ(signature)) {
    112     TypeVectorInst *V3UIntTy = m->getVectorType(UIntTy, 3);
    113     InvocationId = m->getInvocationId();
    114     auto IID = b.MakeLoad(V3UIntTy, InvocationId);
    115     Blk->addInstruction(IID);
    116 
    117     XValue = b.MakeCompositeExtract(UIntTy, IID, {0});
    118     Blk->addInstruction(XValue);
    119 
    120     YValue = b.MakeCompositeExtract(UIntTy, IID, {1});
    121     Blk->addInstruction(YValue);
    122 
    123     ZValue = b.MakeCompositeExtract(UIntTy, IID, {2});
    124     Blk->addInstruction(ZValue);
    125 
    126     // TODO: Use SpecConstant for workgroup size
    127     auto ConstOne = m->getConstant(UIntTy, 1U);
    128     auto GroupSize =
    129         m->getConstantComposite(V3UIntTy, ConstOne, ConstOne, ConstOne);
    130 
    131     auto GroupSizeX = b.MakeCompositeExtract(UIntTy, GroupSize, {0});
    132     Blk->addInstruction(GroupSizeX);
    133 
    134     auto GroupSizeY = b.MakeCompositeExtract(UIntTy, GroupSize, {1});
    135     Blk->addInstruction(GroupSizeY);
    136 
    137     NumWorkgroups = m->getNumWorkgroups();
    138     auto NumGroup = b.MakeLoad(V3UIntTy, NumWorkgroups);
    139     Blk->addInstruction(NumGroup);
    140 
    141     auto NumGroupX = b.MakeCompositeExtract(UIntTy, NumGroup, {0});
    142     Blk->addInstruction(NumGroupX);
    143 
    144     auto NumGroupY = b.MakeCompositeExtract(UIntTy, NumGroup, {1});
    145     Blk->addInstruction(NumGroupY);
    146 
    147     auto GlobalSizeX = b.MakeIMul(UIntTy, GroupSizeX, NumGroupX);
    148     Blk->addInstruction(GlobalSizeX);
    149 
    150     auto GlobalSizeY = b.MakeIMul(UIntTy, GroupSizeY, NumGroupY);
    151     Blk->addInstruction(GlobalSizeY);
    152 
    153     auto RowsAlongZ = b.MakeIMul(UIntTy, GlobalSizeY, ZValue);
    154     Blk->addInstruction(RowsAlongZ);
    155 
    156     auto NumRows = b.MakeIAdd(UIntTy, YValue, RowsAlongZ);
    157     Blk->addInstruction(NumRows);
    158 
    159     auto NumCellsFromYZ = b.MakeIMul(UIntTy, GlobalSizeX, NumRows);
    160     Blk->addInstruction(NumCellsFromYZ);
    161 
    162     Index = b.MakeIAdd(UIntTy, NumCellsFromYZ, XValue);
    163     Blk->addInstruction(Index);
    164   }
    165 
    166   std::vector<IdRef> inputs;
    167 
    168   ConstantInst *ConstZero = m->getConstant(UIntTy, 0);
    169 
    170   for (uint32_t i = 0; i < numInput; i++) {
    171     FunctionParameterInst *param = kernel->getParameter(i);
    172     Instruction *elementType = param->mResultType.mInstruction;
    173     VariableInst *inputBuffer = AddBuffer(elementType, i + 3, b, m);
    174 
    175     TypePointerInst *PtrTy =
    176         m->getPointerType(StorageClass::Function, elementType);
    177     AccessChainInst *Ptr =
    178         b.MakeAccessChain(PtrTy, inputBuffer, {ConstZero, Index});
    179     Blk->addInstruction(Ptr);
    180 
    181     Instruction *input = b.MakeLoad(elementType, Ptr);
    182     Blk->addInstruction(input);
    183 
    184     inputs.push_back(IdRef(input));
    185   }
    186 
    187   // TODO: Convert from unsigned int to signed int if that is what the kernel
    188   // function takes for the coordinate parameters
    189   if (MetadataExtractor::hasForEachSignatureX(signature)) {
    190     inputs.push_back(XValue);
    191     if (MetadataExtractor::hasForEachSignatureY(signature)) {
    192       inputs.push_back(YValue);
    193       if (MetadataExtractor::hasForEachSignatureZ(signature)) {
    194         inputs.push_back(ZValue);
    195       }
    196     }
    197   }
    198 
    199   auto resultType = kernel->getReturnType();
    200   auto kernelCall =
    201       b.MakeFunctionCall(resultType, kernel->getInstruction(), inputs);
    202   Blk->addInstruction(kernelCall);
    203 
    204   if (MetadataExtractor::hasForEachSignatureOut(signature)) {
    205     VariableInst *OutputBuffer = AddBuffer(resultType, 2, b, m);
    206     auto resultPtrType = m->getPointerType(StorageClass::Function, resultType);
    207     AccessChainInst *OutPtr =
    208         b.MakeAccessChain(resultPtrType, OutputBuffer, {ConstZero, Index});
    209     Blk->addInstruction(OutPtr);
    210     Blk->addInstruction(b.MakeStore(OutPtr, kernelCall));
    211   }
    212 
    213   Blk->addInstruction(b.MakeReturn());
    214 
    215   std::string wrapperName("entry_");
    216   wrapperName.append(name);
    217 
    218   EntryPointDefinition *entry = b.MakeEntryPointDefinition(
    219       ExecutionModel::GLCompute, Func, wrapperName.c_str());
    220 
    221   entry->setLocalSize(1, 1, 1);
    222 
    223   if (Index != nullptr) {
    224     entry->addToInterface(InvocationId);
    225     entry->addToInterface(NumWorkgroups);
    226   }
    227 
    228   m->addEntryPoint(entry);
    229 
    230   return true;
    231 }
    232 
    233 bool DecorateGlobalBuffer(llvm::Module &LM, Builder &b, Module *m) {
    234   Instruction *inst = m->lookupByName("__GPUBlock");
    235   if (inst == nullptr) {
    236     return true;
    237   }
    238 
    239   VariableInst *bufferVar = static_cast<VariableInst *>(inst);
    240   bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
    241   bufferVar->decorate(Decoration::Binding)->addExtraOperand(0);
    242 
    243   TypePointerInst *StructPtrTy =
    244       static_cast<TypePointerInst *>(bufferVar->mResultType.mInstruction);
    245   TypeStructInst *StructTy =
    246       static_cast<TypeStructInst *>(StructPtrTy->mOperand2.mInstruction);
    247   StructTy->decorate(Decoration::BufferBlock);
    248 
    249   // Decorate each member with proper offsets
    250 
    251   const auto GlobalsB = LM.globals().begin();
    252   const auto GlobalsE = LM.globals().end();
    253   const auto Found =
    254       std::find_if(GlobalsB, GlobalsE, [](const llvm::GlobalVariable &GV) {
    255         return GV.getName() == "__GPUBlock";
    256       });
    257 
    258   if (Found == GlobalsE) {
    259     return true; // GPUBlock not found - not an error by itself.
    260   }
    261 
    262   const llvm::GlobalVariable &G = *Found;
    263 
    264   rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance();
    265   bool IsCorrectTy = false;
    266   if (const auto *LPtrTy = llvm::dyn_cast<llvm::PointerType>(G.getType())) {
    267     if (auto *LStructTy =
    268             llvm::dyn_cast<llvm::StructType>(LPtrTy->getElementType())) {
    269       IsCorrectTy = true;
    270 
    271       const auto &DLayout = LM.getDataLayout();
    272       const auto *SLayout = DLayout.getStructLayout(LStructTy);
    273       assert(SLayout);
    274       if (SLayout == nullptr) {
    275         std::cerr << "struct layout is null" << std::endl;
    276         return false;
    277       }
    278       std::vector<uint32_t> offsets;
    279       for (uint32_t i = 0, e = LStructTy->getNumElements(); i != e; ++i) {
    280         auto decor = StructTy->memberDecorate(i, Decoration::Offset);
    281         if (!decor) {
    282           std::cerr << "failed creating member decoration for field " << i
    283                     << std::endl;
    284           return false;
    285         }
    286         const uint32_t offset = (uint32_t)SLayout->getElementOffset(i);
    287         decor->addExtraOperand(offset);
    288         offsets.push_back(offset);
    289       }
    290       std::stringstream ssOffsets;
    291       // TODO: define this string in a central place
    292       ssOffsets << ".rsov.ExportedVars:";
    293       for(uint32_t slot = 0; slot < Ctxt.getNumExportVar(); slot++) {
    294         const uint32_t index = Ctxt.getExportVarIndex(slot);
    295         const uint32_t offset = offsets[index];
    296         ssOffsets << offset << ';';
    297       }
    298       m->addString(ssOffsets.str().c_str());
    299 
    300       std::stringstream ssGlobalSize;
    301       ssGlobalSize << ".rsov.GlobalSize:" << Ctxt.getGlobalSize();
    302       m->addString(ssGlobalSize.str().c_str());
    303     }
    304   }
    305 
    306   if (!IsCorrectTy) {
    307     return false;
    308   }
    309 
    310   llvm::SmallVector<rs2spirv::RSAllocationInfo, 2> RSAllocs;
    311   if (!getRSAllocationInfo(LM, RSAllocs)) {
    312     // llvm::errs() << "Extracting rs_allocation info failed\n";
    313     return true;
    314   }
    315 
    316   // TODO: clean up the binding number assignment
    317   size_t BindingNum = 3;
    318   for (const auto &A : RSAllocs) {
    319     Instruction *inst = m->lookupByName(A.VarName.c_str());
    320     if (inst == nullptr) {
    321       return false;
    322     }
    323     VariableInst *bufferVar = static_cast<VariableInst *>(inst);
    324     bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
    325     bufferVar->decorate(Decoration::Binding)->addExtraOperand(BindingNum++);
    326   }
    327 
    328   return true;
    329 }
    330 
    331 void AddHeader(Module *m) {
    332   m->addCapability(Capability::Shader);
    333   m->setMemoryModel(AddressingModel::Logical, MemoryModel::GLSL450);
    334 
    335   m->addSource(SourceLanguage::GLSL, 450);
    336   m->addSourceExtension("GL_ARB_separate_shader_objects");
    337   m->addSourceExtension("GL_ARB_shading_language_420pack");
    338   m->addSourceExtension("GL_GOOGLE_cpp_style_line_directive");
    339   m->addSourceExtension("GL_GOOGLE_include_directive");
    340 }
    341 
    342 namespace {
    343 
    344 class StorageClassVisitor : public DoNothingVisitor {
    345 public:
    346   void visit(TypePointerInst *inst) override {
    347     matchAndReplace(inst->mOperand1);
    348   }
    349 
    350   void visit(TypeForwardPointerInst *inst) override {
    351     matchAndReplace(inst->mOperand2);
    352   }
    353 
    354   void visit(VariableInst *inst) override { matchAndReplace(inst->mOperand1); }
    355 
    356 private:
    357   void matchAndReplace(StorageClass &storage) {
    358     if (storage == StorageClass::Function) {
    359       storage = StorageClass::Uniform;
    360     }
    361   }
    362 };
    363 
    364 void FixGlobalStorageClass(Module *m) {
    365   StorageClassVisitor v;
    366   m->getGlobalSection()->accept(&v);
    367 }
    368 
    369 } // anonymous namespace
    370 
    371 bool AddWrappers(llvm::Module &LM,
    372                  android::spirit::Module *m) {
    373   rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance();
    374   const bcinfo::MetadataExtractor &metadata = Ctxt.getMetadata();
    375   android::spirit::Builder b;
    376 
    377   m->setBuilder(&b);
    378 
    379   FixGlobalStorageClass(m);
    380 
    381   AddHeader(m);
    382 
    383   DecorateGlobalBuffer(LM, b, m);
    384 
    385   const size_t numKernel = metadata.getExportForEachSignatureCount();
    386   const char **kernelName = metadata.getExportForEachNameList();
    387   const uint32_t *kernelSigature = metadata.getExportForEachSignatureList();
    388   const uint32_t *inputCount = metadata.getExportForEachInputCountList();
    389 
    390   for (size_t i = 0; i < numKernel; i++) {
    391     bool success =
    392         AddWrapper(kernelName[i], kernelSigature[i], inputCount[i], b, m);
    393     if (!success) {
    394       return false;
    395     }
    396   }
    397 
    398   m->consolidateAnnotations();
    399   return true;
    400 }
    401 
    402 class WrapperPass : public Pass {
    403 public:
    404   WrapperPass(const llvm::Module &LM) : mLLVMModule(const_cast<llvm::Module&>(LM)) {}
    405 
    406   Module *run(Module *m, int *error) override {
    407     bool success = AddWrappers(mLLVMModule, m);
    408     if (error) {
    409       *error = success ? 0 : -1;
    410     }
    411     return m;
    412   }
    413 
    414 private:
    415   llvm::Module &mLLVMModule;
    416 };
    417 
    418 } // namespace spirit
    419 } // namespace android
    420 
    421 namespace rs2spirv {
    422 
    423 android::spirit::Pass* CreateWrapperPass(const llvm::Module &LLVMModule) {
    424   return new android::spirit::WrapperPass(LLVMModule);
    425 }
    426 
    427 } // namespace rs2spirv
    428