Home | History | Annotate | Download | only in lib
      1 /*
      2  * Copyright 2015, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "RSScriptGroupFusion.h"
     18 
     19 #include "Assert.h"
     20 #include "Log.h"
     21 #include "bcc/BCCContext.h"
     22 #include "bcc/Source.h"
     23 #include "bcinfo/MetadataExtractor.h"
     24 #include "llvm/ADT/StringExtras.h"
     25 #include "llvm/IR/DataLayout.h"
     26 #include "llvm/IR/IRBuilder.h"
     27 #include "llvm/IR/Module.h"
     28 #include "llvm/Support/raw_ostream.h"
     29 
     30 using llvm::Function;
     31 using llvm::Module;
     32 
     33 using std::string;
     34 
     35 namespace bcc {
     36 
     37 namespace {
     38 
     39 const Function* getInvokeFunction(const Source& source, const int slot,
     40                                   Module* newModule) {
     41 
     42   bcinfo::MetadataExtractor &metadata = *source.getMetadata();
     43   const char* functionName = metadata.getExportFuncNameList()[slot];
     44   Function* func = newModule->getFunction(functionName);
     45   // Materialize the function so that later the caller can inspect its argument
     46   // and return types.
     47   newModule->materialize(func);
     48   return func;
     49 }
     50 
     51 const Function*
     52 getFunction(Module* mergedModule, const Source* source, const int slot,
     53             uint32_t* signature) {
     54 
     55   bcinfo::MetadataExtractor &metadata = *source->getMetadata();
     56   const char* functionName = metadata.getExportForEachNameList()[slot];
     57   if (functionName == nullptr || !functionName[0]) {
     58     ALOGE("Kernel fusion (module %s slot %d): failed to find kernel function",
     59           source->getName().c_str(), slot);
     60     return nullptr;
     61   }
     62 
     63   if (metadata.getExportForEachInputCountList()[slot] > 1) {
     64     ALOGE("Kernel fusion (module %s function %s): cannot handle multiple inputs",
     65           source->getName().c_str(), functionName);
     66     return nullptr;
     67   }
     68 
     69   if (signature != nullptr) {
     70     *signature = metadata.getExportForEachSignatureList()[slot];
     71   }
     72 
     73   const Function* function = mergedModule->getFunction(functionName);
     74 
     75   return function;
     76 }
     77 
     78 // The whitelist of supported signature bits. Context or user data arguments are
     79 // not currently supported in kernel fusion. To support them or any new kinds of
     80 // arguments in the future, it requires not only listing the signature bits here,
     81 // but also implementing additional necessary fusion logic in the getFusedFuncSig(),
     82 // getFusedFuncType(), and fuseKernels() functions below.
     83 constexpr uint32_t ExpectedSignatureBits =
     84         bcinfo::MD_SIG_In |
     85         bcinfo::MD_SIG_Out |
     86         bcinfo::MD_SIG_X |
     87         bcinfo::MD_SIG_Y |
     88         bcinfo::MD_SIG_Z |
     89         bcinfo::MD_SIG_Kernel;
     90 
     91 int getFusedFuncSig(const std::vector<Source*>& sources,
     92                     const std::vector<int>& slots,
     93                     uint32_t* retSig) {
     94   *retSig = 0;
     95   uint32_t firstSignature = 0;
     96   uint32_t signature = 0;
     97   auto slotIter = slots.begin();
     98   for (const Source* source : sources) {
     99     const int slot = *slotIter++;
    100     bcinfo::MetadataExtractor &metadata = *source->getMetadata();
    101 
    102     if (metadata.getExportForEachInputCountList()[slot] > 1) {
    103       ALOGE("Kernel fusion (module %s slot %d): cannot handle multiple inputs",
    104             source->getName().c_str(), slot);
    105       return -1;
    106     }
    107 
    108     signature = metadata.getExportForEachSignatureList()[slot];
    109     if (signature & ~ExpectedSignatureBits) {
    110       ALOGE("Kernel fusion (module %s slot %d): Unexpected signature %x",
    111             source->getName().c_str(), slot, signature);
    112       return -1;
    113     }
    114 
    115     if (firstSignature == 0) {
    116       firstSignature = signature;
    117     }
    118 
    119     *retSig |= signature;
    120   }
    121 
    122   if (!bcinfo::MetadataExtractor::hasForEachSignatureIn(firstSignature)) {
    123     *retSig &= ~bcinfo::MD_SIG_In;
    124   }
    125 
    126   if (!bcinfo::MetadataExtractor::hasForEachSignatureOut(signature)) {
    127     *retSig &= ~bcinfo::MD_SIG_Out;
    128   }
    129 
    130   return 0;
    131 }
    132 
    133 llvm::FunctionType* getFusedFuncType(bcc::BCCContext& Context,
    134                                      const std::vector<Source*>& sources,
    135                                      const std::vector<int>& slots,
    136                                      Module* M,
    137                                      uint32_t* signature) {
    138   int error = getFusedFuncSig(sources, slots, signature);
    139 
    140   if (error < 0) {
    141     return nullptr;
    142   }
    143 
    144   const Function* firstF = getFunction(M, sources.front(), slots.front(), nullptr);
    145 
    146   bccAssert (firstF != nullptr);
    147 
    148   llvm::SmallVector<llvm::Type*, 8> ArgTys;
    149 
    150   if (bcinfo::MetadataExtractor::hasForEachSignatureIn(*signature)) {
    151     ArgTys.push_back(firstF->arg_begin()->getType());
    152   }
    153 
    154   llvm::Type* I32Ty = llvm::IntegerType::get(Context.getLLVMContext(), 32);
    155   if (bcinfo::MetadataExtractor::hasForEachSignatureX(*signature)) {
    156     ArgTys.push_back(I32Ty);
    157   }
    158   if (bcinfo::MetadataExtractor::hasForEachSignatureY(*signature)) {
    159     ArgTys.push_back(I32Ty);
    160   }
    161   if (bcinfo::MetadataExtractor::hasForEachSignatureZ(*signature)) {
    162     ArgTys.push_back(I32Ty);
    163   }
    164 
    165   const Function* lastF = getFunction(M, sources.back(), slots.back(), nullptr);
    166 
    167   bccAssert (lastF != nullptr);
    168 
    169   llvm::Type* retTy = lastF->getReturnType();
    170 
    171   return llvm::FunctionType::get(retTy, ArgTys, false);
    172 }
    173 
    174 }  // anonymous namespace
    175 
    176 bool fuseKernels(bcc::BCCContext& Context,
    177                  const std::vector<Source *>& sources,
    178                  const std::vector<int>& slots,
    179                  const std::string& fusedName,
    180                  Module* mergedModule) {
    181   bccAssert(sources.size() == slots.size() && "sources and slots differ in size");
    182 
    183   uint32_t fusedFunctionSignature;
    184 
    185   llvm::FunctionType* fusedType =
    186           getFusedFuncType(Context, sources, slots, mergedModule, &fusedFunctionSignature);
    187 
    188   if (fusedType == nullptr) {
    189     return false;
    190   }
    191 
    192   Function* fusedKernel =
    193           (Function*)(mergedModule->getOrInsertFunction(fusedName, fusedType));
    194 
    195   llvm::LLVMContext& ctxt = Context.getLLVMContext();
    196 
    197   llvm::BasicBlock* block = llvm::BasicBlock::Create(ctxt, "entry", fusedKernel);
    198   llvm::IRBuilder<> builder(block);
    199 
    200   Function::arg_iterator argIter = fusedKernel->arg_begin();
    201 
    202   llvm::Value* dataElement = nullptr;
    203   if (bcinfo::MetadataExtractor::hasForEachSignatureIn(fusedFunctionSignature)) {
    204     dataElement = &*(argIter++);
    205     dataElement->setName("DataIn");
    206   }
    207 
    208   llvm::Value* X = nullptr;
    209   if (bcinfo::MetadataExtractor::hasForEachSignatureX(fusedFunctionSignature)) {
    210     X = &*(argIter++);
    211     X->setName("x");
    212   }
    213 
    214   llvm::Value* Y = nullptr;
    215   if (bcinfo::MetadataExtractor::hasForEachSignatureY(fusedFunctionSignature)) {
    216     Y = &*(argIter++);
    217     Y->setName("y");
    218   }
    219 
    220   llvm::Value* Z = nullptr;
    221   if (bcinfo::MetadataExtractor::hasForEachSignatureZ(fusedFunctionSignature)) {
    222     Z = &*(argIter++);
    223     Z->setName("z");
    224   }
    225 
    226   auto slotIter = slots.begin();
    227   for (const Source* source : sources) {
    228     int slot = *slotIter;
    229 
    230     uint32_t inputFunctionSignature;
    231     const Function* inputFunction =
    232             getFunction(mergedModule, source, slot, &inputFunctionSignature);
    233     if (inputFunction == nullptr) {
    234       // Either failed to find the kernel function, or the function has multiple inputs.
    235       return false;
    236     }
    237 
    238     // Don't try to fuse a non-kernel
    239     if (!bcinfo::MetadataExtractor::hasForEachSignatureKernel(inputFunctionSignature)) {
    240       ALOGE("Kernel fusion (module %s function %s): not a kernel",
    241             source->getName().c_str(), inputFunction->getName().str().c_str());
    242       return false;
    243     }
    244 
    245     std::vector<llvm::Value*> args;
    246 
    247     if (bcinfo::MetadataExtractor::hasForEachSignatureIn(inputFunctionSignature)) {
    248       if (dataElement == nullptr) {
    249         ALOGE("Kernel fusion (module %s function %s): expected input, but got null",
    250               source->getName().c_str(), inputFunction->getName().str().c_str());
    251         return false;
    252       }
    253 
    254       const llvm::FunctionType* funcTy = inputFunction->getFunctionType();
    255       llvm::Type* firstArgType = funcTy->getParamType(0);
    256 
    257       if (dataElement->getType() != firstArgType) {
    258         std::string msg;
    259         llvm::raw_string_ostream rso(msg);
    260         rso << "Mismatching argument type, expected ";
    261         firstArgType->print(rso);
    262         rso << ", received ";
    263         dataElement->getType()->print(rso);
    264         ALOGE("Kernel fusion (module %s function %s): %s", source->getName().c_str(),
    265               inputFunction->getName().str().c_str(), rso.str().c_str());
    266         return false;
    267       }
    268 
    269       args.push_back(dataElement);
    270     } else {
    271       // Only the first kernel in a batch is allowed to have no input
    272       if (slotIter != slots.begin()) {
    273         ALOGE("Kernel fusion (module %s function %s): function not first in batch takes no input",
    274               source->getName().c_str(), inputFunction->getName().str().c_str());
    275         return false;
    276       }
    277     }
    278 
    279     if (bcinfo::MetadataExtractor::hasForEachSignatureX(inputFunctionSignature)) {
    280       args.push_back(X);
    281     }
    282 
    283     if (bcinfo::MetadataExtractor::hasForEachSignatureY(inputFunctionSignature)) {
    284       args.push_back(Y);
    285     }
    286 
    287     if (bcinfo::MetadataExtractor::hasForEachSignatureZ(inputFunctionSignature)) {
    288       args.push_back(Z);
    289     }
    290 
    291     dataElement = builder.CreateCall((llvm::Value*)inputFunction, args);
    292 
    293     slotIter++;
    294   }
    295 
    296   if (fusedKernel->getReturnType()->isVoidTy()) {
    297     builder.CreateRetVoid();
    298   } else {
    299     builder.CreateRet(dataElement);
    300   }
    301 
    302   llvm::NamedMDNode* ExportForEachNameMD =
    303     mergedModule->getOrInsertNamedMetadata("#rs_export_foreach_name");
    304 
    305   llvm::MDString* nameMDStr = llvm::MDString::get(ctxt, fusedName);
    306   llvm::MDNode* nameMDNode = llvm::MDNode::get(ctxt, nameMDStr);
    307   ExportForEachNameMD->addOperand(nameMDNode);
    308 
    309   llvm::NamedMDNode* ExportForEachMD =
    310     mergedModule->getOrInsertNamedMetadata("#rs_export_foreach");
    311   llvm::MDString* sigMDStr = llvm::MDString::get(ctxt,
    312                                                  llvm::utostr(fusedFunctionSignature));
    313   llvm::MDNode* sigMDNode = llvm::MDNode::get(ctxt, sigMDStr);
    314   ExportForEachMD->addOperand(sigMDNode);
    315 
    316   return true;
    317 }
    318 
    319 bool renameInvoke(BCCContext& Context, const Source* source, const int slot,
    320                   const std::string& newName, Module* module) {
    321   const llvm::Function* F = getInvokeFunction(*source, slot, module);
    322   std::vector<llvm::Type*> params;
    323   for (auto I = F->arg_begin(), E = F->arg_end(); I != E; ++I) {
    324     params.push_back(I->getType());
    325   }
    326   llvm::Type* returnTy = F->getReturnType();
    327 
    328   llvm::FunctionType* batchFuncTy =
    329           llvm::FunctionType::get(returnTy, params, false);
    330 
    331   llvm::Function* newF =
    332           llvm::Function::Create(batchFuncTy,
    333                                  llvm::GlobalValue::ExternalLinkage, newName,
    334                                  module);
    335 
    336   llvm::BasicBlock* block = llvm::BasicBlock::Create(Context.getLLVMContext(),
    337                                                      "entry", newF);
    338   llvm::IRBuilder<> builder(block);
    339 
    340   llvm::Function::arg_iterator argIter = newF->arg_begin();
    341   llvm::Value* arg1 = &*(argIter++);
    342   builder.CreateCall((llvm::Value*)F, arg1);
    343 
    344   builder.CreateRetVoid();
    345 
    346   llvm::NamedMDNode* ExportFuncNameMD =
    347           module->getOrInsertNamedMetadata("#rs_export_func");
    348   llvm::MDString* strMD = llvm::MDString::get(module->getContext(), newName);
    349   llvm::MDNode* nodeMD = llvm::MDNode::get(module->getContext(), strMD);
    350   ExportFuncNameMD->addOperand(nodeMD);
    351 
    352   return true;
    353 }
    354 
    355 }  // namespace bcc
    356