Home | History | Annotate | Download | only in opt
      1 // Copyright (c) 2016 Google Inc.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "source/opt/fold_spec_constant_op_and_composite_pass.h"
     16 
     17 #include <algorithm>
     18 #include <initializer_list>
     19 #include <tuple>
     20 
     21 #include "source/opt/constants.h"
     22 #include "source/opt/fold.h"
     23 #include "source/opt/ir_context.h"
     24 #include "source/util/make_unique.h"
     25 
     26 namespace spvtools {
     27 namespace opt {
     28 
     29 Pass::Status FoldSpecConstantOpAndCompositePass::Process() {
     30   bool modified = false;
     31   // Traverse through all the constant defining instructions. For Normal
     32   // Constants whose values are determined and do not depend on OpUndef
     33   // instructions, records their values in two internal maps: id_to_const_val_
     34   // and const_val_to_id_ so that we can use them to infer the value of Spec
     35   // Constants later.
     36   // For Spec Constants defined with OpSpecConstantComposite instructions, if
     37   // all of their components are Normal Constants, they will be turned into
     38   // Normal Constants too. For Spec Constants defined with OpSpecConstantOp
     39   // instructions, we check if they only depends on Normal Constants and fold
     40   // them when possible. The two maps for Normal Constants: id_to_const_val_
     41   // and const_val_to_id_ will be updated along the traversal so that the new
     42   // Normal Constants generated from folding can be used to fold following Spec
     43   // Constants.
     44   // This algorithm depends on the SSA property of SPIR-V when
     45   // defining constants. The dependent constants must be defined before the
     46   // dependee constants. So a dependent Spec Constant must be defined and
     47   // will be processed before its dependee Spec Constant. When we encounter
     48   // the dependee Spec Constants, all its dependent constants must have been
     49   // processed and all its dependent Spec Constants should have been folded if
     50   // possible.
     51   Module::inst_iterator next_inst = context()->types_values_begin();
     52   for (Module::inst_iterator inst_iter = next_inst;
     53        // Need to re-evaluate the end iterator since we may modify the list of
     54        // instructions in this section of the module as the process goes.
     55        inst_iter != context()->types_values_end(); inst_iter = next_inst) {
     56     ++next_inst;
     57     Instruction* inst = &*inst_iter;
     58     // Collect constant values of normal constants and process the
     59     // OpSpecConstantOp and OpSpecConstantComposite instructions if possible.
     60     // The constant values will be stored in analysis::Constant instances.
     61     // OpConstantSampler instruction is not collected here because it cannot be
     62     // used in OpSpecConstant{Composite|Op} instructions.
     63     // TODO(qining): If the constant or its type has decoration, we may need
     64     // to skip it.
     65     if (context()->get_constant_mgr()->GetType(inst) &&
     66         !context()->get_constant_mgr()->GetType(inst)->decoration_empty())
     67       continue;
     68     switch (SpvOp opcode = inst->opcode()) {
     69       // Records the values of Normal Constants.
     70       case SpvOp::SpvOpConstantTrue:
     71       case SpvOp::SpvOpConstantFalse:
     72       case SpvOp::SpvOpConstant:
     73       case SpvOp::SpvOpConstantNull:
     74       case SpvOp::SpvOpConstantComposite:
     75       case SpvOp::SpvOpSpecConstantComposite: {
     76         // A Constant instance will be created if the given instruction is a
     77         // Normal Constant whose value(s) are fixed. Note that for a composite
     78         // Spec Constant defined with OpSpecConstantComposite instruction, if
     79         // all of its components are Normal Constants already, the Spec
     80         // Constant will be turned in to a Normal Constant. In that case, a
     81         // Constant instance should also be created successfully and recorded
     82         // in the id_to_const_val_ and const_val_to_id_ mapps.
     83         if (auto const_value =
     84                 context()->get_constant_mgr()->GetConstantFromInst(inst)) {
     85           // Need to replace the OpSpecConstantComposite instruction with a
     86           // corresponding OpConstantComposite instruction.
     87           if (opcode == SpvOp::SpvOpSpecConstantComposite) {
     88             inst->SetOpcode(SpvOp::SpvOpConstantComposite);
     89             modified = true;
     90           }
     91           context()->get_constant_mgr()->MapConstantToInst(const_value, inst);
     92         }
     93         break;
     94       }
     95       // For a Spec Constants defined with OpSpecConstantOp instruction, check
     96       // if it only depends on Normal Constants. If so, the Spec Constant will
     97       // be folded. The original Spec Constant defining instruction will be
     98       // replaced by Normal Constant defining instructions, and the new Normal
     99       // Constants will be added to id_to_const_val_ and const_val_to_id_ so
    100       // that we can use the new Normal Constants when folding following Spec
    101       // Constants.
    102       case SpvOp::SpvOpSpecConstantOp:
    103         modified |= ProcessOpSpecConstantOp(&inst_iter);
    104         break;
    105       default:
    106         break;
    107     }
    108   }
    109   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
    110 }
    111 
    112 bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp(
    113     Module::inst_iterator* pos) {
    114   Instruction* inst = &**pos;
    115   Instruction* folded_inst = nullptr;
    116   assert(inst->GetInOperand(0).type ==
    117              SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER &&
    118          "The first in-operand of OpSpecContantOp instruction must be of "
    119          "SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER type");
    120 
    121   switch (static_cast<SpvOp>(inst->GetSingleWordInOperand(0))) {
    122     case SpvOp::SpvOpCompositeExtract:
    123       folded_inst = DoCompositeExtract(pos);
    124       break;
    125     case SpvOp::SpvOpVectorShuffle:
    126       folded_inst = DoVectorShuffle(pos);
    127       break;
    128 
    129     case SpvOp::SpvOpCompositeInsert:
    130       // Current Glslang does not generate code with OpSpecConstantOp
    131       // CompositeInsert instruction, so this is not implmented so far.
    132       // TODO(qining): Implement CompositeInsert case.
    133       return false;
    134 
    135     default:
    136       // Component-wise operations.
    137       folded_inst = DoComponentWiseOperation(pos);
    138       break;
    139   }
    140   if (!folded_inst) return false;
    141 
    142   // Replace the original constant with the new folded constant, kill the
    143   // original constant.
    144   uint32_t new_id = folded_inst->result_id();
    145   uint32_t old_id = inst->result_id();
    146   context()->ReplaceAllUsesWith(old_id, new_id);
    147   context()->KillDef(old_id);
    148   return true;
    149 }
    150 
    151 uint32_t FoldSpecConstantOpAndCompositePass::GetTypeComponent(
    152     uint32_t typeId, uint32_t element) const {
    153   Instruction* type = context()->get_def_use_mgr()->GetDef(typeId);
    154   uint32_t subtype = type->GetTypeComponent(element);
    155   assert(subtype != 0);
    156 
    157   return subtype;
    158 }
    159 
    160 Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract(
    161     Module::inst_iterator* pos) {
    162   Instruction* inst = &**pos;
    163   assert(inst->NumInOperands() - 1 >= 2 &&
    164          "OpSpecConstantOp CompositeExtract requires at least two non-type "
    165          "non-opcode operands.");
    166   assert(inst->GetInOperand(1).type == SPV_OPERAND_TYPE_ID &&
    167          "The composite operand must have a SPV_OPERAND_TYPE_ID type");
    168   assert(
    169       inst->GetInOperand(2).type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
    170       "The literal operand must have a SPV_OPERAND_TYPE_LITERAL_INTEGER type");
    171 
    172   // Note that for OpSpecConstantOp, the second in-operand is the first id
    173   // operand. The first in-operand is the spec opcode.
    174   uint32_t source = inst->GetSingleWordInOperand(1);
    175   uint32_t type = context()->get_def_use_mgr()->GetDef(source)->type_id();
    176   const analysis::Constant* first_operand_const =
    177       context()->get_constant_mgr()->FindDeclaredConstant(source);
    178   if (!first_operand_const) return nullptr;
    179 
    180   const analysis::Constant* current_const = first_operand_const;
    181   for (uint32_t i = 2; i < inst->NumInOperands(); i++) {
    182     uint32_t literal = inst->GetSingleWordInOperand(i);
    183     type = GetTypeComponent(type, literal);
    184   }
    185   for (uint32_t i = 2; i < inst->NumInOperands(); i++) {
    186     uint32_t literal = inst->GetSingleWordInOperand(i);
    187     if (const analysis::CompositeConstant* composite_const =
    188             current_const->AsCompositeConstant()) {
    189       // Case 1: current constant is a non-null composite type constant.
    190       assert(literal < composite_const->GetComponents().size() &&
    191              "Literal index out of bound of the composite constant");
    192       current_const = composite_const->GetComponents().at(literal);
    193     } else if (current_const->AsNullConstant()) {
    194       // Case 2: current constant is a constant created with OpConstantNull.
    195       // Because components of a NullConstant are always NullConstants, we can
    196       // return early with a NullConstant in the result type.
    197       return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
    198           context()->get_constant_mgr()->GetConstant(
    199               context()->get_constant_mgr()->GetType(inst), {}),
    200           pos, type);
    201     } else {
    202       // Dereferencing a non-composite constant. Invalid case.
    203       return nullptr;
    204     }
    205   }
    206   return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
    207       current_const, pos);
    208 }
    209 
    210 Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle(
    211     Module::inst_iterator* pos) {
    212   Instruction* inst = &**pos;
    213   analysis::Vector* result_vec_type =
    214       context()->get_constant_mgr()->GetType(inst)->AsVector();
    215   assert(inst->NumInOperands() - 1 > 2 &&
    216          "OpSpecConstantOp DoVectorShuffle instruction requires more than 2 "
    217          "operands (2 vector ids and at least one literal operand");
    218   assert(result_vec_type &&
    219          "The result of VectorShuffle must be of type vector");
    220 
    221   // A temporary null constants that can be used as the components of the result
    222   // vector. This is needed when any one of the vector operands are null
    223   // constant.
    224   const analysis::Constant* null_component_constants = nullptr;
    225 
    226   // Get a concatenated vector of scalar constants. The vector should be built
    227   // with the components from the first and the second operand of VectorShuffle.
    228   std::vector<const analysis::Constant*> concatenated_components;
    229   // Note that for OpSpecConstantOp, the second in-operand is the first id
    230   // operand. The first in-operand is the spec opcode.
    231   for (uint32_t i : {1, 2}) {
    232     assert(inst->GetInOperand(i).type == SPV_OPERAND_TYPE_ID &&
    233            "The vector operand must have a SPV_OPERAND_TYPE_ID type");
    234     uint32_t operand_id = inst->GetSingleWordInOperand(i);
    235     auto operand_const =
    236         context()->get_constant_mgr()->FindDeclaredConstant(operand_id);
    237     if (!operand_const) return nullptr;
    238     const analysis::Type* operand_type = operand_const->type();
    239     assert(operand_type->AsVector() &&
    240            "The first two operand of VectorShuffle must be of vector type");
    241     if (auto vec_const = operand_const->AsVectorConstant()) {
    242       // case 1: current operand is a non-null vector constant.
    243       concatenated_components.insert(concatenated_components.end(),
    244                                      vec_const->GetComponents().begin(),
    245                                      vec_const->GetComponents().end());
    246     } else if (operand_const->AsNullConstant()) {
    247       // case 2: current operand is a null vector constant. Create a temporary
    248       // null scalar constant as the component.
    249       if (!null_component_constants) {
    250         const analysis::Type* component_type =
    251             operand_type->AsVector()->element_type();
    252         null_component_constants =
    253             context()->get_constant_mgr()->GetConstant(component_type, {});
    254       }
    255       // Append the null scalar consts to the concatenated components
    256       // vector.
    257       concatenated_components.insert(concatenated_components.end(),
    258                                      operand_type->AsVector()->element_count(),
    259                                      null_component_constants);
    260     } else {
    261       // no other valid cases
    262       return nullptr;
    263     }
    264   }
    265   // Create null component constants if there are any. The component constants
    266   // must be added to the module before the dependee composite constants to
    267   // satisfy SSA def-use dominance.
    268   if (null_component_constants) {
    269     context()->get_constant_mgr()->BuildInstructionAndAddToModule(
    270         null_component_constants, pos);
    271   }
    272   // Create the new vector constant with the selected components.
    273   std::vector<const analysis::Constant*> selected_components;
    274   for (uint32_t i = 3; i < inst->NumInOperands(); i++) {
    275     assert(inst->GetInOperand(i).type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
    276            "The literal operand must of type SPV_OPERAND_TYPE_LITERAL_INTEGER");
    277     uint32_t literal = inst->GetSingleWordInOperand(i);
    278     assert(literal < concatenated_components.size() &&
    279            "Literal index out of bound of the concatenated vector");
    280     selected_components.push_back(concatenated_components[literal]);
    281   }
    282   auto new_vec_const = MakeUnique<analysis::VectorConstant>(
    283       result_vec_type, selected_components);
    284   auto reg_vec_const =
    285       context()->get_constant_mgr()->RegisterConstant(std::move(new_vec_const));
    286   return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
    287       reg_vec_const, pos);
    288 }
    289 
    290 namespace {
    291 // A helper function to check the type for component wise operations. Returns
    292 // true if the type:
    293 //  1) is bool type;
    294 //  2) is 32-bit int type;
    295 //  3) is vector of bool type;
    296 //  4) is vector of 32-bit integer type.
    297 // Otherwise returns false.
    298 bool IsValidTypeForComponentWiseOperation(const analysis::Type* type) {
    299   if (type->AsBool()) {
    300     return true;
    301   } else if (auto* it = type->AsInteger()) {
    302     if (it->width() == 32) return true;
    303   } else if (auto* vt = type->AsVector()) {
    304     if (vt->element_type()->AsBool()) {
    305       return true;
    306     } else if (auto* vit = vt->element_type()->AsInteger()) {
    307       if (vit->width() == 32) return true;
    308     }
    309   }
    310   return false;
    311 }
    312 }  // namespace
    313 
    314 Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
    315     Module::inst_iterator* pos) {
    316   const Instruction* inst = &**pos;
    317   const analysis::Type* result_type =
    318       context()->get_constant_mgr()->GetType(inst);
    319   SpvOp spec_opcode = static_cast<SpvOp>(inst->GetSingleWordInOperand(0));
    320   // Check and collect operands.
    321   std::vector<const analysis::Constant*> operands;
    322 
    323   if (!std::all_of(
    324           inst->cbegin(), inst->cend(), [&operands, this](const Operand& o) {
    325             // skip the operands that is not an id.
    326             if (o.type != spv_operand_type_t::SPV_OPERAND_TYPE_ID) return true;
    327             uint32_t id = o.words.front();
    328             if (auto c =
    329                     context()->get_constant_mgr()->FindDeclaredConstant(id)) {
    330               if (IsValidTypeForComponentWiseOperation(c->type())) {
    331                 operands.push_back(c);
    332                 return true;
    333               }
    334             }
    335             return false;
    336           }))
    337     return nullptr;
    338 
    339   if (result_type->AsInteger() || result_type->AsBool()) {
    340     // Scalar operation
    341     uint32_t result_val =
    342         context()->get_instruction_folder().FoldScalars(spec_opcode, operands);
    343     auto result_const =
    344         context()->get_constant_mgr()->GetConstant(result_type, {result_val});
    345     return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
    346         result_const, pos);
    347   } else if (result_type->AsVector()) {
    348     // Vector operation
    349     const analysis::Type* element_type =
    350         result_type->AsVector()->element_type();
    351     uint32_t num_dims = result_type->AsVector()->element_count();
    352     std::vector<uint32_t> result_vec =
    353         context()->get_instruction_folder().FoldVectors(spec_opcode, num_dims,
    354                                                         operands);
    355     std::vector<const analysis::Constant*> result_vector_components;
    356     for (uint32_t r : result_vec) {
    357       if (auto rc =
    358               context()->get_constant_mgr()->GetConstant(element_type, {r})) {
    359         result_vector_components.push_back(rc);
    360         if (!context()->get_constant_mgr()->BuildInstructionAndAddToModule(
    361                 rc, pos)) {
    362           assert(false &&
    363                  "Failed to build and insert constant declaring instruction "
    364                  "for the given vector component constant");
    365         }
    366       } else {
    367         assert(false && "Failed to create constants with 32-bit word");
    368       }
    369     }
    370     auto new_vec_const = MakeUnique<analysis::VectorConstant>(
    371         result_type->AsVector(), result_vector_components);
    372     auto reg_vec_const = context()->get_constant_mgr()->RegisterConstant(
    373         std::move(new_vec_const));
    374     return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
    375         reg_vec_const, pos);
    376   } else {
    377     // Cannot process invalid component wise operation. The result of component
    378     // wise operation must be of integer or bool scalar or vector of
    379     // integer/bool type.
    380     return nullptr;
    381   }
    382 }
    383 
    384 }  // namespace opt
    385 }  // namespace spvtools
    386