Home | History | Annotate | Download | only in opt
      1 // Copyright (c) 2016 Google Inc.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "fold_spec_constant_op_and_composite_pass.h"
     16 
     17 #include <algorithm>
     18 #include <initializer_list>
     19 #include <tuple>
     20 
     21 #include "constants.h"
     22 #include "make_unique.h"
     23 
     24 namespace spvtools {
     25 namespace opt {
     26 
     27 namespace {
     28 // Returns the single-word result from performing the given unary operation on
     29 // the operand value which is passed in as a 32-bit word.
     30 uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) {
     31   switch (opcode) {
     32     // Arthimetics
     33     case SpvOp::SpvOpSNegate:
     34       return -static_cast<int32_t>(operand);
     35     case SpvOp::SpvOpNot:
     36       return ~operand;
     37     case SpvOp::SpvOpLogicalNot:
     38       return !static_cast<bool>(operand);
     39     default:
     40       assert(false &&
     41              "Unsupported unary operation for OpSpecConstantOp instruction");
     42       return 0u;
     43   }
     44 }
     45 
     46 // Returns the single-word result from performing the given binary operation on
     47 // the operand values which are passed in as two 32-bit word.
     48 uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) {
     49   switch (opcode) {
     50     // Arthimetics
     51     case SpvOp::SpvOpIAdd:
     52       return a + b;
     53     case SpvOp::SpvOpISub:
     54       return a - b;
     55     case SpvOp::SpvOpIMul:
     56       return a * b;
     57     case SpvOp::SpvOpUDiv:
     58       assert(b != 0);
     59       return a / b;
     60     case SpvOp::SpvOpSDiv:
     61       assert(b != 0u);
     62       return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
     63     case SpvOp::SpvOpSRem: {
     64       // The sign of non-zero result comes from the first operand: a. This is
     65       // guaranteed by C++11 rules for integer division operator. The division
     66       // result is rounded toward zero, so the result of '%' has the sign of
     67       // the first operand.
     68       assert(b != 0u);
     69       return static_cast<int32_t>(a) % static_cast<int32_t>(b);
     70     }
     71     case SpvOp::SpvOpSMod: {
     72       // The sign of non-zero result comes from the second operand: b
     73       assert(b != 0u);
     74       int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b);
     75       int32_t b_prim = static_cast<int32_t>(b);
     76       return (rem + b_prim) % b_prim;
     77     }
     78     case SpvOp::SpvOpUMod:
     79       assert(b != 0u);
     80       return (a % b);
     81 
     82     // Shifting
     83     case SpvOp::SpvOpShiftRightLogical: {
     84       return a >> b;
     85     }
     86     case SpvOp::SpvOpShiftRightArithmetic:
     87       return (static_cast<int32_t>(a)) >> b;
     88     case SpvOp::SpvOpShiftLeftLogical:
     89       return a << b;
     90 
     91     // Bitwise operations
     92     case SpvOp::SpvOpBitwiseOr:
     93       return a | b;
     94     case SpvOp::SpvOpBitwiseAnd:
     95       return a & b;
     96     case SpvOp::SpvOpBitwiseXor:
     97       return a ^ b;
     98 
     99     // Logical
    100     case SpvOp::SpvOpLogicalEqual:
    101       return (static_cast<bool>(a)) == (static_cast<bool>(b));
    102     case SpvOp::SpvOpLogicalNotEqual:
    103       return (static_cast<bool>(a)) != (static_cast<bool>(b));
    104     case SpvOp::SpvOpLogicalOr:
    105       return (static_cast<bool>(a)) || (static_cast<bool>(b));
    106     case SpvOp::SpvOpLogicalAnd:
    107       return (static_cast<bool>(a)) && (static_cast<bool>(b));
    108 
    109     // Comparison
    110     case SpvOp::SpvOpIEqual:
    111       return a == b;
    112     case SpvOp::SpvOpINotEqual:
    113       return a != b;
    114     case SpvOp::SpvOpULessThan:
    115       return a < b;
    116     case SpvOp::SpvOpSLessThan:
    117       return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
    118     case SpvOp::SpvOpUGreaterThan:
    119       return a > b;
    120     case SpvOp::SpvOpSGreaterThan:
    121       return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
    122     case SpvOp::SpvOpULessThanEqual:
    123       return a <= b;
    124     case SpvOp::SpvOpSLessThanEqual:
    125       return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
    126     case SpvOp::SpvOpUGreaterThanEqual:
    127       return a >= b;
    128     case SpvOp::SpvOpSGreaterThanEqual:
    129       return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
    130     default:
    131       assert(false &&
    132              "Unsupported binary operation for OpSpecConstantOp instruction");
    133       return 0u;
    134   }
    135 }
    136 
    137 // Returns the single-word result from performing the given ternary operation
    138 // on the operand values which are passed in as three 32-bit word.
    139 uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, uint32_t c) {
    140   switch (opcode) {
    141     case SpvOp::SpvOpSelect:
    142       return (static_cast<bool>(a)) ? b : c;
    143     default:
    144       assert(false &&
    145              "Unsupported ternary operation for OpSpecConstantOp instruction");
    146       return 0u;
    147   }
    148 }
    149 
    150 // Returns the single-word result from performing the given operation on the
    151 // operand words. This only works with 32-bit operations and uses boolean
    152 // convention that 0u is false, and anything else is boolean true.
    153 // TODO(qining): Support operands other than 32-bit wide.
    154 uint32_t OperateWords(SpvOp opcode,
    155                       const std::vector<uint32_t>& operand_words) {
    156   switch (operand_words.size()) {
    157     case 1:
    158       return UnaryOperate(opcode, operand_words.front());
    159     case 2:
    160       return BinaryOperate(opcode, operand_words.front(), operand_words.back());
    161     case 3:
    162       return TernaryOperate(opcode, operand_words[0], operand_words[1],
    163                             operand_words[2]);
    164     default:
    165       assert(false && "Invalid number of operands");
    166       return 0;
    167   }
    168 }
    169 
    170 // Returns the result of performing an operation on scalar constant operands.
    171 // This function extracts the operand values as 32 bit words and returns the
    172 // result in 32 bit word. Scalar constants with longer than 32-bit width are
    173 // not accepted in this function.
    174 uint32_t OperateScalars(SpvOp opcode,
    175                         const std::vector<analysis::Constant*>& operands) {
    176   std::vector<uint32_t> operand_values_in_raw_words;
    177   for (analysis::Constant* operand : operands) {
    178     if (analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
    179       const auto& scalar_words = scalar->words();
    180       assert(scalar_words.size() == 1 &&
    181              "Scalar constants with longer than 32-bit width are not allowed "
    182              "in OperateScalars()");
    183       operand_values_in_raw_words.push_back(scalar_words.front());
    184     } else if (operand->AsNullConstant()) {
    185       operand_values_in_raw_words.push_back(0u);
    186     } else {
    187       assert(false &&
    188              "OperateScalars() only accepts ScalarConst or NullConst type of "
    189              "constant");
    190     }
    191   }
    192   return OperateWords(opcode, operand_values_in_raw_words);
    193 }
    194 
    195 // Returns the result of performing an operation over constant vectors. This
    196 // function iterates through the given vector type constant operands and
    197 // calculates the result for each element of the result vector to return.
    198 // Vectors with longer than 32-bit scalar components are not accepted in this
    199 // function.
    200 std::vector<uint32_t> OperateVectors(
    201     SpvOp opcode, uint32_t num_dims,
    202     const std::vector<analysis::Constant*>& operands) {
    203   std::vector<uint32_t> result;
    204   for (uint32_t d = 0; d < num_dims; d++) {
    205     std::vector<uint32_t> operand_values_for_one_dimension;
    206     for (analysis::Constant* operand : operands) {
    207       if (analysis::VectorConstant* vector_operand =
    208               operand->AsVectorConstant()) {
    209         // Extract the raw value of the scalar component constants
    210         // in 32-bit words here. The reason of not using OperateScalars() here
    211         // is that we do not create temporary null constants as components
    212         // when the vector operand is a NullConstant because Constant creation
    213         // may need extra checks for the validity and that is not manageed in
    214         // here.
    215         if (const analysis::ScalarConstant* scalar_component =
    216                 vector_operand->GetComponents().at(d)->AsScalarConstant()) {
    217           const auto& scalar_words = scalar_component->words();
    218           assert(
    219               scalar_words.size() == 1 &&
    220               "Vector components with longer than 32-bit width are not allowed "
    221               "in OperateVectors()");
    222           operand_values_for_one_dimension.push_back(scalar_words.front());
    223         } else if (operand->AsNullConstant()) {
    224           operand_values_for_one_dimension.push_back(0u);
    225         } else {
    226           assert(false &&
    227                  "VectorConst should only has ScalarConst or NullConst as "
    228                  "components");
    229         }
    230       } else if (operand->AsNullConstant()) {
    231         operand_values_for_one_dimension.push_back(0u);
    232       } else {
    233         assert(false &&
    234                "OperateVectors() only accepts VectorConst or NullConst type of "
    235                "constant");
    236       }
    237     }
    238     result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
    239   }
    240   return result;
    241 }
    242 }  // anonymous namespace
    243 
    244 FoldSpecConstantOpAndCompositePass::FoldSpecConstantOpAndCompositePass()
    245     : max_id_(0),
    246       module_(nullptr),
    247       def_use_mgr_(nullptr),
    248       type_mgr_(nullptr),
    249       id_to_const_val_() {}
    250 
    251 Pass::Status FoldSpecConstantOpAndCompositePass::Process(ir::Module* module) {
    252   Initialize(module);
    253   return ProcessImpl(module);
    254 }
    255 
    256 void FoldSpecConstantOpAndCompositePass::Initialize(ir::Module* module) {
    257   type_mgr_.reset(new analysis::TypeManager(consumer(), *module));
    258   def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module));
    259   for (const auto& id_def : def_use_mgr_->id_to_defs()) {
    260     max_id_ = std::max(max_id_, id_def.first);
    261   }
    262   module_ = module;
    263 };
    264 
    265 Pass::Status FoldSpecConstantOpAndCompositePass::ProcessImpl(
    266     ir::Module* module) {
    267   bool modified = false;
    268   // Traverse through all the constant defining instructions. For Normal
    269   // Constants whose values are determined and do not depend on OpUndef
    270   // instructions, records their values in two internal maps: id_to_const_val_
    271   // and const_val_to_id_ so that we can use them to infer the value of Spec
    272   // Constants later.
    273   // For Spec Constants defined with OpSpecConstantComposite instructions, if
    274   // all of their components are Normal Constants, they will be turned into
    275   // Normal Constants too. For Spec Constants defined with OpSpecConstantOp
    276   // instructions, we check if they only depends on Normal Constants and fold
    277   // them when possible. The two maps for Normal Constants: id_to_const_val_
    278   // and const_val_to_id_ will be updated along the traversal so that the new
    279   // Normal Constants generated from folding can be used to fold following Spec
    280   // Constants.
    281   // This algorithm depends on the SSA property of SPIR-V when
    282   // defining constants. The dependent constants must be defined before the
    283   // dependee constants. So a dependent Spec Constant must be defined and
    284   // will be processed before its dependee Spec Constant. When we encounter
    285   // the dependee Spec Constants, all its dependent constants must have been
    286   // processed and all its dependent Spec Constants should have been folded if
    287   // possible.
    288   for (ir::Module::inst_iterator inst_iter = module->types_values_begin();
    289        // Need to re-evaluate the end iterator since we may modify the list of
    290        // instructions in this section of the module as the process goes.
    291        inst_iter != module->types_values_end(); ++inst_iter) {
    292     ir::Instruction* inst = &*inst_iter;
    293     // Collect constant values of normal constants and process the
    294     // OpSpecConstantOp and OpSpecConstantComposite instructions if possible.
    295     // The constant values will be stored in analysis::Constant instances.
    296     // OpConstantSampler instruction is not collected here because it cannot be
    297     // used in OpSpecConstant{Composite|Op} instructions.
    298     // TODO(qining): If the constant or its type has decoration, we may need
    299     // to skip it.
    300     if (GetType(inst) && !GetType(inst)->decoration_empty()) continue;
    301     switch (SpvOp opcode = inst->opcode()) {
    302       // Records the values of Normal Constants.
    303       case SpvOp::SpvOpConstantTrue:
    304       case SpvOp::SpvOpConstantFalse:
    305       case SpvOp::SpvOpConstant:
    306       case SpvOp::SpvOpConstantNull:
    307       case SpvOp::SpvOpConstantComposite:
    308       case SpvOp::SpvOpSpecConstantComposite: {
    309         // A Constant instance will be created if the given instruction is a
    310         // Normal Constant whose value(s) are fixed. Note that for a composite
    311         // Spec Constant defined with OpSpecConstantComposite instruction, if
    312         // all of its components are Normal Constants already, the Spec
    313         // Constant will be turned in to a Normal Constant. In that case, a
    314         // Constant instance should also be created successfully and recorded
    315         // in the id_to_const_val_ and const_val_to_id_ mapps.
    316         if (auto const_value = CreateConstFromInst(inst)) {
    317           // Need to replace the OpSpecConstantComposite instruction with a
    318           // corresponding OpConstantComposite instruction.
    319           if (opcode == SpvOp::SpvOpSpecConstantComposite) {
    320             inst->SetOpcode(SpvOp::SpvOpConstantComposite);
    321             modified = true;
    322           }
    323           const_val_to_id_[const_value.get()] = inst->result_id();
    324           id_to_const_val_[inst->result_id()] = std::move(const_value);
    325         }
    326         break;
    327       }
    328       // For a Spec Constants defined with OpSpecConstantOp instruction, check
    329       // if it only depends on Normal Constants. If so, the Spec Constant will
    330       // be folded. The original Spec Constant defining instruction will be
    331       // replaced by Normal Constant defining instructions, and the new Normal
    332       // Constants will be added to id_to_const_val_ and const_val_to_id_ so
    333       // that we can use the new Normal Constants when folding following Spec
    334       // Constants.
    335       case SpvOp::SpvOpSpecConstantOp:
    336         modified |= ProcessOpSpecConstantOp(&inst_iter);
    337         break;
    338       default:
    339         break;
    340     }
    341   }
    342   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
    343 }
    344 
    345 bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp(
    346     ir::Module::inst_iterator* pos) {
    347   ir::Instruction* inst = &**pos;
    348   ir::Instruction* folded_inst = nullptr;
    349   assert(inst->GetInOperand(0).type ==
    350              SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER &&
    351          "The first in-operand of OpSpecContantOp instruction must be of "
    352          "SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER type");
    353 
    354   switch (static_cast<SpvOp>(inst->GetSingleWordInOperand(0))) {
    355     case SpvOp::SpvOpCompositeExtract:
    356       folded_inst = DoCompositeExtract(pos);
    357       break;
    358     case SpvOp::SpvOpVectorShuffle:
    359       folded_inst = DoVectorShuffle(pos);
    360       break;
    361 
    362     case SpvOp::SpvOpCompositeInsert:
    363       // Current Glslang does not generate code with OpSpecConstantOp
    364       // CompositeInsert instruction, so this is not implmented so far.
    365       // TODO(qining): Implement CompositeInsert case.
    366       return false;
    367 
    368     default:
    369       // Component-wise operations.
    370       folded_inst = DoComponentWiseOperation(pos);
    371       break;
    372   }
    373   if (!folded_inst) return false;
    374 
    375   // Replace the original constant with the new folded constant, kill the
    376   // original constant.
    377   uint32_t new_id = folded_inst->result_id();
    378   uint32_t old_id = inst->result_id();
    379   def_use_mgr_->ReplaceAllUsesWith(old_id, new_id);
    380   def_use_mgr_->KillDef(old_id);
    381   return true;
    382 }
    383 
    384 ir::Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract(
    385     ir::Module::inst_iterator* pos) {
    386   ir::Instruction* inst = &**pos;
    387   assert(inst->NumInOperands() - 1 >= 2 &&
    388          "OpSpecConstantOp CompositeExtract requires at least two non-type "
    389          "non-opcode operands.");
    390   assert(inst->GetInOperand(1).type == SPV_OPERAND_TYPE_ID &&
    391          "The vector operand must have a SPV_OPERAND_TYPE_ID type");
    392   assert(
    393       inst->GetInOperand(2).type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
    394       "The literal operand must have a SPV_OPERAND_TYPE_LITERAL_INTEGER type");
    395 
    396   // Note that for OpSpecConstantOp, the second in-operand is the first id
    397   // operand. The first in-operand is the spec opcode.
    398   analysis::Constant* first_operand_const =
    399       FindRecordedConst(inst->GetSingleWordInOperand(1));
    400   if (!first_operand_const) return nullptr;
    401 
    402   const analysis::Constant* current_const = first_operand_const;
    403   for (uint32_t i = 2; i < inst->NumInOperands(); i++) {
    404     uint32_t literal = inst->GetSingleWordInOperand(i);
    405     if (const analysis::CompositeConstant* composite_const =
    406             current_const->AsCompositeConstant()) {
    407       // Case 1: current constant is a non-null composite type constant.
    408       assert(literal < composite_const->GetComponents().size() &&
    409              "Literal index out of bound of the composite constant");
    410       current_const = composite_const->GetComponents().at(literal);
    411     } else if (current_const->AsNullConstant()) {
    412       // Case 2: current constant is a constant created with OpConstantNull.
    413       // Because components of a NullConstant are always NullConstants, we can
    414       // return early with a NullConstant in the result type.
    415       return BuildInstructionAndAddToModule(CreateConst(GetType(inst), {}),
    416                                             pos);
    417     } else {
    418       // Dereferencing a non-composite constant. Invalid case.
    419       return nullptr;
    420     }
    421   }
    422   return BuildInstructionAndAddToModule(current_const->Copy(), pos);
    423 }
    424 
    425 ir::Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle(
    426     ir::Module::inst_iterator* pos) {
    427   ir::Instruction* inst = &**pos;
    428   analysis::Vector* result_vec_type = GetType(inst)->AsVector();
    429   assert(inst->NumInOperands() - 1 > 2 &&
    430          "OpSpecConstantOp DoVectorShuffle instruction requires more than 2 "
    431          "operands (2 vector ids and at least one literal operand");
    432   assert(result_vec_type &&
    433          "The result of VectorShuffle must be of type vector");
    434 
    435   // A temporary null constants that can be used as the components fo the
    436   // result vector. This is needed when any one of the vector operands are null
    437   // constant.
    438   std::unique_ptr<analysis::Constant> null_component_constants;
    439 
    440   // Get a concatenated vector of scalar constants. The vector should be built
    441   // with the components from the first and the second operand of VectorShuffle.
    442   std::vector<const analysis::Constant*> concatenated_components;
    443   // Note that for OpSpecConstantOp, the second in-operand is the first id
    444   // operand. The first in-operand is the spec opcode.
    445   for (uint32_t i : {1, 2}) {
    446     assert(inst->GetInOperand(i).type == SPV_OPERAND_TYPE_ID &&
    447            "The vector operand must have a SPV_OPERAND_TYPE_ID type");
    448     uint32_t operand_id = inst->GetSingleWordInOperand(i);
    449     analysis::Constant* operand_const = FindRecordedConst(operand_id);
    450     if (!operand_const) return nullptr;
    451     const analysis::Type* operand_type = operand_const->type();
    452     assert(operand_type->AsVector() &&
    453            "The first two operand of VectorShuffle must be of vector type");
    454     if (analysis::VectorConstant* vec_const =
    455             operand_const->AsVectorConstant()) {
    456       // case 1: current operand is a non-null vector constant.
    457       concatenated_components.insert(concatenated_components.end(),
    458                                      vec_const->GetComponents().begin(),
    459                                      vec_const->GetComponents().end());
    460     } else if (operand_const->AsNullConstant()) {
    461       // case 2: current operand is a null vector constant. Create a temporary
    462       // null scalar constant as the component.
    463       if (!null_component_constants) {
    464         const analysis::Type* component_type =
    465             operand_type->AsVector()->element_type();
    466         null_component_constants = CreateConst(component_type, {});
    467       }
    468       // Append the null scalar consts to the concatenated components
    469       // vector.
    470       concatenated_components.insert(concatenated_components.end(),
    471                                      operand_type->AsVector()->element_count(),
    472                                      null_component_constants.get());
    473     } else {
    474       // no other valid cases
    475       return nullptr;
    476     }
    477   }
    478   // Create null component constants if there are any. The component constants
    479   // must be added to the module before the dependee composite constants to
    480   // satisfy SSA def-use dominance.
    481   if (null_component_constants) {
    482     BuildInstructionAndAddToModule(std::move(null_component_constants), pos);
    483   }
    484   // Create the new vector constant with the selected components.
    485   std::vector<const analysis::Constant*> selected_components;
    486   for (uint32_t i = 3; i < inst->NumInOperands(); i++) {
    487     assert(inst->GetInOperand(i).type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
    488            "The literal operand must of type SPV_OPERAND_TYPE_LITERAL_INTEGER");
    489     uint32_t literal = inst->GetSingleWordInOperand(i);
    490     assert(literal < concatenated_components.size() &&
    491            "Literal index out of bound of the concatenated vector");
    492     selected_components.push_back(concatenated_components[literal]);
    493   }
    494   auto new_vec_const = MakeUnique<analysis::VectorConstant>(
    495       result_vec_type, selected_components);
    496   return BuildInstructionAndAddToModule(std::move(new_vec_const), pos);
    497 }
    498 
    499 namespace {
    500 // A helper function to check the type for component wise operations. Returns
    501 // true if the type:
    502 //  1) is bool type;
    503 //  2) is 32-bit int type;
    504 //  3) is vector of bool type;
    505 //  4) is vector of 32-bit integer type.
    506 // Otherwise returns false.
    507 bool IsValidTypeForComponentWiseOperation(const analysis::Type* type) {
    508   if (type->AsBool()) {
    509     return true;
    510   } else if (auto* it = type->AsInteger()) {
    511     if (it->width() == 32) return true;
    512   } else if (auto* vt = type->AsVector()) {
    513     if (vt->element_type()->AsBool())
    514       return true;
    515     else if (auto* vit = vt->element_type()->AsInteger()) {
    516       if (vit->width() == 32) return true;
    517     }
    518   }
    519   return false;
    520 }
    521 }
    522 
    523 ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
    524     ir::Module::inst_iterator* pos) {
    525   const ir::Instruction* inst = &**pos;
    526   const analysis::Type* result_type = GetType(inst);
    527   SpvOp spec_opcode = static_cast<SpvOp>(inst->GetSingleWordInOperand(0));
    528   // Check and collect operands.
    529   std::vector<analysis::Constant*> operands;
    530 
    531   if (!std::all_of(inst->cbegin(), inst->cend(),
    532                    [&operands, this](const ir::Operand& o) {
    533                      // skip the operands that is not an id.
    534                      if (o.type != spv_operand_type_t::SPV_OPERAND_TYPE_ID)
    535                        return true;
    536                      uint32_t id = o.words.front();
    537                      if (analysis::Constant* c = FindRecordedConst(id)) {
    538                        if (IsValidTypeForComponentWiseOperation(c->type())) {
    539                          operands.push_back(c);
    540                          return true;
    541                        }
    542                      }
    543                      return false;
    544                    }))
    545     return nullptr;
    546 
    547   if (result_type->AsInteger() || result_type->AsBool()) {
    548     // Scalar operation
    549     uint32_t result_val = OperateScalars(spec_opcode, operands);
    550     auto result_const = CreateConst(result_type, {result_val});
    551     return BuildInstructionAndAddToModule(std::move(result_const), pos);
    552   } else if (result_type->AsVector()) {
    553     // Vector operation
    554     const analysis::Type* element_type =
    555         result_type->AsVector()->element_type();
    556     uint32_t num_dims = result_type->AsVector()->element_count();
    557     std::vector<uint32_t> result_vec =
    558         OperateVectors(spec_opcode, num_dims, operands);
    559     std::vector<const analysis::Constant*> result_vector_components;
    560     for (uint32_t r : result_vec) {
    561       if (auto rc = CreateConst(element_type, {r})) {
    562         result_vector_components.push_back(rc.get());
    563         if (!BuildInstructionAndAddToModule(std::move(rc), pos)) {
    564           assert(false &&
    565                  "Failed to build and insert constant declaring instruction "
    566                  "for the given vector component constant");
    567         }
    568       } else {
    569         assert(false && "Failed to create constants with 32-bit word");
    570       }
    571     }
    572     auto new_vec_const = MakeUnique<analysis::VectorConstant>(
    573         result_type->AsVector(), result_vector_components);
    574     return BuildInstructionAndAddToModule(std::move(new_vec_const), pos);
    575   } else {
    576     // Cannot process invalid component wise operation. The result of component
    577     // wise operation must be of integer or bool scalar or vector of
    578     // integer/bool type.
    579     return nullptr;
    580   }
    581 }
    582 
    583 ir::Instruction*
    584 FoldSpecConstantOpAndCompositePass::BuildInstructionAndAddToModule(
    585     std::unique_ptr<analysis::Constant> c, ir::Module::inst_iterator* pos) {
    586   analysis::Constant* new_const = c.get();
    587   uint32_t new_id = ++max_id_;
    588   module_->SetIdBound(new_id + 1);
    589   const_val_to_id_[new_const] = new_id;
    590   id_to_const_val_[new_id] = std::move(c);
    591   auto new_inst = CreateInstruction(new_id, new_const);
    592   if (!new_inst) return nullptr;
    593   auto* new_inst_ptr = new_inst.get();
    594   *pos = pos->InsertBefore(std::move(new_inst));
    595   (*pos)++;
    596   def_use_mgr_->AnalyzeInstDefUse(new_inst_ptr);
    597   return new_inst_ptr;
    598 }
    599 
    600 std::unique_ptr<analysis::Constant>
    601 FoldSpecConstantOpAndCompositePass::CreateConstFromInst(ir::Instruction* inst) {
    602   std::vector<uint32_t> literal_words_or_ids;
    603   std::unique_ptr<analysis::Constant> new_const;
    604   // Collect the constant defining literals or component ids.
    605   for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
    606     literal_words_or_ids.insert(literal_words_or_ids.end(),
    607                                 inst->GetInOperand(i).words.begin(),
    608                                 inst->GetInOperand(i).words.end());
    609   }
    610   switch (inst->opcode()) {
    611     // OpConstant{True|Flase} have the value embedded in the opcode. So they
    612     // are not handled by the for-loop above. Here we add the value explicitly.
    613     case SpvOp::SpvOpConstantTrue:
    614       literal_words_or_ids.push_back(true);
    615       break;
    616     case SpvOp::SpvOpConstantFalse:
    617       literal_words_or_ids.push_back(false);
    618       break;
    619     case SpvOp::SpvOpConstantNull:
    620     case SpvOp::SpvOpConstant:
    621     case SpvOp::SpvOpConstantComposite:
    622     case SpvOp::SpvOpSpecConstantComposite:
    623       break;
    624     default:
    625       return nullptr;
    626   }
    627   return CreateConst(GetType(inst), literal_words_or_ids);
    628 }
    629 
    630 analysis::Constant* FoldSpecConstantOpAndCompositePass::FindRecordedConst(
    631     uint32_t id) {
    632   auto iter = id_to_const_val_.find(id);
    633   if (iter == id_to_const_val_.end()) {
    634     return nullptr;
    635   } else {
    636     return iter->second.get();
    637   }
    638 }
    639 
    640 uint32_t FoldSpecConstantOpAndCompositePass::FindRecordedConst(
    641     const analysis::Constant* c) {
    642   auto iter = const_val_to_id_.find(c);
    643   if (iter == const_val_to_id_.end()) {
    644     return 0;
    645   } else {
    646     return iter->second;
    647   }
    648 }
    649 
    650 std::vector<const analysis::Constant*>
    651 FoldSpecConstantOpAndCompositePass::GetConstsFromIds(
    652     const std::vector<uint32_t>& ids) {
    653   std::vector<const analysis::Constant*> constants;
    654   for (uint32_t id : ids) {
    655     if (analysis::Constant* c = FindRecordedConst(id)) {
    656       constants.push_back(c);
    657     } else {
    658       return {};
    659     }
    660   }
    661   return constants;
    662 }
    663 
    664 std::unique_ptr<analysis::Constant>
    665 FoldSpecConstantOpAndCompositePass::CreateConst(
    666     const analysis::Type* type,
    667     const std::vector<uint32_t>& literal_words_or_ids) {
    668   std::unique_ptr<analysis::Constant> new_const;
    669   if (literal_words_or_ids.size() == 0) {
    670     // Constant declared with OpConstantNull
    671     return MakeUnique<analysis::NullConstant>(type);
    672   } else if (auto* bt = type->AsBool()) {
    673     assert(literal_words_or_ids.size() == 1 &&
    674            "Bool constant should be declared with one operand");
    675     return MakeUnique<analysis::BoolConstant>(bt, literal_words_or_ids.front());
    676   } else if (auto* it = type->AsInteger()) {
    677     return MakeUnique<analysis::IntConstant>(it, literal_words_or_ids);
    678   } else if (auto* ft = type->AsFloat()) {
    679     return MakeUnique<analysis::FloatConstant>(ft, literal_words_or_ids);
    680   } else if (auto* vt = type->AsVector()) {
    681     auto components = GetConstsFromIds(literal_words_or_ids);
    682     if (components.empty()) return nullptr;
    683     // All components of VectorConstant must be of type Bool, Integer or Float.
    684     if (!std::all_of(components.begin(), components.end(),
    685                      [](const analysis::Constant* c) {
    686                        if (c->type()->AsBool() || c->type()->AsInteger() ||
    687                            c->type()->AsFloat()) {
    688                          return true;
    689                        } else {
    690                          return false;
    691                        }
    692                      }))
    693       return nullptr;
    694     // All components of VectorConstant must be in the same type.
    695     const auto* component_type = components.front()->type();
    696     if (!std::all_of(components.begin(), components.end(),
    697                      [&component_type](const analysis::Constant* c) {
    698                        if (c->type() == component_type) return true;
    699                        return false;
    700                      }))
    701       return nullptr;
    702     return MakeUnique<analysis::VectorConstant>(vt, components);
    703   } else if (auto* st = type->AsStruct()) {
    704     auto components = GetConstsFromIds(literal_words_or_ids);
    705     if (components.empty()) return nullptr;
    706     return MakeUnique<analysis::StructConstant>(st, components);
    707   } else if (auto* at = type->AsArray()) {
    708     auto components = GetConstsFromIds(literal_words_or_ids);
    709     if (components.empty()) return nullptr;
    710     return MakeUnique<analysis::ArrayConstant>(at, components);
    711   } else {
    712     return nullptr;
    713   }
    714 }
    715 
    716 std::vector<ir::Operand> BuildOperandsFromIds(
    717     const std::vector<uint32_t>& ids) {
    718   std::vector<ir::Operand> operands;
    719   for (uint32_t id : ids) {
    720     operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
    721                           std::initializer_list<uint32_t>{id});
    722   }
    723   return operands;
    724 }
    725 
    726 std::unique_ptr<ir::Instruction>
    727 FoldSpecConstantOpAndCompositePass::CreateInstruction(uint32_t id,
    728                                                       analysis::Constant* c) {
    729   if (c->AsNullConstant()) {
    730     return MakeUnique<ir::Instruction>(SpvOp::SpvOpConstantNull,
    731                                        type_mgr_->GetId(c->type()), id,
    732                                        std::initializer_list<ir::Operand>{});
    733   } else if (analysis::BoolConstant* bc = c->AsBoolConstant()) {
    734     return MakeUnique<ir::Instruction>(
    735         bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse,
    736         type_mgr_->GetId(c->type()), id, std::initializer_list<ir::Operand>{});
    737   } else if (analysis::IntConstant* ic = c->AsIntConstant()) {
    738     return MakeUnique<ir::Instruction>(
    739         SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id,
    740         std::initializer_list<ir::Operand>{ir::Operand(
    741             spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
    742             ic->words())});
    743   } else if (analysis::FloatConstant* fc = c->AsFloatConstant()) {
    744     return MakeUnique<ir::Instruction>(
    745         SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id,
    746         std::initializer_list<ir::Operand>{ir::Operand(
    747             spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
    748             fc->words())});
    749   } else if (analysis::CompositeConstant* cc = c->AsCompositeConstant()) {
    750     return CreateCompositeInstruction(id, cc);
    751   } else {
    752     return nullptr;
    753   }
    754 }
    755 
    756 std::unique_ptr<ir::Instruction>
    757 FoldSpecConstantOpAndCompositePass::CreateCompositeInstruction(
    758     uint32_t result_id, analysis::CompositeConstant* cc) {
    759   std::vector<ir::Operand> operands;
    760   for (const analysis::Constant* component_const : cc->GetComponents()) {
    761     uint32_t id = FindRecordedConst(component_const);
    762     if (id == 0) {
    763       // Cannot get the id of the component constant, while all components
    764       // should have been added to the module prior to the composite constant.
    765       // Cannot create OpConstantComposite instruction in this case.
    766       return nullptr;
    767     }
    768     operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
    769                           std::initializer_list<uint32_t>{id});
    770   }
    771   return MakeUnique<ir::Instruction>(SpvOp::SpvOpConstantComposite,
    772                                      type_mgr_->GetId(cc->type()), result_id,
    773                                      std::move(operands));
    774 }
    775 
    776 }  // namespace opt
    777 }  // namespace spvtools
    778