Home | History | Annotate | Download | only in Transforms
      1 /*
      2  * Copyright 2012, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "Config.h"
     18 #include "bcc/bcc_assert.h"
     19 
     20 #include "DebugHelper.h"
     21 
     22 #include "llvm/DerivedTypes.h"
     23 #include "llvm/Function.h"
     24 #include "llvm/Instructions.h"
     25 #include "llvm/Module.h"
     26 #include "llvm/Pass.h"
     27 #include "llvm/Type.h"
     28 #include "llvm/Support/IRBuilder.h"
     29 
     30 namespace {
     31   /* ForEachExpandPass - This pass operates on functions that are able to be
     32    * called via rsForEach() or "foreach_<NAME>". We create an inner loop for
     33    * the ForEach-able function to be invoked over the appropriate data cells
     34    * of the input/output allocations (adjusting other relevant parameters as
     35    * we go). We support doing this for any ForEach-able compute kernels.
     36    * The new function name is the original function name followed by
     37    * ".expand". Note that we still generate code for the original function.
     38    */
     39   class ForEachExpandPass : public llvm::ModulePass {
     40   private:
     41   static char ID;
     42 
     43   llvm::Module *M;
     44   llvm::LLVMContext *C;
     45 
     46   std::vector<std::string>& mNames;
     47   std::vector<uint32_t>& mSignatures;
     48 
     49   uint32_t getRootSignature(llvm::Function *F) {
     50     const llvm::NamedMDNode *ExportForEachMetadata =
     51         M->getNamedMetadata("#rs_export_foreach");
     52 
     53     if (!ExportForEachMetadata) {
     54       llvm::SmallVector<llvm::Type*, 8> RootArgTys;
     55       for (llvm::Function::arg_iterator B = F->arg_begin(),
     56                                         E = F->arg_end();
     57            B != E;
     58            ++B) {
     59         RootArgTys.push_back(B->getType());
     60       }
     61 
     62       // For pre-ICS bitcode, we may not have signature information. In that
     63       // case, we use the size of the RootArgTys to select the number of
     64       // arguments.
     65       return (1 << RootArgTys.size()) - 1;
     66     }
     67 
     68     bccAssert(ExportForEachMetadata->getNumOperands() > 0);
     69 
     70     // We only handle the case for legacy root() functions here, so this is
     71     // hard-coded to look at only the first such function.
     72     llvm::MDNode *SigNode = ExportForEachMetadata->getOperand(0);
     73     if (SigNode != NULL && SigNode->getNumOperands() == 1) {
     74       llvm::Value *SigVal = SigNode->getOperand(0);
     75       if (SigVal->getValueID() == llvm::Value::MDStringVal) {
     76         llvm::StringRef SigString =
     77             static_cast<llvm::MDString*>(SigVal)->getString();
     78         uint32_t Signature = 0;
     79         if (SigString.getAsInteger(10, Signature)) {
     80           ALOGE("Non-integer signature value '%s'", SigString.str().c_str());
     81           return 0;
     82         }
     83         return Signature;
     84       }
     85     }
     86 
     87     return 0;
     88   }
     89 
     90   static bool hasIn(uint32_t Signature) {
     91     return Signature & 1;
     92   }
     93 
     94   static bool hasOut(uint32_t Signature) {
     95     return Signature & 2;
     96   }
     97 
     98   static bool hasUsrData(uint32_t Signature) {
     99     return Signature & 4;
    100   }
    101 
    102   static bool hasX(uint32_t Signature) {
    103     return Signature & 8;
    104   }
    105 
    106   static bool hasY(uint32_t Signature) {
    107     return Signature & 16;
    108   }
    109 
    110   public:
    111   ForEachExpandPass(std::vector<std::string>& Names,
    112                     std::vector<uint32_t>& Signatures)
    113       : ModulePass(ID), M(NULL), C(NULL), mNames(Names),
    114         mSignatures(Signatures) {
    115   }
    116 
    117   /* Performs the actual optimization on a selected function. On success, the
    118    * Module will contain a new function of the name "<NAME>.expand" that
    119    * invokes <NAME>() in a loop with the appropriate parameters.
    120    */
    121   bool ExpandFunction(llvm::Function *F, uint32_t Signature) {
    122     ALOGV("Expanding ForEach-able Function %s", F->getName().str().c_str());
    123 
    124     if (!Signature) {
    125       Signature = getRootSignature(F);
    126       if (!Signature) {
    127         // We couldn't determine how to expand this function based on its
    128         // function signature.
    129         return false;
    130       }
    131     }
    132 
    133     llvm::Type *VoidPtrTy = llvm::Type::getInt8PtrTy(*C);
    134     llvm::Type *Int32Ty = llvm::Type::getInt32Ty(*C);
    135     llvm::Type *SizeTy = Int32Ty;
    136 
    137     /* Defined in frameworks/base/libs/rs/rs_hal.h:
    138      *
    139      * struct RsForEachStubParamStruct {
    140      *   const void *in;
    141      *   void *out;
    142      *   const void *usr;
    143      *   size_t usr_len;
    144      *   uint32_t x;
    145      *   uint32_t y;
    146      *   uint32_t z;
    147      *   uint32_t lod;
    148      *   enum RsAllocationCubemapFace face;
    149      *   uint32_t ar[16];
    150      * };
    151      */
    152     llvm::SmallVector<llvm::Type*, 9> StructTys;
    153     StructTys.push_back(VoidPtrTy);  // const void *in
    154     StructTys.push_back(VoidPtrTy);  // void *out
    155     StructTys.push_back(VoidPtrTy);  // const void *usr
    156     StructTys.push_back(SizeTy);     // size_t usr_len
    157     StructTys.push_back(Int32Ty);    // uint32_t x
    158     StructTys.push_back(Int32Ty);    // uint32_t y
    159     StructTys.push_back(Int32Ty);    // uint32_t z
    160     StructTys.push_back(Int32Ty);    // uint32_t lod
    161     StructTys.push_back(Int32Ty);    // enum RsAllocationCubemapFace
    162     StructTys.push_back(llvm::ArrayType::get(Int32Ty, 16));  // uint32_t ar[16]
    163 
    164     llvm::Type *ForEachStubPtrTy = llvm::StructType::create(
    165         StructTys, "RsForEachStubParamStruct")->getPointerTo();
    166 
    167     /* Create the function signature for our expanded function.
    168      * void (const RsForEachStubParamStruct *p, uint32_t x1, uint32_t x2,
    169      *       uint32_t instep, uint32_t outstep)
    170      */
    171     llvm::SmallVector<llvm::Type*, 8> ParamTys;
    172     ParamTys.push_back(ForEachStubPtrTy);  // const RsForEachStubParamStruct *p
    173     ParamTys.push_back(Int32Ty);           // uint32_t x1
    174     ParamTys.push_back(Int32Ty);           // uint32_t x2
    175     ParamTys.push_back(Int32Ty);           // uint32_t instep
    176     ParamTys.push_back(Int32Ty);           // uint32_t outstep
    177 
    178     llvm::FunctionType *FT =
    179         llvm::FunctionType::get(llvm::Type::getVoidTy(*C), ParamTys, false);
    180     llvm::Function *ExpandedFunc =
    181         llvm::Function::Create(FT,
    182                                llvm::GlobalValue::ExternalLinkage,
    183                                F->getName() + ".expand", M);
    184 
    185     // Create and name the actual arguments to this expanded function.
    186     llvm::SmallVector<llvm::Argument*, 8> ArgVec;
    187     for (llvm::Function::arg_iterator B = ExpandedFunc->arg_begin(),
    188                                       E = ExpandedFunc->arg_end();
    189          B != E;
    190          ++B) {
    191       ArgVec.push_back(B);
    192     }
    193 
    194     if (ArgVec.size() != 5) {
    195       ALOGE("Incorrect number of arguments to function: %zu",
    196             ArgVec.size());
    197       return false;
    198     }
    199     llvm::Value *Arg_p = ArgVec[0];
    200     llvm::Value *Arg_x1 = ArgVec[1];
    201     llvm::Value *Arg_x2 = ArgVec[2];
    202     llvm::Value *Arg_instep = ArgVec[3];
    203     llvm::Value *Arg_outstep = ArgVec[4];
    204 
    205     Arg_p->setName("p");
    206     Arg_x1->setName("x1");
    207     Arg_x2->setName("x2");
    208     Arg_instep->setName("instep");
    209     Arg_outstep->setName("outstep");
    210 
    211     // Construct the actual function body.
    212     llvm::BasicBlock *Begin =
    213         llvm::BasicBlock::Create(*C, "Begin", ExpandedFunc);
    214     llvm::IRBuilder<> Builder(Begin);
    215 
    216     // uint32_t X = x1;
    217     llvm::AllocaInst *AX = Builder.CreateAlloca(Int32Ty, 0, "AX");
    218     Builder.CreateStore(Arg_x1, AX);
    219 
    220     // Collect and construct the arguments for the kernel().
    221     // Note that we load any loop-invariant arguments before entering the Loop.
    222     llvm::Function::arg_iterator Args = F->arg_begin();
    223 
    224     llvm::Type *InTy = NULL;
    225     llvm::AllocaInst *AIn = NULL;
    226     if (hasIn(Signature)) {
    227       InTy = Args->getType();
    228       AIn = Builder.CreateAlloca(InTy, 0, "AIn");
    229       Builder.CreateStore(Builder.CreatePointerCast(Builder.CreateLoad(
    230           Builder.CreateStructGEP(Arg_p, 0)), InTy), AIn);
    231       Args++;
    232     }
    233 
    234     llvm::Type *OutTy = NULL;
    235     llvm::AllocaInst *AOut = NULL;
    236     if (hasOut(Signature)) {
    237       OutTy = Args->getType();
    238       AOut = Builder.CreateAlloca(OutTy, 0, "AOut");
    239       Builder.CreateStore(Builder.CreatePointerCast(Builder.CreateLoad(
    240           Builder.CreateStructGEP(Arg_p, 1)), OutTy), AOut);
    241       Args++;
    242     }
    243 
    244     llvm::Value *UsrData = NULL;
    245     if (hasUsrData(Signature)) {
    246       llvm::Type *UsrDataTy = Args->getType();
    247       UsrData = Builder.CreatePointerCast(Builder.CreateLoad(
    248           Builder.CreateStructGEP(Arg_p, 2)), UsrDataTy);
    249       UsrData->setName("UsrData");
    250       Args++;
    251     }
    252 
    253     if (hasX(Signature)) {
    254       Args++;
    255     }
    256 
    257     llvm::Value *Y = NULL;
    258     if (hasY(Signature)) {
    259       Y = Builder.CreateLoad(Builder.CreateStructGEP(Arg_p, 5), "Y");
    260       Args++;
    261     }
    262 
    263     bccAssert(Args == F->arg_end());
    264 
    265     llvm::BasicBlock *Loop = llvm::BasicBlock::Create(*C, "Loop", ExpandedFunc);
    266     llvm::BasicBlock *Exit = llvm::BasicBlock::Create(*C, "Exit", ExpandedFunc);
    267 
    268     // if (x1 < x2) goto Loop; else goto Exit;
    269     llvm::Value *Cond = Builder.CreateICmpSLT(Arg_x1, Arg_x2);
    270     Builder.CreateCondBr(Cond, Loop, Exit);
    271 
    272     // Loop:
    273     Builder.SetInsertPoint(Loop);
    274 
    275     // Populate the actual call to kernel().
    276     llvm::SmallVector<llvm::Value*, 8> RootArgs;
    277 
    278     llvm::Value *In = NULL;
    279     llvm::Value *Out = NULL;
    280 
    281     if (AIn) {
    282       In = Builder.CreateLoad(AIn, "In");
    283       RootArgs.push_back(In);
    284     }
    285 
    286     if (AOut) {
    287       Out = Builder.CreateLoad(AOut, "Out");
    288       RootArgs.push_back(Out);
    289     }
    290 
    291     if (UsrData) {
    292       RootArgs.push_back(UsrData);
    293     }
    294 
    295     // We always have to load X, since it is used to iterate through the loop.
    296     llvm::Value *X = Builder.CreateLoad(AX, "X");
    297     if (hasX(Signature)) {
    298       RootArgs.push_back(X);
    299     }
    300 
    301     if (Y) {
    302       RootArgs.push_back(Y);
    303     }
    304 
    305     Builder.CreateCall(F, RootArgs);
    306 
    307     if (In) {
    308       // In += instep
    309       llvm::Value *NewIn = Builder.CreateIntToPtr(Builder.CreateNUWAdd(
    310           Builder.CreatePtrToInt(In, Int32Ty), Arg_instep), InTy);
    311       Builder.CreateStore(NewIn, AIn);
    312     }
    313 
    314     if (Out) {
    315       // Out += outstep
    316       llvm::Value *NewOut = Builder.CreateIntToPtr(Builder.CreateNUWAdd(
    317           Builder.CreatePtrToInt(Out, Int32Ty), Arg_outstep), OutTy);
    318       Builder.CreateStore(NewOut, AOut);
    319     }
    320 
    321     // X++;
    322     llvm::Value *XPlusOne =
    323         Builder.CreateNUWAdd(X, llvm::ConstantInt::get(Int32Ty, 1));
    324     Builder.CreateStore(XPlusOne, AX);
    325 
    326     // If (X < x2) goto Loop; else goto Exit;
    327     Cond = Builder.CreateICmpSLT(XPlusOne, Arg_x2);
    328     Builder.CreateCondBr(Cond, Loop, Exit);
    329 
    330     // Exit:
    331     Builder.SetInsertPoint(Exit);
    332     Builder.CreateRetVoid();
    333 
    334     return true;
    335   }
    336 
    337   virtual bool runOnModule(llvm::Module &M) {
    338     bool Changed = false;
    339     this->M = &M;
    340     C = &M.getContext();
    341 
    342     bccAssert(mNames.size() == mSignatures.size());
    343     for (int i = 0, e = mNames.size(); i != e; i++) {
    344       llvm::Function *kernel = M.getFunction(mNames[i]);
    345       if (kernel && kernel->getReturnType()->isVoidTy()) {
    346         Changed |= ExpandFunction(kernel, mSignatures[i]);
    347       }
    348     }
    349 
    350     return Changed;
    351   }
    352 
    353   virtual const char *getPassName() const {
    354     return "ForEach-able Function Expansion";
    355   }
    356 
    357   };
    358 }  // end anonymous namespace
    359 
    360 char ForEachExpandPass::ID = 0;
    361 
    362 namespace bcc {
    363 
    364   llvm::ModulePass *createForEachExpandPass(std::vector<std::string>& Names,
    365                                             std::vector<uint32_t>& Signatures) {
    366     return new ForEachExpandPass(Names, Signatures);
    367   }
    368 
    369 }  // namespace bcc
    370