Home | History | Annotate | Download | only in SPIRV
      1 //===- SPIRVLowerBool.cpp  Lower instructions with bool operands ----------===//
      2 //
      3 //                     The LLVM/SPIRV Translator
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 // Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved.
      9 //
     10 // Permission is hereby granted, free of charge, to any person obtaining a
     11 // copy of this software and associated documentation files (the "Software"),
     12 // to deal with the Software without restriction, including without limitation
     13 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
     14 // and/or sell copies of the Software, and to permit persons to whom the
     15 // Software is furnished to do so, subject to the following conditions:
     16 //
     17 // Redistributions of source code must retain the above copyright notice,
     18 // this list of conditions and the following disclaimers.
     19 // Redistributions in binary form must reproduce the above copyright notice,
     20 // this list of conditions and the following disclaimers in the documentation
     21 // and/or other materials provided with the distribution.
     22 // Neither the names of Advanced Micro Devices, Inc., nor the names of its
     23 // contributors may be used to endorse or promote products derived from this
     24 // Software without specific prior written permission.
     25 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     26 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
     27 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
     28 // CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
     29 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
     30 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH
     31 // THE SOFTWARE.
     32 //
     33 //===----------------------------------------------------------------------===//
     34 //
     35 // This file implements lowering instructions with bool operands.
     36 //
     37 //===----------------------------------------------------------------------===//
     38 #define DEBUG_TYPE "spvbool"
     39 
     40 #include "SPIRVInternal.h"
     41 #include "llvm/IR/InstVisitor.h"
     42 #include "llvm/IR/Instructions.h"
     43 #include "llvm/IR/IRBuilder.h"
     44 #include "llvm/IR/Verifier.h"
     45 #include "llvm/Pass.h"
     46 #include "llvm/PassSupport.h"
     47 #include "llvm/Support/CommandLine.h"
     48 #include "llvm/Support/Debug.h"
     49 #include "llvm/Support/raw_ostream.h"
     50 
     51 using namespace llvm;
     52 using namespace SPIRV;
     53 
     54 namespace SPIRV {
     55 cl::opt<bool> SPIRVLowerBoolValidate("spvbool-validate",
     56     cl::desc("Validate module after lowering boolean instructions for SPIR-V"));
     57 
     58 class SPIRVLowerBool: public ModulePass,
     59   public InstVisitor<SPIRVLowerBool> {
     60 public:
     61   SPIRVLowerBool():ModulePass(ID), Context(nullptr) {
     62     initializeSPIRVLowerBoolPass(*PassRegistry::getPassRegistry());
     63   }
     64   void replace(Instruction *I, Instruction *NewI) {
     65     NewI->takeName(I);
     66     I->replaceAllUsesWith(NewI);
     67     I->dropAllReferences();
     68     I->eraseFromParent();
     69   }
     70   bool isBoolType(Type *Ty) {
     71     if (Ty->isIntegerTy(1))
     72       return true;
     73     if (auto VT = dyn_cast<VectorType>(Ty))
     74       return isBoolType(VT->getElementType());
     75     return false;
     76   }
     77   virtual void visitTruncInst(TruncInst &I) {
     78     if (isBoolType(I.getType())) {
     79       auto Op = I.getOperand(0);
     80       auto Zero = getScalarOrVectorConstantInt(Op->getType(), 0, false);
     81       auto Cmp = new ICmpInst(&I, CmpInst::ICMP_NE, Op, Zero);
     82       replace(&I, Cmp);
     83     }
     84   }
     85   virtual void visitZExtInst(ZExtInst &I) {
     86     auto Op = I.getOperand(0);
     87     if (isBoolType(Op->getType())) {
     88       auto Ty = I.getType();
     89       auto Zero = getScalarOrVectorConstantInt(Ty, 0, false);
     90       auto One = getScalarOrVectorConstantInt(Ty, 1, false);
     91       auto Sel = SelectInst::Create(Op, One, Zero, "", &I);
     92       replace(&I, Sel);
     93     }
     94   }
     95   virtual void visitSExtInst(SExtInst &I) {
     96     auto Op = I.getOperand(0);
     97     if (isBoolType(Op->getType())) {
     98       auto Ty = I.getType();
     99       auto Zero = getScalarOrVectorConstantInt(Ty, 0, false);
    100       auto One = getScalarOrVectorConstantInt(Ty, ~0, false);
    101       auto Sel = SelectInst::Create(Op, One, Zero, "", &I);
    102       replace(&I, Sel);
    103     }
    104   }
    105   virtual bool runOnModule(Module &M) {
    106     Context = &M.getContext();
    107     visit(M);
    108 
    109     if (SPIRVLowerBoolValidate) {
    110       DEBUG(dbgs() << "After SPIRVLowerBool:\n" << M);
    111       std::string Err;
    112       raw_string_ostream ErrorOS(Err);
    113       if (verifyModule(M, &ErrorOS)){
    114         Err = std::string("Fails to verify module: ") + Err;
    115         report_fatal_error(Err.c_str(), false);
    116       }
    117     }
    118     return true;
    119   }
    120 
    121   static char ID;
    122 private:
    123   LLVMContext *Context;
    124 };
    125 
    126 char SPIRVLowerBool::ID = 0;
    127 }
    128 
    129 INITIALIZE_PASS(SPIRVLowerBool, "spvbool",
    130     "Lower instructions with bool operands", false, false)
    131 
    132 ModulePass *llvm::createSPIRVLowerBool() {
    133   return new SPIRVLowerBool();
    134 }
    135