Home | History | Annotate | Download | only in opt
      1 // Copyright (c) 2017 Google Inc.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "source/opt/constants.h"
     16 
     17 #include <unordered_map>
     18 #include <vector>
     19 
     20 #include "source/opt/ir_context.h"
     21 
     22 namespace spvtools {
     23 namespace opt {
     24 namespace analysis {
     25 
     26 float Constant::GetFloat() const {
     27   assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 32);
     28 
     29   if (const FloatConstant* fc = AsFloatConstant()) {
     30     return fc->GetFloatValue();
     31   } else {
     32     assert(AsNullConstant() && "Must be a floating point constant.");
     33     return 0.0f;
     34   }
     35 }
     36 
     37 double Constant::GetDouble() const {
     38   assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 64);
     39 
     40   if (const FloatConstant* fc = AsFloatConstant()) {
     41     return fc->GetDoubleValue();
     42   } else {
     43     assert(AsNullConstant() && "Must be a floating point constant.");
     44     return 0.0;
     45   }
     46 }
     47 
     48 double Constant::GetValueAsDouble() const {
     49   assert(type()->AsFloat() != nullptr);
     50   if (type()->AsFloat()->width() == 32) {
     51     return GetFloat();
     52   } else {
     53     assert(type()->AsFloat()->width() == 64);
     54     return GetDouble();
     55   }
     56 }
     57 
     58 uint32_t Constant::GetU32() const {
     59   assert(type()->AsInteger() != nullptr);
     60   assert(type()->AsInteger()->width() == 32);
     61 
     62   if (const IntConstant* ic = AsIntConstant()) {
     63     return ic->GetU32BitValue();
     64   } else {
     65     assert(AsNullConstant() && "Must be an integer constant.");
     66     return 0u;
     67   }
     68 }
     69 
     70 uint64_t Constant::GetU64() const {
     71   assert(type()->AsInteger() != nullptr);
     72   assert(type()->AsInteger()->width() == 64);
     73 
     74   if (const IntConstant* ic = AsIntConstant()) {
     75     return ic->GetU64BitValue();
     76   } else {
     77     assert(AsNullConstant() && "Must be an integer constant.");
     78     return 0u;
     79   }
     80 }
     81 
     82 int32_t Constant::GetS32() const {
     83   assert(type()->AsInteger() != nullptr);
     84   assert(type()->AsInteger()->width() == 32);
     85 
     86   if (const IntConstant* ic = AsIntConstant()) {
     87     return ic->GetS32BitValue();
     88   } else {
     89     assert(AsNullConstant() && "Must be an integer constant.");
     90     return 0;
     91   }
     92 }
     93 
     94 int64_t Constant::GetS64() const {
     95   assert(type()->AsInteger() != nullptr);
     96   assert(type()->AsInteger()->width() == 64);
     97 
     98   if (const IntConstant* ic = AsIntConstant()) {
     99     return ic->GetS64BitValue();
    100   } else {
    101     assert(AsNullConstant() && "Must be an integer constant.");
    102     return 0;
    103   }
    104 }
    105 
    106 ConstantManager::ConstantManager(IRContext* ctx) : ctx_(ctx) {
    107   // Populate the constant table with values from constant declarations in the
    108   // module.  The values of each OpConstant declaration is the identity
    109   // assignment (i.e., each constant is its own value).
    110   for (const auto& inst : ctx_->module()->GetConstants()) {
    111     MapInst(inst);
    112   }
    113 }
    114 
    115 Type* ConstantManager::GetType(const Instruction* inst) const {
    116   return context()->get_type_mgr()->GetType(inst->type_id());
    117 }
    118 
    119 std::vector<const Constant*> ConstantManager::GetOperandConstants(
    120     Instruction* inst) const {
    121   std::vector<const Constant*> constants;
    122   for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
    123     const Operand* operand = &inst->GetInOperand(i);
    124     if (operand->type != SPV_OPERAND_TYPE_ID) {
    125       constants.push_back(nullptr);
    126     } else {
    127       uint32_t id = operand->words[0];
    128       const analysis::Constant* constant = FindDeclaredConstant(id);
    129       constants.push_back(constant);
    130     }
    131   }
    132   return constants;
    133 }
    134 
    135 uint32_t ConstantManager::FindDeclaredConstant(const Constant* c,
    136                                                uint32_t type_id) const {
    137   c = FindConstant(c);
    138   if (c == nullptr) {
    139     return 0;
    140   }
    141 
    142   for (auto range = const_val_to_id_.equal_range(c);
    143        range.first != range.second; ++range.first) {
    144     Instruction* const_def =
    145         context()->get_def_use_mgr()->GetDef(range.first->second);
    146     if (type_id == 0 || const_def->type_id() == type_id) {
    147       return range.first->second;
    148     }
    149   }
    150   return 0;
    151 }
    152 
    153 std::vector<const Constant*> ConstantManager::GetConstantsFromIds(
    154     const std::vector<uint32_t>& ids) const {
    155   std::vector<const Constant*> constants;
    156   for (uint32_t id : ids) {
    157     if (const Constant* c = FindDeclaredConstant(id)) {
    158       constants.push_back(c);
    159     } else {
    160       return {};
    161     }
    162   }
    163   return constants;
    164 }
    165 
    166 Instruction* ConstantManager::BuildInstructionAndAddToModule(
    167     const Constant* new_const, Module::inst_iterator* pos, uint32_t type_id) {
    168   // TODO(1841): Handle id overflow.
    169   uint32_t new_id = context()->TakeNextId();
    170   auto new_inst = CreateInstruction(new_id, new_const, type_id);
    171   if (!new_inst) {
    172     return nullptr;
    173   }
    174   auto* new_inst_ptr = new_inst.get();
    175   *pos = pos->InsertBefore(std::move(new_inst));
    176   ++(*pos);
    177   context()->get_def_use_mgr()->AnalyzeInstDefUse(new_inst_ptr);
    178   MapConstantToInst(new_const, new_inst_ptr);
    179   return new_inst_ptr;
    180 }
    181 
    182 Instruction* ConstantManager::GetDefiningInstruction(
    183     const Constant* c, uint32_t type_id, Module::inst_iterator* pos) {
    184   assert(type_id == 0 ||
    185          context()->get_type_mgr()->GetType(type_id) == c->type());
    186   uint32_t decl_id = FindDeclaredConstant(c, type_id);
    187   if (decl_id == 0) {
    188     auto iter = context()->types_values_end();
    189     if (pos == nullptr) pos = &iter;
    190     return BuildInstructionAndAddToModule(c, pos, type_id);
    191   } else {
    192     auto def = context()->get_def_use_mgr()->GetDef(decl_id);
    193     assert(def != nullptr);
    194     assert((type_id == 0 || def->type_id() == type_id) &&
    195            "This constant already has an instruction with a different type.");
    196     return def;
    197   }
    198 }
    199 
    200 std::unique_ptr<Constant> ConstantManager::CreateConstant(
    201     const Type* type, const std::vector<uint32_t>& literal_words_or_ids) const {
    202   if (literal_words_or_ids.size() == 0) {
    203     // Constant declared with OpConstantNull
    204     return MakeUnique<NullConstant>(type);
    205   } else if (auto* bt = type->AsBool()) {
    206     assert(literal_words_or_ids.size() == 1 &&
    207            "Bool constant should be declared with one operand");
    208     return MakeUnique<BoolConstant>(bt, literal_words_or_ids.front());
    209   } else if (auto* it = type->AsInteger()) {
    210     return MakeUnique<IntConstant>(it, literal_words_or_ids);
    211   } else if (auto* ft = type->AsFloat()) {
    212     return MakeUnique<FloatConstant>(ft, literal_words_or_ids);
    213   } else if (auto* vt = type->AsVector()) {
    214     auto components = GetConstantsFromIds(literal_words_or_ids);
    215     if (components.empty()) return nullptr;
    216     // All components of VectorConstant must be of type Bool, Integer or Float.
    217     if (!std::all_of(components.begin(), components.end(),
    218                      [](const Constant* c) {
    219                        if (c->type()->AsBool() || c->type()->AsInteger() ||
    220                            c->type()->AsFloat()) {
    221                          return true;
    222                        } else {
    223                          return false;
    224                        }
    225                      }))
    226       return nullptr;
    227     // All components of VectorConstant must be in the same type.
    228     const auto* component_type = components.front()->type();
    229     if (!std::all_of(components.begin(), components.end(),
    230                      [&component_type](const Constant* c) {
    231                        if (c->type() == component_type) return true;
    232                        return false;
    233                      }))
    234       return nullptr;
    235     return MakeUnique<VectorConstant>(vt, components);
    236   } else if (auto* mt = type->AsMatrix()) {
    237     auto components = GetConstantsFromIds(literal_words_or_ids);
    238     if (components.empty()) return nullptr;
    239     return MakeUnique<MatrixConstant>(mt, components);
    240   } else if (auto* st = type->AsStruct()) {
    241     auto components = GetConstantsFromIds(literal_words_or_ids);
    242     if (components.empty()) return nullptr;
    243     return MakeUnique<StructConstant>(st, components);
    244   } else if (auto* at = type->AsArray()) {
    245     auto components = GetConstantsFromIds(literal_words_or_ids);
    246     if (components.empty()) return nullptr;
    247     return MakeUnique<ArrayConstant>(at, components);
    248   } else {
    249     return nullptr;
    250   }
    251 }
    252 
    253 const Constant* ConstantManager::GetConstantFromInst(Instruction* inst) {
    254   std::vector<uint32_t> literal_words_or_ids;
    255 
    256   // Collect the constant defining literals or component ids.
    257   for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
    258     literal_words_or_ids.insert(literal_words_or_ids.end(),
    259                                 inst->GetInOperand(i).words.begin(),
    260                                 inst->GetInOperand(i).words.end());
    261   }
    262 
    263   switch (inst->opcode()) {
    264     // OpConstant{True|False} have the value embedded in the opcode. So they
    265     // are not handled by the for-loop above. Here we add the value explicitly.
    266     case SpvOp::SpvOpConstantTrue:
    267       literal_words_or_ids.push_back(true);
    268       break;
    269     case SpvOp::SpvOpConstantFalse:
    270       literal_words_or_ids.push_back(false);
    271       break;
    272     case SpvOp::SpvOpConstantNull:
    273     case SpvOp::SpvOpConstant:
    274     case SpvOp::SpvOpConstantComposite:
    275     case SpvOp::SpvOpSpecConstantComposite:
    276       break;
    277     default:
    278       return nullptr;
    279   }
    280 
    281   return GetConstant(GetType(inst), literal_words_or_ids);
    282 }
    283 
    284 std::unique_ptr<Instruction> ConstantManager::CreateInstruction(
    285     uint32_t id, const Constant* c, uint32_t type_id) const {
    286   uint32_t type =
    287       (type_id == 0) ? context()->get_type_mgr()->GetId(c->type()) : type_id;
    288   if (c->AsNullConstant()) {
    289     return MakeUnique<Instruction>(context(), SpvOp::SpvOpConstantNull, type,
    290                                    id, std::initializer_list<Operand>{});
    291   } else if (const BoolConstant* bc = c->AsBoolConstant()) {
    292     return MakeUnique<Instruction>(
    293         context(),
    294         bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse,
    295         type, id, std::initializer_list<Operand>{});
    296   } else if (const IntConstant* ic = c->AsIntConstant()) {
    297     return MakeUnique<Instruction>(
    298         context(), SpvOp::SpvOpConstant, type, id,
    299         std::initializer_list<Operand>{
    300             Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
    301                     ic->words())});
    302   } else if (const FloatConstant* fc = c->AsFloatConstant()) {
    303     return MakeUnique<Instruction>(
    304         context(), SpvOp::SpvOpConstant, type, id,
    305         std::initializer_list<Operand>{
    306             Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
    307                     fc->words())});
    308   } else if (const CompositeConstant* cc = c->AsCompositeConstant()) {
    309     return CreateCompositeInstruction(id, cc, type_id);
    310   } else {
    311     return nullptr;
    312   }
    313 }
    314 
    315 std::unique_ptr<Instruction> ConstantManager::CreateCompositeInstruction(
    316     uint32_t result_id, const CompositeConstant* cc, uint32_t type_id) const {
    317   std::vector<Operand> operands;
    318   Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id);
    319   uint32_t component_index = 0;
    320   for (const Constant* component_const : cc->GetComponents()) {
    321     uint32_t component_type_id = 0;
    322     if (type_inst && type_inst->opcode() == SpvOpTypeStruct) {
    323       component_type_id = type_inst->GetSingleWordInOperand(component_index);
    324     } else if (type_inst && type_inst->opcode() == SpvOpTypeArray) {
    325       component_type_id = type_inst->GetSingleWordInOperand(0);
    326     }
    327     uint32_t id = FindDeclaredConstant(component_const, component_type_id);
    328 
    329     if (id == 0) {
    330       // Cannot get the id of the component constant, while all components
    331       // should have been added to the module prior to the composite constant.
    332       // Cannot create OpConstantComposite instruction in this case.
    333       return nullptr;
    334     }
    335     operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
    336                           std::initializer_list<uint32_t>{id});
    337     component_index++;
    338   }
    339   uint32_t type =
    340       (type_id == 0) ? context()->get_type_mgr()->GetId(cc->type()) : type_id;
    341   return MakeUnique<Instruction>(context(), SpvOp::SpvOpConstantComposite, type,
    342                                  result_id, std::move(operands));
    343 }
    344 
    345 const Constant* ConstantManager::GetConstant(
    346     const Type* type, const std::vector<uint32_t>& literal_words_or_ids) {
    347   auto cst = CreateConstant(type, literal_words_or_ids);
    348   return cst ? RegisterConstant(std::move(cst)) : nullptr;
    349 }
    350 
    351 std::vector<const analysis::Constant*> Constant::GetVectorComponents(
    352     analysis::ConstantManager* const_mgr) const {
    353   std::vector<const analysis::Constant*> components;
    354   const analysis::VectorConstant* a = this->AsVectorConstant();
    355   const analysis::Vector* vector_type = this->type()->AsVector();
    356   assert(vector_type != nullptr);
    357   if (a != nullptr) {
    358     for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
    359       components.push_back(a->GetComponents()[i]);
    360     }
    361   } else {
    362     const analysis::Type* element_type = vector_type->element_type();
    363     const analysis::Constant* element_null_const =
    364         const_mgr->GetConstant(element_type, {});
    365     for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
    366       components.push_back(element_null_const);
    367     }
    368   }
    369   return components;
    370 }
    371 
    372 }  // namespace analysis
    373 }  // namespace opt
    374 }  // namespace spvtools
    375