1 /* 2 * Copyright 2016-2017, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "GlobalMergePass.h" 18 19 #include "llvm/IR/Constants.h" 20 #include "llvm/IR/DataLayout.h" 21 #include "llvm/IR/GlobalVariable.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/Instructions.h" 24 #include "llvm/IR/Module.h" 25 #include "llvm/Pass.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/raw_ostream.h" 28 29 #include "Context.h" 30 #include "RSAllocationUtils.h" 31 32 #include <functional> 33 34 #define DEBUG_TYPE "rs2spirv-global-merge" 35 36 using namespace llvm; 37 38 namespace rs2spirv { 39 40 namespace { 41 42 class GlobalMergePass : public ModulePass { 43 public: 44 static char ID; 45 GlobalMergePass(bool CPU = false) : ModulePass(ID), mForCPU(CPU) {} 46 const char *getPassName() const override { return "GlobalMergePass"; } 47 48 bool runOnModule(Module &M) override { 49 DEBUG(dbgs() << "RS2SPIRVGlobalMergePass\n"); 50 51 SmallVector<GlobalVariable *, 8> Globals; 52 if (!collectGlobals(M, Globals)) { 53 return false; // Module not modified. 54 } 55 56 SmallVector<Type *, 8> Tys; 57 Tys.reserve(Globals.size()); 58 59 Context &RS2SPIRVCtxt = Context::getInstance(); 60 61 uint32_t index = 0; 62 for (GlobalVariable *GV : Globals) { 63 Tys.push_back(GV->getValueType()); 64 const char *name = GV->getName().data(); 65 RS2SPIRVCtxt.addExportVarIndex(name, index); 66 index++; 67 } 68 69 LLVMContext &LLVMCtxt = M.getContext(); 70 71 StructType *MergedTy = StructType::create(LLVMCtxt, "struct.__GPUBuffer"); 72 MergedTy->setBody(Tys, false); 73 74 // Size calculation has to consider data layout 75 const DataLayout &DL = M.getDataLayout(); 76 const uint64_t BufferSize = DL.getTypeAllocSize(MergedTy); 77 RS2SPIRVCtxt.setGlobalSize(BufferSize); 78 79 Type *BufferVarTy = mForCPU ? static_cast<Type *>(PointerType::getUnqual( 80 Type::getInt8Ty(M.getContext()))) 81 : static_cast<Type *>(MergedTy); 82 GlobalVariable *MergedGV = 83 new GlobalVariable(M, BufferVarTy, false, GlobalValue::ExternalLinkage, 84 nullptr, "__GPUBlock"); 85 86 // For CPU, create a constant struct for initial values, which has each of 87 // its fields initialized to the original value of the corresponding global 88 // variable. 89 // During the script initialization, the driver should copy these initial 90 // values to the global buffer. 91 if (mForCPU) { 92 CreateInitFunction(LLVMCtxt, M, MergedGV, MergedTy, BufferSize, Globals); 93 } 94 95 const bool forCPU = mForCPU; 96 IntegerType *const Int32Ty = Type::getInt32Ty(LLVMCtxt); 97 ConstantInt *const Zero = ConstantInt::get(Int32Ty, 0); 98 Value *Idx[] = {Zero, nullptr}; 99 100 auto InstMaker = [forCPU, MergedGV, MergedTy, 101 &Idx](Instruction *InsertBefore) { 102 Value *Base = MergedGV; 103 if (forCPU) { 104 LoadInst *Load = new LoadInst(MergedGV, "", InsertBefore); 105 DEBUG(Load->dump()); 106 Base = new BitCastInst(Load, PointerType::getUnqual(MergedTy), "", 107 InsertBefore); 108 DEBUG(Base->dump()); 109 } 110 GetElementPtrInst *GEP = GetElementPtrInst::CreateInBounds( 111 MergedTy, Base, Idx, "", InsertBefore); 112 DEBUG(GEP->dump()); 113 return GEP; 114 }; 115 116 for (size_t i = 0, e = Globals.size(); i != e; ++i) { 117 GlobalVariable *G = Globals[i]; 118 Idx[1] = ConstantInt::get(Int32Ty, i); 119 ReplaceAllUsesWithNewInstructions(G, std::cref(InstMaker)); 120 G->eraseFromParent(); 121 } 122 123 // Return true, as the pass modifies module. 124 return true; 125 } 126 127 private: 128 // In the User of Value Old, replaces all references of Old with Value New 129 static inline void ReplaceUse(User *U, Value *Old, Value *New) { 130 for (unsigned i = 0, n = U->getNumOperands(); i < n; ++i) { 131 if (U->getOperand(i) == Old) { 132 U->getOperandUse(i) = New; 133 } 134 } 135 } 136 137 // Replaces each use of V with new instructions created by 138 // funcCreateAndInsert and inserted right before that use. In the cases where 139 // the use is not an instruction, but a constant expression, recursively 140 // replaces that constant expression with a newly constructed equivalent 141 // instruction, before replacing V in that new instruction. 142 static inline void ReplaceAllUsesWithNewInstructions( 143 Value *V, 144 std::function<Instruction *(Instruction *)> funcCreateAndInsert) { 145 SmallVector<User *, 8> Users(V->user_begin(), V->user_end()); 146 for (User *U : Users) { 147 if (Instruction *Inst = dyn_cast<Instruction>(U)) { 148 DEBUG(dbgs() << "\nBefore replacement:\n"); 149 DEBUG(Inst->dump()); 150 DEBUG(dbgs() << "----\n"); 151 152 ReplaceUse(U, V, funcCreateAndInsert(Inst)); 153 154 DEBUG(Inst->dump()); 155 } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U)) { 156 auto InstMaker([CE, V, &funcCreateAndInsert](Instruction *UserOfU) { 157 Instruction *Inst = CE->getAsInstruction(); 158 Inst->insertBefore(UserOfU); 159 ReplaceUse(Inst, V, funcCreateAndInsert(Inst)); 160 161 DEBUG(Inst->dump()); 162 return Inst; 163 }); 164 ReplaceAllUsesWithNewInstructions(U, InstMaker); 165 } else { 166 DEBUG(U->dump()); 167 llvm_unreachable("Expecting only Instruction or ConstantExpr"); 168 } 169 } 170 } 171 172 static inline void 173 CreateInitFunction(LLVMContext &LLVMCtxt, Module &M, GlobalVariable *MergedGV, 174 StructType *MergedTy, const uint64_t BufferSize, 175 const SmallVectorImpl<GlobalVariable *> &Globals) { 176 SmallVector<Constant *, 8> Initializers; 177 Initializers.reserve(Globals.size()); 178 for (size_t i = 0, e = Globals.size(); i != e; ++i) { 179 GlobalVariable *G = Globals[i]; 180 Initializers.push_back(G->getInitializer()); 181 } 182 ArrayRef<Constant *> ArrInit(Initializers.begin(), Initializers.end()); 183 Constant *MergedInitializer = ConstantStruct::get(MergedTy, ArrInit); 184 GlobalVariable *MergedInit = 185 new GlobalVariable(M, MergedTy, true, GlobalValue::InternalLinkage, 186 MergedInitializer, "__GPUBlock0"); 187 188 Function *UserInit = M.getFunction("init"); 189 // If there is no user-defined init() function, make the new global 190 // initialization function the init(). 191 StringRef FName(UserInit ? ".rsov.global_init" : "init"); 192 Function *Func; 193 FunctionType *FTy = FunctionType::get(Type::getVoidTy(LLVMCtxt), false); 194 Func = Function::Create(FTy, GlobalValue::ExternalLinkage, FName, &M); 195 BasicBlock *Blk = BasicBlock::Create(LLVMCtxt, "entry", Func); 196 IRBuilder<> LLVMIRBuilder(Blk); 197 LoadInst *Load = LLVMIRBuilder.CreateLoad(MergedGV); 198 LLVMIRBuilder.CreateMemCpy(Load, MergedInit, BufferSize, 0); 199 LLVMIRBuilder.CreateRetVoid(); 200 201 // If there is a user-defined init() function, add a call to the global 202 // initialization function in the beginning of that function. 203 if (UserInit) { 204 BasicBlock &EntryBlk = UserInit->getEntryBlock(); 205 CallInst::Create(Func, {}, "", &EntryBlk.front()); 206 } 207 } 208 209 bool collectGlobals(Module &M, SmallVectorImpl<GlobalVariable *> &Globals) { 210 for (GlobalVariable &GV : M.globals()) { 211 assert(!GV.hasComdat() && "global variable has a comdat section"); 212 assert(!GV.hasSection() && "global variable has a non-default section"); 213 assert(!GV.isDeclaration() && "global variable is only a declaration"); 214 assert(!GV.isThreadLocal() && "global variable is thread-local"); 215 assert(GV.getType()->getAddressSpace() == 0 && 216 "global variable has non-default address space"); 217 218 // TODO: Constants accessed by kernels should be handled differently 219 if (GV.isConstant()) { 220 continue; 221 } 222 223 // Global Allocations are handled differently in separate passes 224 if (isRSAllocation(GV)) { 225 continue; 226 } 227 228 Globals.push_back(&GV); 229 } 230 231 return !Globals.empty(); 232 } 233 234 bool mForCPU; 235 }; 236 237 } // namespace 238 239 char GlobalMergePass::ID = 0; 240 241 ModulePass *createGlobalMergePass(bool CPU) { return new GlobalMergePass(CPU); } 242 243 } // namespace rs2spirv 244