1 /* 2 * Copyright 2015, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "bcc/Renderscript/RSScriptGroupFusion.h" 18 19 #include "bcc/Assert.h" 20 #include "bcc/BCCContext.h" 21 #include "bcc/Source.h" 22 #include "bcc/Support/Log.h" 23 #include "bcinfo/MetadataExtractor.h" 24 #include "llvm/ADT/StringExtras.h" 25 #include "llvm/IR/DataLayout.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/IR/Module.h" 28 #include "llvm/Support/raw_ostream.h" 29 30 using llvm::Function; 31 using llvm::Module; 32 33 using std::string; 34 35 namespace bcc { 36 37 namespace { 38 39 const Function* getInvokeFunction(const Source& source, const int slot, 40 Module* newModule) { 41 42 bcinfo::MetadataExtractor &metadata = *source.getMetadata(); 43 const char* functionName = metadata.getExportFuncNameList()[slot]; 44 Function* func = newModule->getFunction(functionName); 45 // Materialize the function so that later the caller can inspect its argument 46 // and return types. 47 newModule->materialize(func); 48 return func; 49 } 50 51 const Function* 52 getFunction(Module* mergedModule, const Source* source, const int slot, 53 uint32_t* signature) { 54 55 bcinfo::MetadataExtractor &metadata = *source->getMetadata(); 56 const char* functionName = metadata.getExportForEachNameList()[slot]; 57 if (functionName == nullptr || !functionName[0]) { 58 ALOGE("Kernel fusion (module %s slot %d): failed to find kernel function", 59 source->getName().c_str(), slot); 60 return nullptr; 61 } 62 63 if (metadata.getExportForEachInputCountList()[slot] > 1) { 64 ALOGE("Kernel fusion (module %s function %s): cannot handle multiple inputs", 65 source->getName().c_str(), functionName); 66 return nullptr; 67 } 68 69 if (signature != nullptr) { 70 *signature = metadata.getExportForEachSignatureList()[slot]; 71 } 72 73 const Function* function = mergedModule->getFunction(functionName); 74 75 return function; 76 } 77 78 // The whitelist of supported signature bits. Context or user data arguments are 79 // not currently supported in kernel fusion. To support them or any new kinds of 80 // arguments in the future, it requires not only listing the signature bits here, 81 // but also implementing additional necessary fusion logic in the getFusedFuncSig(), 82 // getFusedFuncType(), and fuseKernels() functions below. 83 constexpr uint32_t ExpectedSignatureBits = 84 bcinfo::MD_SIG_In | 85 bcinfo::MD_SIG_Out | 86 bcinfo::MD_SIG_X | 87 bcinfo::MD_SIG_Y | 88 bcinfo::MD_SIG_Z | 89 bcinfo::MD_SIG_Kernel; 90 91 int getFusedFuncSig(const std::vector<Source*>& sources, 92 const std::vector<int>& slots, 93 uint32_t* retSig) { 94 *retSig = 0; 95 uint32_t firstSignature = 0; 96 uint32_t signature = 0; 97 auto slotIter = slots.begin(); 98 for (const Source* source : sources) { 99 const int slot = *slotIter++; 100 bcinfo::MetadataExtractor &metadata = *source->getMetadata(); 101 102 if (metadata.getExportForEachInputCountList()[slot] > 1) { 103 ALOGE("Kernel fusion (module %s slot %d): cannot handle multiple inputs", 104 source->getName().c_str(), slot); 105 return -1; 106 } 107 108 signature = metadata.getExportForEachSignatureList()[slot]; 109 if (signature & ~ExpectedSignatureBits) { 110 ALOGE("Kernel fusion (module %s slot %d): Unexpected signature %x", 111 source->getName().c_str(), slot, signature); 112 return -1; 113 } 114 115 if (firstSignature == 0) { 116 firstSignature = signature; 117 } 118 119 *retSig |= signature; 120 } 121 122 if (!bcinfo::MetadataExtractor::hasForEachSignatureIn(firstSignature)) { 123 *retSig &= ~bcinfo::MD_SIG_In; 124 } 125 126 if (!bcinfo::MetadataExtractor::hasForEachSignatureOut(signature)) { 127 *retSig &= ~bcinfo::MD_SIG_Out; 128 } 129 130 return 0; 131 } 132 133 llvm::FunctionType* getFusedFuncType(bcc::BCCContext& Context, 134 const std::vector<Source*>& sources, 135 const std::vector<int>& slots, 136 Module* M, 137 uint32_t* signature) { 138 int error = getFusedFuncSig(sources, slots, signature); 139 140 if (error < 0) { 141 return nullptr; 142 } 143 144 const Function* firstF = getFunction(M, sources.front(), slots.front(), nullptr); 145 146 bccAssert (firstF != nullptr); 147 148 llvm::SmallVector<llvm::Type*, 8> ArgTys; 149 150 if (bcinfo::MetadataExtractor::hasForEachSignatureIn(*signature)) { 151 ArgTys.push_back(firstF->arg_begin()->getType()); 152 } 153 154 llvm::Type* I32Ty = llvm::IntegerType::get(Context.getLLVMContext(), 32); 155 if (bcinfo::MetadataExtractor::hasForEachSignatureX(*signature)) { 156 ArgTys.push_back(I32Ty); 157 } 158 if (bcinfo::MetadataExtractor::hasForEachSignatureY(*signature)) { 159 ArgTys.push_back(I32Ty); 160 } 161 if (bcinfo::MetadataExtractor::hasForEachSignatureZ(*signature)) { 162 ArgTys.push_back(I32Ty); 163 } 164 165 const Function* lastF = getFunction(M, sources.back(), slots.back(), nullptr); 166 167 bccAssert (lastF != nullptr); 168 169 llvm::Type* retTy = lastF->getReturnType(); 170 171 return llvm::FunctionType::get(retTy, ArgTys, false); 172 } 173 174 } // anonymous namespace 175 176 bool fuseKernels(bcc::BCCContext& Context, 177 const std::vector<Source *>& sources, 178 const std::vector<int>& slots, 179 const std::string& fusedName, 180 Module* mergedModule) { 181 bccAssert(sources.size() == slots.size() && "sources and slots differ in size"); 182 183 uint32_t fusedFunctionSignature; 184 185 llvm::FunctionType* fusedType = 186 getFusedFuncType(Context, sources, slots, mergedModule, &fusedFunctionSignature); 187 188 if (fusedType == nullptr) { 189 return false; 190 } 191 192 Function* fusedKernel = 193 (Function*)(mergedModule->getOrInsertFunction(fusedName, fusedType)); 194 195 llvm::LLVMContext& ctxt = Context.getLLVMContext(); 196 197 llvm::BasicBlock* block = llvm::BasicBlock::Create(ctxt, "entry", fusedKernel); 198 llvm::IRBuilder<> builder(block); 199 200 Function::arg_iterator argIter = fusedKernel->arg_begin(); 201 202 llvm::Value* dataElement = nullptr; 203 if (bcinfo::MetadataExtractor::hasForEachSignatureIn(fusedFunctionSignature)) { 204 dataElement = &*(argIter++); 205 dataElement->setName("DataIn"); 206 } 207 208 llvm::Value* X = nullptr; 209 if (bcinfo::MetadataExtractor::hasForEachSignatureX(fusedFunctionSignature)) { 210 X = &*(argIter++); 211 X->setName("x"); 212 } 213 214 llvm::Value* Y = nullptr; 215 if (bcinfo::MetadataExtractor::hasForEachSignatureY(fusedFunctionSignature)) { 216 Y = &*(argIter++); 217 Y->setName("y"); 218 } 219 220 llvm::Value* Z = nullptr; 221 if (bcinfo::MetadataExtractor::hasForEachSignatureZ(fusedFunctionSignature)) { 222 Z = &*(argIter++); 223 Z->setName("z"); 224 } 225 226 auto slotIter = slots.begin(); 227 for (const Source* source : sources) { 228 int slot = *slotIter; 229 230 uint32_t inputFunctionSignature; 231 const Function* inputFunction = 232 getFunction(mergedModule, source, slot, &inputFunctionSignature); 233 if (inputFunction == nullptr) { 234 // Either failed to find the kernel function, or the function has multiple inputs. 235 return false; 236 } 237 238 // Don't try to fuse a non-kernel 239 if (!bcinfo::MetadataExtractor::hasForEachSignatureKernel(inputFunctionSignature)) { 240 ALOGE("Kernel fusion (module %s function %s): not a kernel", 241 source->getName().c_str(), inputFunction->getName().str().c_str()); 242 return false; 243 } 244 245 std::vector<llvm::Value*> args; 246 247 if (bcinfo::MetadataExtractor::hasForEachSignatureIn(inputFunctionSignature)) { 248 if (dataElement == nullptr) { 249 ALOGE("Kernel fusion (module %s function %s): expected input, but got null", 250 source->getName().c_str(), inputFunction->getName().str().c_str()); 251 return false; 252 } 253 254 const llvm::FunctionType* funcTy = inputFunction->getFunctionType(); 255 llvm::Type* firstArgType = funcTy->getParamType(0); 256 257 if (dataElement->getType() != firstArgType) { 258 std::string msg; 259 llvm::raw_string_ostream rso(msg); 260 rso << "Mismatching argument type, expected "; 261 firstArgType->print(rso); 262 rso << ", received "; 263 dataElement->getType()->print(rso); 264 ALOGE("Kernel fusion (module %s function %s): %s", source->getName().c_str(), 265 inputFunction->getName().str().c_str(), rso.str().c_str()); 266 return false; 267 } 268 269 args.push_back(dataElement); 270 } else { 271 // Only the first kernel in a batch is allowed to have no input 272 if (slotIter != slots.begin()) { 273 ALOGE("Kernel fusion (module %s function %s): function not first in batch takes no input", 274 source->getName().c_str(), inputFunction->getName().str().c_str()); 275 return false; 276 } 277 } 278 279 if (bcinfo::MetadataExtractor::hasForEachSignatureX(inputFunctionSignature)) { 280 args.push_back(X); 281 } 282 283 if (bcinfo::MetadataExtractor::hasForEachSignatureY(inputFunctionSignature)) { 284 args.push_back(Y); 285 } 286 287 if (bcinfo::MetadataExtractor::hasForEachSignatureZ(inputFunctionSignature)) { 288 args.push_back(Z); 289 } 290 291 dataElement = builder.CreateCall((llvm::Value*)inputFunction, args); 292 293 slotIter++; 294 } 295 296 if (fusedKernel->getReturnType()->isVoidTy()) { 297 builder.CreateRetVoid(); 298 } else { 299 builder.CreateRet(dataElement); 300 } 301 302 llvm::NamedMDNode* ExportForEachNameMD = 303 mergedModule->getOrInsertNamedMetadata("#rs_export_foreach_name"); 304 305 llvm::MDString* nameMDStr = llvm::MDString::get(ctxt, fusedName); 306 llvm::MDNode* nameMDNode = llvm::MDNode::get(ctxt, nameMDStr); 307 ExportForEachNameMD->addOperand(nameMDNode); 308 309 llvm::NamedMDNode* ExportForEachMD = 310 mergedModule->getOrInsertNamedMetadata("#rs_export_foreach"); 311 llvm::MDString* sigMDStr = llvm::MDString::get(ctxt, 312 llvm::utostr_32(fusedFunctionSignature)); 313 llvm::MDNode* sigMDNode = llvm::MDNode::get(ctxt, sigMDStr); 314 ExportForEachMD->addOperand(sigMDNode); 315 316 return true; 317 } 318 319 bool renameInvoke(BCCContext& Context, const Source* source, const int slot, 320 const std::string& newName, Module* module) { 321 const llvm::Function* F = getInvokeFunction(*source, slot, module); 322 std::vector<llvm::Type*> params; 323 for (auto I = F->arg_begin(), E = F->arg_end(); I != E; ++I) { 324 params.push_back(I->getType()); 325 } 326 llvm::Type* returnTy = F->getReturnType(); 327 328 llvm::FunctionType* batchFuncTy = 329 llvm::FunctionType::get(returnTy, params, false); 330 331 llvm::Function* newF = 332 llvm::Function::Create(batchFuncTy, 333 llvm::GlobalValue::ExternalLinkage, newName, 334 module); 335 336 llvm::BasicBlock* block = llvm::BasicBlock::Create(Context.getLLVMContext(), 337 "entry", newF); 338 llvm::IRBuilder<> builder(block); 339 340 llvm::Function::arg_iterator argIter = newF->arg_begin(); 341 llvm::Value* arg1 = &*(argIter++); 342 builder.CreateCall((llvm::Value*)F, arg1); 343 344 builder.CreateRetVoid(); 345 346 llvm::NamedMDNode* ExportFuncNameMD = 347 module->getOrInsertNamedMetadata("#rs_export_func"); 348 llvm::MDString* strMD = llvm::MDString::get(module->getContext(), newName); 349 llvm::MDNode* nodeMD = llvm::MDNode::get(module->getContext(), strMD); 350 ExportFuncNameMD->addOperand(nodeMD); 351 352 return true; 353 } 354 355 } // namespace bcc 356