Home | History | Annotate | Download | only in compiler
      1 /*
      2  * Copyright 2016-2017, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "GlobalMergePass.h"
     18 
     19 #include "llvm/IR/Constants.h"
     20 #include "llvm/IR/DataLayout.h"
     21 #include "llvm/IR/GlobalVariable.h"
     22 #include "llvm/IR/IRBuilder.h"
     23 #include "llvm/IR/Instructions.h"
     24 #include "llvm/IR/Module.h"
     25 #include "llvm/Pass.h"
     26 #include "llvm/Support/Debug.h"
     27 #include "llvm/Support/raw_ostream.h"
     28 
     29 #include "Context.h"
     30 #include "RSAllocationUtils.h"
     31 
     32 #include <functional>
     33 
     34 #define DEBUG_TYPE "rs2spirv-global-merge"
     35 
     36 using namespace llvm;
     37 
     38 namespace rs2spirv {
     39 
     40 namespace {
     41 
     42 class GlobalMergePass : public ModulePass {
     43 public:
     44   static char ID;
     45   GlobalMergePass(bool CPU = false) : ModulePass(ID), mForCPU(CPU) {}
     46   const char *getPassName() const override { return "GlobalMergePass"; }
     47 
     48   bool runOnModule(Module &M) override {
     49     DEBUG(dbgs() << "RS2SPIRVGlobalMergePass\n");
     50 
     51     SmallVector<GlobalVariable *, 8> Globals;
     52     if (!collectGlobals(M, Globals)) {
     53       return false; // Module not modified.
     54     }
     55 
     56     SmallVector<Type *, 8> Tys;
     57     Tys.reserve(Globals.size());
     58 
     59     Context &RS2SPIRVCtxt = Context::getInstance();
     60 
     61     uint32_t index = 0;
     62     for (GlobalVariable *GV : Globals) {
     63       Tys.push_back(GV->getValueType());
     64       const char *name = GV->getName().data();
     65       RS2SPIRVCtxt.addExportVarIndex(name, index);
     66       index++;
     67     }
     68 
     69     LLVMContext &LLVMCtxt = M.getContext();
     70 
     71     StructType *MergedTy = StructType::create(LLVMCtxt, "struct.__GPUBuffer");
     72     MergedTy->setBody(Tys, false);
     73 
     74     // Size calculation has to consider data layout
     75     const DataLayout &DL = M.getDataLayout();
     76     const uint64_t BufferSize = DL.getTypeAllocSize(MergedTy);
     77     RS2SPIRVCtxt.setGlobalSize(BufferSize);
     78 
     79     Type *BufferVarTy = mForCPU ? static_cast<Type *>(PointerType::getUnqual(
     80                                       Type::getInt8Ty(M.getContext())))
     81                                 : static_cast<Type *>(MergedTy);
     82     GlobalVariable *MergedGV =
     83         new GlobalVariable(M, BufferVarTy, false, GlobalValue::ExternalLinkage,
     84                            nullptr, "__GPUBlock");
     85 
     86     // For CPU, create a constant struct for initial values, which has each of
     87     // its fields initialized to the original value of the corresponding global
     88     // variable.
     89     // During the script initialization, the driver should copy these initial
     90     // values to the global buffer.
     91     if (mForCPU) {
     92       CreateInitFunction(LLVMCtxt, M, MergedGV, MergedTy, BufferSize, Globals);
     93     }
     94 
     95     const bool forCPU = mForCPU;
     96     IntegerType *const Int32Ty = Type::getInt32Ty(LLVMCtxt);
     97     ConstantInt *const Zero = ConstantInt::get(Int32Ty, 0);
     98     Value *Idx[] = {Zero, nullptr};
     99 
    100     auto InstMaker = [forCPU, MergedGV, MergedTy,
    101                       &Idx](Instruction *InsertBefore) {
    102       Value *Base = MergedGV;
    103       if (forCPU) {
    104         LoadInst *Load = new LoadInst(MergedGV, "", InsertBefore);
    105         DEBUG(Load->dump());
    106         Base = new BitCastInst(Load, PointerType::getUnqual(MergedTy), "",
    107                                InsertBefore);
    108         DEBUG(Base->dump());
    109       }
    110       GetElementPtrInst *GEP = GetElementPtrInst::CreateInBounds(
    111           MergedTy, Base, Idx, "", InsertBefore);
    112       DEBUG(GEP->dump());
    113       return GEP;
    114     };
    115 
    116     for (size_t i = 0, e = Globals.size(); i != e; ++i) {
    117       GlobalVariable *G = Globals[i];
    118       Idx[1] = ConstantInt::get(Int32Ty, i);
    119       ReplaceAllUsesWithNewInstructions(G, std::cref(InstMaker));
    120       G->eraseFromParent();
    121     }
    122 
    123     // Return true, as the pass modifies module.
    124     return true;
    125   }
    126 
    127 private:
    128   // In the User of Value Old, replaces all references of Old with Value New
    129   static inline void ReplaceUse(User *U, Value *Old, Value *New) {
    130     for (unsigned i = 0, n = U->getNumOperands(); i < n; ++i) {
    131       if (U->getOperand(i) == Old) {
    132         U->getOperandUse(i) = New;
    133       }
    134     }
    135   }
    136 
    137   // Replaces each use of V with new instructions created by
    138   // funcCreateAndInsert and inserted right before that use. In the cases where
    139   // the use is not an instruction, but a constant expression, recursively
    140   // replaces that constant expression with a newly constructed equivalent
    141   // instruction, before replacing V in that new instruction.
    142   static inline void ReplaceAllUsesWithNewInstructions(
    143       Value *V,
    144       std::function<Instruction *(Instruction *)> funcCreateAndInsert) {
    145     SmallVector<User *, 8> Users(V->user_begin(), V->user_end());
    146     for (User *U : Users) {
    147       if (Instruction *Inst = dyn_cast<Instruction>(U)) {
    148         DEBUG(dbgs() << "\nBefore replacement:\n");
    149         DEBUG(Inst->dump());
    150         DEBUG(dbgs() << "----\n");
    151 
    152         ReplaceUse(U, V, funcCreateAndInsert(Inst));
    153 
    154         DEBUG(Inst->dump());
    155       } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U)) {
    156         auto InstMaker([CE, V, &funcCreateAndInsert](Instruction *UserOfU) {
    157           Instruction *Inst = CE->getAsInstruction();
    158           Inst->insertBefore(UserOfU);
    159           ReplaceUse(Inst, V, funcCreateAndInsert(Inst));
    160 
    161           DEBUG(Inst->dump());
    162           return Inst;
    163         });
    164         ReplaceAllUsesWithNewInstructions(U, InstMaker);
    165       } else {
    166         DEBUG(U->dump());
    167         llvm_unreachable("Expecting only Instruction or ConstantExpr");
    168       }
    169     }
    170   }
    171 
    172   static inline void
    173   CreateInitFunction(LLVMContext &LLVMCtxt, Module &M, GlobalVariable *MergedGV,
    174                      StructType *MergedTy, const uint64_t BufferSize,
    175                      const SmallVectorImpl<GlobalVariable *> &Globals) {
    176     SmallVector<Constant *, 8> Initializers;
    177     Initializers.reserve(Globals.size());
    178     for (size_t i = 0, e = Globals.size(); i != e; ++i) {
    179       GlobalVariable *G = Globals[i];
    180       Initializers.push_back(G->getInitializer());
    181     }
    182     ArrayRef<Constant *> ArrInit(Initializers.begin(), Initializers.end());
    183     Constant *MergedInitializer = ConstantStruct::get(MergedTy, ArrInit);
    184     GlobalVariable *MergedInit =
    185         new GlobalVariable(M, MergedTy, true, GlobalValue::InternalLinkage,
    186                            MergedInitializer, "__GPUBlock0");
    187 
    188     Function *UserInit = M.getFunction("init");
    189     // If there is no user-defined init() function, make the new global
    190     // initialization function the init().
    191     StringRef FName(UserInit ? ".rsov.global_init" : "init");
    192     Function *Func;
    193     FunctionType *FTy = FunctionType::get(Type::getVoidTy(LLVMCtxt), false);
    194     Func = Function::Create(FTy, GlobalValue::ExternalLinkage, FName, &M);
    195     BasicBlock *Blk = BasicBlock::Create(LLVMCtxt, "entry", Func);
    196     IRBuilder<> LLVMIRBuilder(Blk);
    197     LoadInst *Load = LLVMIRBuilder.CreateLoad(MergedGV);
    198     LLVMIRBuilder.CreateMemCpy(Load, MergedInit, BufferSize, 0);
    199     LLVMIRBuilder.CreateRetVoid();
    200 
    201     // If there is a user-defined init() function, add a call to the global
    202     // initialization function in the beginning of that function.
    203     if (UserInit) {
    204       BasicBlock &EntryBlk = UserInit->getEntryBlock();
    205       CallInst::Create(Func, {}, "", &EntryBlk.front());
    206     }
    207   }
    208 
    209   bool collectGlobals(Module &M, SmallVectorImpl<GlobalVariable *> &Globals) {
    210     for (GlobalVariable &GV : M.globals()) {
    211       assert(!GV.hasComdat() && "global variable has a comdat section");
    212       assert(!GV.hasSection() && "global variable has a non-default section");
    213       assert(!GV.isDeclaration() && "global variable is only a declaration");
    214       assert(!GV.isThreadLocal() && "global variable is thread-local");
    215       assert(GV.getType()->getAddressSpace() == 0 &&
    216              "global variable has non-default address space");
    217 
    218       // TODO: Constants accessed by kernels should be handled differently
    219       if (GV.isConstant()) {
    220         continue;
    221       }
    222 
    223       // Global Allocations are handled differently in separate passes
    224       if (isRSAllocation(GV)) {
    225         continue;
    226       }
    227 
    228       Globals.push_back(&GV);
    229     }
    230 
    231     return !Globals.empty();
    232   }
    233 
    234   bool mForCPU;
    235 };
    236 
    237 } // namespace
    238 
    239 char GlobalMergePass::ID = 0;
    240 
    241 ModulePass *createGlobalMergePass(bool CPU) { return new GlobalMergePass(CPU); }
    242 
    243 } // namespace rs2spirv
    244