Home | History | Annotate | Download | only in compiler
      1 /*
      2  * Copyright 2017, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "Wrapper.h"
     18 
     19 #include "llvm/IR/Module.h"
     20 
     21 #include "Builtin.h"
     22 #include "Context.h"
     23 #include "GlobalAllocSPIRITPass.h"
     24 #include "RSAllocationUtils.h"
     25 #include "bcinfo/MetadataExtractor.h"
     26 #include "builder.h"
     27 #include "instructions.h"
     28 #include "module.h"
     29 #include "pass.h"
     30 
     31 #include <sstream>
     32 #include <vector>
     33 
     34 using bcinfo::MetadataExtractor;
     35 
     36 namespace android {
     37 namespace spirit {
     38 
     39 VariableInst *AddBuffer(Instruction *elementType, uint32_t binding, Builder &b,
     40                         Module *m) {
     41   auto ArrTy = m->getRuntimeArrayType(elementType);
     42   const size_t stride = m->getSize(elementType);
     43   ArrTy->decorate(Decoration::ArrayStride)->addExtraOperand(stride);
     44   auto StructTy = m->getStructType(ArrTy);
     45   StructTy->decorate(Decoration::BufferBlock);
     46   StructTy->memberDecorate(0, Decoration::Offset)->addExtraOperand(0);
     47 
     48   auto StructPtrTy = m->getPointerType(StorageClass::Uniform, StructTy);
     49 
     50   VariableInst *bufferVar = b.MakeVariable(StructPtrTy, StorageClass::Uniform);
     51   bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
     52   bufferVar->decorate(Decoration::Binding)->addExtraOperand(binding);
     53   m->addVariable(bufferVar);
     54 
     55   return bufferVar;
     56 }
     57 
     58 bool AddWrapper(const char *name, const uint32_t signature,
     59                 const uint32_t numInput, Builder &b, Module *m) {
     60   FunctionDefinition *kernel = m->lookupFunctionDefinitionByName(name);
     61   if (kernel == nullptr) {
     62     // In the metadata for RenderScript LLVM bitcode, the first foreach kernel
     63     // is always reserved for the root kernel, even though in the most recent RS
     64     // apps it does not exist. Simply bypass wrapper generation here, and return
     65     // true for this case.
     66     // Otherwise, if a non-root kernel function cannot be found, it is a
     67     // fatal internal error which is really unexpected.
     68     return (strncmp(name, "root", 4) == 0);
     69   }
     70 
     71   // The following three cases are not supported
     72   if (!MetadataExtractor::hasForEachSignatureKernel(signature)) {
     73     // Not handling old-style kernel
     74     return false;
     75   }
     76 
     77   if (MetadataExtractor::hasForEachSignatureUsrData(signature)) {
     78     // Not handling the user argument
     79     return false;
     80   }
     81 
     82   if (MetadataExtractor::hasForEachSignatureCtxt(signature)) {
     83     // Not handling the context argument
     84     return false;
     85   }
     86 
     87   TypeVoidInst *VoidTy = m->getVoidType();
     88   TypeFunctionInst *FuncTy = m->getFunctionType(VoidTy, nullptr, 0);
     89   FunctionDefinition *Func =
     90       b.MakeFunctionDefinition(VoidTy, FunctionControl::None, FuncTy);
     91   m->addFunctionDefinition(Func);
     92 
     93   Block *Blk = b.MakeBlock();
     94   Func->addBlock(Blk);
     95 
     96   Blk->addInstruction(b.MakeLabel());
     97 
     98   TypeIntInst *UIntTy = m->getUnsignedIntType(32);
     99 
    100   Instruction *XValue = nullptr;
    101   Instruction *YValue = nullptr;
    102   Instruction *ZValue = nullptr;
    103   Instruction *Index = nullptr;
    104   VariableInst *InvocationId = nullptr;
    105   VariableInst *NumWorkgroups = nullptr;
    106 
    107   if (MetadataExtractor::hasForEachSignatureIn(signature) ||
    108       MetadataExtractor::hasForEachSignatureOut(signature) ||
    109       MetadataExtractor::hasForEachSignatureX(signature) ||
    110       MetadataExtractor::hasForEachSignatureY(signature) ||
    111       MetadataExtractor::hasForEachSignatureZ(signature)) {
    112     TypeVectorInst *V3UIntTy = m->getVectorType(UIntTy, 3);
    113     InvocationId = m->getInvocationId();
    114     auto IID = b.MakeLoad(V3UIntTy, InvocationId);
    115     Blk->addInstruction(IID);
    116 
    117     XValue = b.MakeCompositeExtract(UIntTy, IID, {0});
    118     Blk->addInstruction(XValue);
    119 
    120     YValue = b.MakeCompositeExtract(UIntTy, IID, {1});
    121     Blk->addInstruction(YValue);
    122 
    123     ZValue = b.MakeCompositeExtract(UIntTy, IID, {2});
    124     Blk->addInstruction(ZValue);
    125 
    126     // TODO: Use SpecConstant for workgroup size
    127     auto ConstOne = m->getConstant(UIntTy, 1U);
    128     auto GroupSize =
    129         m->getConstantComposite(V3UIntTy, ConstOne, ConstOne, ConstOne);
    130 
    131     auto GroupSizeX = b.MakeCompositeExtract(UIntTy, GroupSize, {0});
    132     Blk->addInstruction(GroupSizeX);
    133 
    134     auto GroupSizeY = b.MakeCompositeExtract(UIntTy, GroupSize, {1});
    135     Blk->addInstruction(GroupSizeY);
    136 
    137     NumWorkgroups = m->getNumWorkgroups();
    138     auto NumGroup = b.MakeLoad(V3UIntTy, NumWorkgroups);
    139     Blk->addInstruction(NumGroup);
    140 
    141     auto NumGroupX = b.MakeCompositeExtract(UIntTy, NumGroup, {0});
    142     Blk->addInstruction(NumGroupX);
    143 
    144     auto NumGroupY = b.MakeCompositeExtract(UIntTy, NumGroup, {1});
    145     Blk->addInstruction(NumGroupY);
    146 
    147     auto GlobalSizeX = b.MakeIMul(UIntTy, GroupSizeX, NumGroupX);
    148     Blk->addInstruction(GlobalSizeX);
    149 
    150     auto GlobalSizeY = b.MakeIMul(UIntTy, GroupSizeY, NumGroupY);
    151     Blk->addInstruction(GlobalSizeY);
    152 
    153     auto RowsAlongZ = b.MakeIMul(UIntTy, GlobalSizeY, ZValue);
    154     Blk->addInstruction(RowsAlongZ);
    155 
    156     auto NumRows = b.MakeIAdd(UIntTy, YValue, RowsAlongZ);
    157     Blk->addInstruction(NumRows);
    158 
    159     auto NumCellsFromYZ = b.MakeIMul(UIntTy, GlobalSizeX, NumRows);
    160     Blk->addInstruction(NumCellsFromYZ);
    161 
    162     Index = b.MakeIAdd(UIntTy, NumCellsFromYZ, XValue);
    163     Blk->addInstruction(Index);
    164   }
    165 
    166   std::vector<IdRef> inputs;
    167 
    168   ConstantInst *ConstZero = m->getConstant(UIntTy, 0);
    169 
    170   for (uint32_t i = 0; i < numInput; i++) {
    171     FunctionParameterInst *param = kernel->getParameter(i);
    172     Instruction *elementType = param->mResultType.mInstruction;
    173     VariableInst *inputBuffer = AddBuffer(elementType, i + 2, b, m);
    174 
    175     TypePointerInst *PtrTy =
    176         m->getPointerType(StorageClass::Function, elementType);
    177     AccessChainInst *Ptr =
    178         b.MakeAccessChain(PtrTy, inputBuffer, {ConstZero, Index});
    179     Blk->addInstruction(Ptr);
    180 
    181     Instruction *input = b.MakeLoad(elementType, Ptr);
    182     Blk->addInstruction(input);
    183 
    184     inputs.push_back(IdRef(input));
    185   }
    186 
    187   // TODO: Convert from unsigned int to signed int if that is what the kernel
    188   // function takes for the coordinate parameters
    189   if (MetadataExtractor::hasForEachSignatureX(signature)) {
    190     inputs.push_back(XValue);
    191     if (MetadataExtractor::hasForEachSignatureY(signature)) {
    192       inputs.push_back(YValue);
    193       if (MetadataExtractor::hasForEachSignatureZ(signature)) {
    194         inputs.push_back(ZValue);
    195       }
    196     }
    197   }
    198 
    199   auto resultType = kernel->getReturnType();
    200   auto kernelCall =
    201       b.MakeFunctionCall(resultType, kernel->getInstruction(), inputs);
    202   Blk->addInstruction(kernelCall);
    203 
    204   if (MetadataExtractor::hasForEachSignatureOut(signature)) {
    205     VariableInst *OutputBuffer = AddBuffer(resultType, 1, b, m);
    206     auto resultPtrType = m->getPointerType(StorageClass::Function, resultType);
    207     AccessChainInst *OutPtr =
    208         b.MakeAccessChain(resultPtrType, OutputBuffer, {ConstZero, Index});
    209     Blk->addInstruction(OutPtr);
    210     Blk->addInstruction(b.MakeStore(OutPtr, kernelCall));
    211   }
    212 
    213   Blk->addInstruction(b.MakeReturn());
    214 
    215   std::string wrapperName("entry_");
    216   wrapperName.append(name);
    217 
    218   EntryPointDefinition *entry = b.MakeEntryPointDefinition(
    219       ExecutionModel::GLCompute, Func, wrapperName.c_str());
    220 
    221   entry->setLocalSize(1, 1, 1);
    222 
    223   if (Index != nullptr) {
    224     entry->addToInterface(InvocationId);
    225     entry->addToInterface(NumWorkgroups);
    226   }
    227 
    228   m->addEntryPoint(entry);
    229 
    230   return true;
    231 }
    232 
    233 bool DecorateGlobalBuffer(llvm::Module &LM, Builder &b, Module *m) {
    234   Instruction *inst = m->lookupByName("__GPUBlock");
    235   if (inst == nullptr) {
    236     return true;
    237   }
    238 
    239   VariableInst *bufferVar = static_cast<VariableInst *>(inst);
    240   bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
    241   bufferVar->decorate(Decoration::Binding)->addExtraOperand(0);
    242 
    243   TypePointerInst *StructPtrTy =
    244       static_cast<TypePointerInst *>(bufferVar->mResultType.mInstruction);
    245   TypeStructInst *StructTy =
    246       static_cast<TypeStructInst *>(StructPtrTy->mOperand2.mInstruction);
    247   StructTy->decorate(Decoration::BufferBlock);
    248 
    249   // Decorate each member with proper offsets
    250 
    251   const auto GlobalsB = LM.globals().begin();
    252   const auto GlobalsE = LM.globals().end();
    253   const auto Found =
    254       std::find_if(GlobalsB, GlobalsE, [](const llvm::GlobalVariable &GV) {
    255         return GV.getName() == "__GPUBlock";
    256       });
    257 
    258   if (Found == GlobalsE) {
    259     return true; // GPUBlock not found - not an error by itself.
    260   }
    261 
    262   const llvm::GlobalVariable &G = *Found;
    263 
    264   bool IsCorrectTy = false;
    265   if (const auto *LPtrTy = llvm::dyn_cast<llvm::PointerType>(G.getType())) {
    266     if (auto *LStructTy =
    267             llvm::dyn_cast<llvm::StructType>(LPtrTy->getElementType())) {
    268       IsCorrectTy = true;
    269 
    270       const auto &DLayout = LM.getDataLayout();
    271       const auto *SLayout = DLayout.getStructLayout(LStructTy);
    272       assert(SLayout);
    273       if (SLayout == nullptr) {
    274         std::cerr << "struct layout is null" << std::endl;
    275         return false;
    276       }
    277       for (uint32_t i = 0, e = LStructTy->getNumElements(); i != e; ++i) {
    278         auto decor = StructTy->memberDecorate(i, Decoration::Offset);
    279         if (!decor) {
    280           std::cerr << "failed creating member decoration for field " << i
    281                     << std::endl;
    282           return false;
    283         }
    284         const uint32_t offset = (uint32_t)SLayout->getElementOffset(i);
    285         decor->addExtraOperand(offset);
    286       }
    287     }
    288   }
    289 
    290   if (!IsCorrectTy) {
    291     return false;
    292   }
    293 
    294   llvm::SmallVector<rs2spirv::RSAllocationInfo, 2> RSAllocs;
    295   if (!getRSAllocationInfo(LM, RSAllocs)) {
    296     // llvm::errs() << "Extracting rs_allocation info failed\n";
    297     return true;
    298   }
    299 
    300   // TODO: clean up the binding number assignment
    301   size_t BindingNum = 3;
    302   for (const auto &A : RSAllocs) {
    303     Instruction *inst = m->lookupByName(A.VarName.c_str());
    304     if (inst == nullptr) {
    305       return false;
    306     }
    307     VariableInst *bufferVar = static_cast<VariableInst *>(inst);
    308     bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
    309     bufferVar->decorate(Decoration::Binding)->addExtraOperand(BindingNum++);
    310   }
    311 
    312   return true;
    313 }
    314 
    315 void AddHeader(Module *m) {
    316   m->addCapability(Capability::Shader);
    317   // TODO: avoid duplicated capability
    318   // m->addCapability(Capability::Addresses);
    319   m->setMemoryModel(AddressingModel::Physical32, MemoryModel::GLSL450);
    320 
    321   m->addSource(SourceLanguage::GLSL, 450);
    322   m->addSourceExtension("GL_ARB_separate_shader_objects");
    323   m->addSourceExtension("GL_ARB_shading_language_420pack");
    324   m->addSourceExtension("GL_GOOGLE_cpp_style_line_directive");
    325   m->addSourceExtension("GL_GOOGLE_include_directive");
    326 }
    327 
    328 namespace {
    329 
    330 class StorageClassVisitor : public DoNothingVisitor {
    331 public:
    332   void visit(TypePointerInst *inst) override {
    333     matchAndReplace(inst->mOperand1);
    334   }
    335 
    336   void visit(TypeForwardPointerInst *inst) override {
    337     matchAndReplace(inst->mOperand2);
    338   }
    339 
    340   void visit(VariableInst *inst) override { matchAndReplace(inst->mOperand1); }
    341 
    342 private:
    343   void matchAndReplace(StorageClass &storage) {
    344     if (storage == StorageClass::Function) {
    345       storage = StorageClass::Uniform;
    346     }
    347   }
    348 };
    349 
    350 void FixGlobalStorageClass(Module *m) {
    351   StorageClassVisitor v;
    352   m->getGlobalSection()->accept(&v);
    353 }
    354 
    355 } // anonymous namespace
    356 
    357 bool AddWrappers(llvm::Module &LM,
    358                  android::spirit::Module *m) {
    359   rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance();
    360   const bcinfo::MetadataExtractor &metadata = Ctxt.getMetadata();
    361   android::spirit::Builder b;
    362 
    363   m->setBuilder(&b);
    364 
    365   FixGlobalStorageClass(m);
    366 
    367   AddHeader(m);
    368 
    369   DecorateGlobalBuffer(LM, b, m);
    370 
    371   const size_t numKernel = metadata.getExportForEachSignatureCount();
    372   const char **kernelName = metadata.getExportForEachNameList();
    373   const uint32_t *kernelSigature = metadata.getExportForEachSignatureList();
    374   const uint32_t *inputCount = metadata.getExportForEachInputCountList();
    375 
    376   for (size_t i = 0; i < numKernel; i++) {
    377     bool success =
    378         AddWrapper(kernelName[i], kernelSigature[i], inputCount[i], b, m);
    379     if (!success) {
    380       return false;
    381     }
    382   }
    383 
    384   m->consolidateAnnotations();
    385   return true;
    386 }
    387 
    388 class WrapperPass : public Pass {
    389 public:
    390   WrapperPass(const llvm::Module &LM) : mLLVMModule(const_cast<llvm::Module&>(LM)) {}
    391 
    392   Module *run(Module *m, int *error) override {
    393     bool success = AddWrappers(mLLVMModule, m);
    394     if (error) {
    395       *error = success ? 0 : -1;
    396     }
    397     return m;
    398   }
    399 
    400 private:
    401   llvm::Module &mLLVMModule;
    402 };
    403 
    404 } // namespace spirit
    405 } // namespace android
    406 
    407 namespace rs2spirv {
    408 
    409 android::spirit::Pass* CreateWrapperPass(const llvm::Module &LLVMModule) {
    410   return new android::spirit::WrapperPass(LLVMModule);
    411 }
    412 
    413 } // namespace rs2spirv
    414