1 //===- SPIRVLowerBool.cpp Lower instructions with bool operands ----------===// 2 // 3 // The LLVM/SPIRV Translator 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 // Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved. 9 // 10 // Permission is hereby granted, free of charge, to any person obtaining a 11 // copy of this software and associated documentation files (the "Software"), 12 // to deal with the Software without restriction, including without limitation 13 // the rights to use, copy, modify, merge, publish, distribute, sublicense, 14 // and/or sell copies of the Software, and to permit persons to whom the 15 // Software is furnished to do so, subject to the following conditions: 16 // 17 // Redistributions of source code must retain the above copyright notice, 18 // this list of conditions and the following disclaimers. 19 // Redistributions in binary form must reproduce the above copyright notice, 20 // this list of conditions and the following disclaimers in the documentation 21 // and/or other materials provided with the distribution. 22 // Neither the names of Advanced Micro Devices, Inc., nor the names of its 23 // contributors may be used to endorse or promote products derived from this 24 // Software without specific prior written permission. 25 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 26 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 27 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 28 // CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 29 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 30 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH 31 // THE SOFTWARE. 32 // 33 //===----------------------------------------------------------------------===// 34 // 35 // This file implements lowering instructions with bool operands. 36 // 37 //===----------------------------------------------------------------------===// 38 #define DEBUG_TYPE "spvbool" 39 40 #include "SPIRVInternal.h" 41 #include "llvm/IR/InstVisitor.h" 42 #include "llvm/IR/Instructions.h" 43 #include "llvm/IR/IRBuilder.h" 44 #include "llvm/IR/Verifier.h" 45 #include "llvm/Pass.h" 46 #include "llvm/PassSupport.h" 47 #include "llvm/Support/CommandLine.h" 48 #include "llvm/Support/Debug.h" 49 #include "llvm/Support/raw_ostream.h" 50 51 using namespace llvm; 52 using namespace SPIRV; 53 54 namespace SPIRV { 55 cl::opt<bool> SPIRVLowerBoolValidate("spvbool-validate", 56 cl::desc("Validate module after lowering boolean instructions for SPIR-V")); 57 58 class SPIRVLowerBool: public ModulePass, 59 public InstVisitor<SPIRVLowerBool> { 60 public: 61 SPIRVLowerBool():ModulePass(ID), Context(nullptr) { 62 initializeSPIRVLowerBoolPass(*PassRegistry::getPassRegistry()); 63 } 64 void replace(Instruction *I, Instruction *NewI) { 65 NewI->takeName(I); 66 I->replaceAllUsesWith(NewI); 67 I->dropAllReferences(); 68 I->eraseFromParent(); 69 } 70 bool isBoolType(Type *Ty) { 71 if (Ty->isIntegerTy(1)) 72 return true; 73 if (auto VT = dyn_cast<VectorType>(Ty)) 74 return isBoolType(VT->getElementType()); 75 return false; 76 } 77 virtual void visitTruncInst(TruncInst &I) { 78 if (isBoolType(I.getType())) { 79 auto Op = I.getOperand(0); 80 auto Zero = getScalarOrVectorConstantInt(Op->getType(), 0, false); 81 auto Cmp = new ICmpInst(&I, CmpInst::ICMP_NE, Op, Zero); 82 replace(&I, Cmp); 83 } 84 } 85 virtual void visitZExtInst(ZExtInst &I) { 86 auto Op = I.getOperand(0); 87 if (isBoolType(Op->getType())) { 88 auto Ty = I.getType(); 89 auto Zero = getScalarOrVectorConstantInt(Ty, 0, false); 90 auto One = getScalarOrVectorConstantInt(Ty, 1, false); 91 auto Sel = SelectInst::Create(Op, One, Zero, "", &I); 92 replace(&I, Sel); 93 } 94 } 95 virtual void visitSExtInst(SExtInst &I) { 96 auto Op = I.getOperand(0); 97 if (isBoolType(Op->getType())) { 98 auto Ty = I.getType(); 99 auto Zero = getScalarOrVectorConstantInt(Ty, 0, false); 100 auto One = getScalarOrVectorConstantInt(Ty, ~0, false); 101 auto Sel = SelectInst::Create(Op, One, Zero, "", &I); 102 replace(&I, Sel); 103 } 104 } 105 virtual bool runOnModule(Module &M) { 106 Context = &M.getContext(); 107 visit(M); 108 109 if (SPIRVLowerBoolValidate) { 110 DEBUG(dbgs() << "After SPIRVLowerBool:\n" << M); 111 std::string Err; 112 raw_string_ostream ErrorOS(Err); 113 if (verifyModule(M, &ErrorOS)){ 114 Err = std::string("Fails to verify module: ") + Err; 115 report_fatal_error(Err.c_str(), false); 116 } 117 } 118 return true; 119 } 120 121 static char ID; 122 private: 123 LLVMContext *Context; 124 }; 125 126 char SPIRVLowerBool::ID = 0; 127 } 128 129 INITIALIZE_PASS(SPIRVLowerBool, "spvbool", 130 "Lower instructions with bool operands", false, false) 131 132 ModulePass *llvm::createSPIRVLowerBool() { 133 return new SPIRVLowerBool(); 134 } 135