1 /* 2 * Copyright 2017, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "Wrapper.h" 18 19 #include "llvm/IR/Module.h" 20 21 #include "Builtin.h" 22 #include "Context.h" 23 #include "GlobalAllocSPIRITPass.h" 24 #include "RSAllocationUtils.h" 25 #include "bcinfo/MetadataExtractor.h" 26 #include "builder.h" 27 #include "instructions.h" 28 #include "module.h" 29 #include "pass.h" 30 31 #include <sstream> 32 #include <vector> 33 34 using bcinfo::MetadataExtractor; 35 36 namespace android { 37 namespace spirit { 38 39 VariableInst *AddBuffer(Instruction *elementType, uint32_t binding, Builder &b, 40 Module *m) { 41 auto ArrTy = m->getRuntimeArrayType(elementType); 42 const size_t stride = m->getSize(elementType); 43 ArrTy->decorate(Decoration::ArrayStride)->addExtraOperand(stride); 44 auto StructTy = m->getStructType(ArrTy); 45 StructTy->decorate(Decoration::BufferBlock); 46 StructTy->memberDecorate(0, Decoration::Offset)->addExtraOperand(0); 47 48 auto StructPtrTy = m->getPointerType(StorageClass::Uniform, StructTy); 49 50 VariableInst *bufferVar = b.MakeVariable(StructPtrTy, StorageClass::Uniform); 51 bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0); 52 bufferVar->decorate(Decoration::Binding)->addExtraOperand(binding); 53 m->addVariable(bufferVar); 54 55 return bufferVar; 56 } 57 58 bool AddWrapper(const char *name, const uint32_t signature, 59 const uint32_t numInput, Builder &b, Module *m) { 60 FunctionDefinition *kernel = m->lookupFunctionDefinitionByName(name); 61 if (kernel == nullptr) { 62 // In the metadata for RenderScript LLVM bitcode, the first foreach kernel 63 // is always reserved for the root kernel, even though in the most recent RS 64 // apps it does not exist. Simply bypass wrapper generation here, and return 65 // true for this case. 66 // Otherwise, if a non-root kernel function cannot be found, it is a 67 // fatal internal error which is really unexpected. 68 return (strncmp(name, "root", 4) == 0); 69 } 70 71 // The following three cases are not supported 72 if (!MetadataExtractor::hasForEachSignatureKernel(signature)) { 73 // Not handling old-style kernel 74 return false; 75 } 76 77 if (MetadataExtractor::hasForEachSignatureUsrData(signature)) { 78 // Not handling the user argument 79 return false; 80 } 81 82 if (MetadataExtractor::hasForEachSignatureCtxt(signature)) { 83 // Not handling the context argument 84 return false; 85 } 86 87 TypeVoidInst *VoidTy = m->getVoidType(); 88 TypeFunctionInst *FuncTy = m->getFunctionType(VoidTy, nullptr, 0); 89 FunctionDefinition *Func = 90 b.MakeFunctionDefinition(VoidTy, FunctionControl::None, FuncTy); 91 m->addFunctionDefinition(Func); 92 93 Block *Blk = b.MakeBlock(); 94 Func->addBlock(Blk); 95 96 Blk->addInstruction(b.MakeLabel()); 97 98 TypeIntInst *UIntTy = m->getUnsignedIntType(32); 99 100 Instruction *XValue = nullptr; 101 Instruction *YValue = nullptr; 102 Instruction *ZValue = nullptr; 103 Instruction *Index = nullptr; 104 VariableInst *InvocationId = nullptr; 105 VariableInst *NumWorkgroups = nullptr; 106 107 if (MetadataExtractor::hasForEachSignatureIn(signature) || 108 MetadataExtractor::hasForEachSignatureOut(signature) || 109 MetadataExtractor::hasForEachSignatureX(signature) || 110 MetadataExtractor::hasForEachSignatureY(signature) || 111 MetadataExtractor::hasForEachSignatureZ(signature)) { 112 TypeVectorInst *V3UIntTy = m->getVectorType(UIntTy, 3); 113 InvocationId = m->getInvocationId(); 114 auto IID = b.MakeLoad(V3UIntTy, InvocationId); 115 Blk->addInstruction(IID); 116 117 XValue = b.MakeCompositeExtract(UIntTy, IID, {0}); 118 Blk->addInstruction(XValue); 119 120 YValue = b.MakeCompositeExtract(UIntTy, IID, {1}); 121 Blk->addInstruction(YValue); 122 123 ZValue = b.MakeCompositeExtract(UIntTy, IID, {2}); 124 Blk->addInstruction(ZValue); 125 126 // TODO: Use SpecConstant for workgroup size 127 auto ConstOne = m->getConstant(UIntTy, 1U); 128 auto GroupSize = 129 m->getConstantComposite(V3UIntTy, ConstOne, ConstOne, ConstOne); 130 131 auto GroupSizeX = b.MakeCompositeExtract(UIntTy, GroupSize, {0}); 132 Blk->addInstruction(GroupSizeX); 133 134 auto GroupSizeY = b.MakeCompositeExtract(UIntTy, GroupSize, {1}); 135 Blk->addInstruction(GroupSizeY); 136 137 NumWorkgroups = m->getNumWorkgroups(); 138 auto NumGroup = b.MakeLoad(V3UIntTy, NumWorkgroups); 139 Blk->addInstruction(NumGroup); 140 141 auto NumGroupX = b.MakeCompositeExtract(UIntTy, NumGroup, {0}); 142 Blk->addInstruction(NumGroupX); 143 144 auto NumGroupY = b.MakeCompositeExtract(UIntTy, NumGroup, {1}); 145 Blk->addInstruction(NumGroupY); 146 147 auto GlobalSizeX = b.MakeIMul(UIntTy, GroupSizeX, NumGroupX); 148 Blk->addInstruction(GlobalSizeX); 149 150 auto GlobalSizeY = b.MakeIMul(UIntTy, GroupSizeY, NumGroupY); 151 Blk->addInstruction(GlobalSizeY); 152 153 auto RowsAlongZ = b.MakeIMul(UIntTy, GlobalSizeY, ZValue); 154 Blk->addInstruction(RowsAlongZ); 155 156 auto NumRows = b.MakeIAdd(UIntTy, YValue, RowsAlongZ); 157 Blk->addInstruction(NumRows); 158 159 auto NumCellsFromYZ = b.MakeIMul(UIntTy, GlobalSizeX, NumRows); 160 Blk->addInstruction(NumCellsFromYZ); 161 162 Index = b.MakeIAdd(UIntTy, NumCellsFromYZ, XValue); 163 Blk->addInstruction(Index); 164 } 165 166 std::vector<IdRef> inputs; 167 168 ConstantInst *ConstZero = m->getConstant(UIntTy, 0); 169 170 for (uint32_t i = 0; i < numInput; i++) { 171 FunctionParameterInst *param = kernel->getParameter(i); 172 Instruction *elementType = param->mResultType.mInstruction; 173 VariableInst *inputBuffer = AddBuffer(elementType, i + 2, b, m); 174 175 TypePointerInst *PtrTy = 176 m->getPointerType(StorageClass::Function, elementType); 177 AccessChainInst *Ptr = 178 b.MakeAccessChain(PtrTy, inputBuffer, {ConstZero, Index}); 179 Blk->addInstruction(Ptr); 180 181 Instruction *input = b.MakeLoad(elementType, Ptr); 182 Blk->addInstruction(input); 183 184 inputs.push_back(IdRef(input)); 185 } 186 187 // TODO: Convert from unsigned int to signed int if that is what the kernel 188 // function takes for the coordinate parameters 189 if (MetadataExtractor::hasForEachSignatureX(signature)) { 190 inputs.push_back(XValue); 191 if (MetadataExtractor::hasForEachSignatureY(signature)) { 192 inputs.push_back(YValue); 193 if (MetadataExtractor::hasForEachSignatureZ(signature)) { 194 inputs.push_back(ZValue); 195 } 196 } 197 } 198 199 auto resultType = kernel->getReturnType(); 200 auto kernelCall = 201 b.MakeFunctionCall(resultType, kernel->getInstruction(), inputs); 202 Blk->addInstruction(kernelCall); 203 204 if (MetadataExtractor::hasForEachSignatureOut(signature)) { 205 VariableInst *OutputBuffer = AddBuffer(resultType, 1, b, m); 206 auto resultPtrType = m->getPointerType(StorageClass::Function, resultType); 207 AccessChainInst *OutPtr = 208 b.MakeAccessChain(resultPtrType, OutputBuffer, {ConstZero, Index}); 209 Blk->addInstruction(OutPtr); 210 Blk->addInstruction(b.MakeStore(OutPtr, kernelCall)); 211 } 212 213 Blk->addInstruction(b.MakeReturn()); 214 215 std::string wrapperName("entry_"); 216 wrapperName.append(name); 217 218 EntryPointDefinition *entry = b.MakeEntryPointDefinition( 219 ExecutionModel::GLCompute, Func, wrapperName.c_str()); 220 221 entry->setLocalSize(1, 1, 1); 222 223 if (Index != nullptr) { 224 entry->addToInterface(InvocationId); 225 entry->addToInterface(NumWorkgroups); 226 } 227 228 m->addEntryPoint(entry); 229 230 return true; 231 } 232 233 bool DecorateGlobalBuffer(llvm::Module &LM, Builder &b, Module *m) { 234 Instruction *inst = m->lookupByName("__GPUBlock"); 235 if (inst == nullptr) { 236 return true; 237 } 238 239 VariableInst *bufferVar = static_cast<VariableInst *>(inst); 240 bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0); 241 bufferVar->decorate(Decoration::Binding)->addExtraOperand(0); 242 243 TypePointerInst *StructPtrTy = 244 static_cast<TypePointerInst *>(bufferVar->mResultType.mInstruction); 245 TypeStructInst *StructTy = 246 static_cast<TypeStructInst *>(StructPtrTy->mOperand2.mInstruction); 247 StructTy->decorate(Decoration::BufferBlock); 248 249 // Decorate each member with proper offsets 250 251 const auto GlobalsB = LM.globals().begin(); 252 const auto GlobalsE = LM.globals().end(); 253 const auto Found = 254 std::find_if(GlobalsB, GlobalsE, [](const llvm::GlobalVariable &GV) { 255 return GV.getName() == "__GPUBlock"; 256 }); 257 258 if (Found == GlobalsE) { 259 return true; // GPUBlock not found - not an error by itself. 260 } 261 262 const llvm::GlobalVariable &G = *Found; 263 264 bool IsCorrectTy = false; 265 if (const auto *LPtrTy = llvm::dyn_cast<llvm::PointerType>(G.getType())) { 266 if (auto *LStructTy = 267 llvm::dyn_cast<llvm::StructType>(LPtrTy->getElementType())) { 268 IsCorrectTy = true; 269 270 const auto &DLayout = LM.getDataLayout(); 271 const auto *SLayout = DLayout.getStructLayout(LStructTy); 272 assert(SLayout); 273 if (SLayout == nullptr) { 274 std::cerr << "struct layout is null" << std::endl; 275 return false; 276 } 277 for (uint32_t i = 0, e = LStructTy->getNumElements(); i != e; ++i) { 278 auto decor = StructTy->memberDecorate(i, Decoration::Offset); 279 if (!decor) { 280 std::cerr << "failed creating member decoration for field " << i 281 << std::endl; 282 return false; 283 } 284 const uint32_t offset = (uint32_t)SLayout->getElementOffset(i); 285 decor->addExtraOperand(offset); 286 } 287 } 288 } 289 290 if (!IsCorrectTy) { 291 return false; 292 } 293 294 llvm::SmallVector<rs2spirv::RSAllocationInfo, 2> RSAllocs; 295 if (!getRSAllocationInfo(LM, RSAllocs)) { 296 // llvm::errs() << "Extracting rs_allocation info failed\n"; 297 return true; 298 } 299 300 // TODO: clean up the binding number assignment 301 size_t BindingNum = 3; 302 for (const auto &A : RSAllocs) { 303 Instruction *inst = m->lookupByName(A.VarName.c_str()); 304 if (inst == nullptr) { 305 return false; 306 } 307 VariableInst *bufferVar = static_cast<VariableInst *>(inst); 308 bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0); 309 bufferVar->decorate(Decoration::Binding)->addExtraOperand(BindingNum++); 310 } 311 312 return true; 313 } 314 315 void AddHeader(Module *m) { 316 m->addCapability(Capability::Shader); 317 // TODO: avoid duplicated capability 318 // m->addCapability(Capability::Addresses); 319 m->setMemoryModel(AddressingModel::Physical32, MemoryModel::GLSL450); 320 321 m->addSource(SourceLanguage::GLSL, 450); 322 m->addSourceExtension("GL_ARB_separate_shader_objects"); 323 m->addSourceExtension("GL_ARB_shading_language_420pack"); 324 m->addSourceExtension("GL_GOOGLE_cpp_style_line_directive"); 325 m->addSourceExtension("GL_GOOGLE_include_directive"); 326 } 327 328 namespace { 329 330 class StorageClassVisitor : public DoNothingVisitor { 331 public: 332 void visit(TypePointerInst *inst) override { 333 matchAndReplace(inst->mOperand1); 334 } 335 336 void visit(TypeForwardPointerInst *inst) override { 337 matchAndReplace(inst->mOperand2); 338 } 339 340 void visit(VariableInst *inst) override { matchAndReplace(inst->mOperand1); } 341 342 private: 343 void matchAndReplace(StorageClass &storage) { 344 if (storage == StorageClass::Function) { 345 storage = StorageClass::Uniform; 346 } 347 } 348 }; 349 350 void FixGlobalStorageClass(Module *m) { 351 StorageClassVisitor v; 352 m->getGlobalSection()->accept(&v); 353 } 354 355 } // anonymous namespace 356 357 bool AddWrappers(llvm::Module &LM, 358 android::spirit::Module *m) { 359 rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance(); 360 const bcinfo::MetadataExtractor &metadata = Ctxt.getMetadata(); 361 android::spirit::Builder b; 362 363 m->setBuilder(&b); 364 365 FixGlobalStorageClass(m); 366 367 AddHeader(m); 368 369 DecorateGlobalBuffer(LM, b, m); 370 371 const size_t numKernel = metadata.getExportForEachSignatureCount(); 372 const char **kernelName = metadata.getExportForEachNameList(); 373 const uint32_t *kernelSigature = metadata.getExportForEachSignatureList(); 374 const uint32_t *inputCount = metadata.getExportForEachInputCountList(); 375 376 for (size_t i = 0; i < numKernel; i++) { 377 bool success = 378 AddWrapper(kernelName[i], kernelSigature[i], inputCount[i], b, m); 379 if (!success) { 380 return false; 381 } 382 } 383 384 m->consolidateAnnotations(); 385 return true; 386 } 387 388 class WrapperPass : public Pass { 389 public: 390 WrapperPass(const llvm::Module &LM) : mLLVMModule(const_cast<llvm::Module&>(LM)) {} 391 392 Module *run(Module *m, int *error) override { 393 bool success = AddWrappers(mLLVMModule, m); 394 if (error) { 395 *error = success ? 0 : -1; 396 } 397 return m; 398 } 399 400 private: 401 llvm::Module &mLLVMModule; 402 }; 403 404 } // namespace spirit 405 } // namespace android 406 407 namespace rs2spirv { 408 409 android::spirit::Pass* CreateWrapperPass(const llvm::Module &LLVMModule) { 410 return new android::spirit::WrapperPass(LLVMModule); 411 } 412 413 } // namespace rs2spirv 414