1 // Copyright (c) 2017 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "source/opt/constants.h" 16 17 #include <unordered_map> 18 #include <vector> 19 20 #include "source/opt/ir_context.h" 21 22 namespace spvtools { 23 namespace opt { 24 namespace analysis { 25 26 float Constant::GetFloat() const { 27 assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 32); 28 29 if (const FloatConstant* fc = AsFloatConstant()) { 30 return fc->GetFloatValue(); 31 } else { 32 assert(AsNullConstant() && "Must be a floating point constant."); 33 return 0.0f; 34 } 35 } 36 37 double Constant::GetDouble() const { 38 assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 64); 39 40 if (const FloatConstant* fc = AsFloatConstant()) { 41 return fc->GetDoubleValue(); 42 } else { 43 assert(AsNullConstant() && "Must be a floating point constant."); 44 return 0.0; 45 } 46 } 47 48 double Constant::GetValueAsDouble() const { 49 assert(type()->AsFloat() != nullptr); 50 if (type()->AsFloat()->width() == 32) { 51 return GetFloat(); 52 } else { 53 assert(type()->AsFloat()->width() == 64); 54 return GetDouble(); 55 } 56 } 57 58 uint32_t Constant::GetU32() const { 59 assert(type()->AsInteger() != nullptr); 60 assert(type()->AsInteger()->width() == 32); 61 62 if (const IntConstant* ic = AsIntConstant()) { 63 return ic->GetU32BitValue(); 64 } else { 65 assert(AsNullConstant() && "Must be an integer constant."); 66 return 0u; 67 } 68 } 69 70 uint64_t Constant::GetU64() const { 71 assert(type()->AsInteger() != nullptr); 72 assert(type()->AsInteger()->width() == 64); 73 74 if (const IntConstant* ic = AsIntConstant()) { 75 return ic->GetU64BitValue(); 76 } else { 77 assert(AsNullConstant() && "Must be an integer constant."); 78 return 0u; 79 } 80 } 81 82 int32_t Constant::GetS32() const { 83 assert(type()->AsInteger() != nullptr); 84 assert(type()->AsInteger()->width() == 32); 85 86 if (const IntConstant* ic = AsIntConstant()) { 87 return ic->GetS32BitValue(); 88 } else { 89 assert(AsNullConstant() && "Must be an integer constant."); 90 return 0; 91 } 92 } 93 94 int64_t Constant::GetS64() const { 95 assert(type()->AsInteger() != nullptr); 96 assert(type()->AsInteger()->width() == 64); 97 98 if (const IntConstant* ic = AsIntConstant()) { 99 return ic->GetS64BitValue(); 100 } else { 101 assert(AsNullConstant() && "Must be an integer constant."); 102 return 0; 103 } 104 } 105 106 ConstantManager::ConstantManager(IRContext* ctx) : ctx_(ctx) { 107 // Populate the constant table with values from constant declarations in the 108 // module. The values of each OpConstant declaration is the identity 109 // assignment (i.e., each constant is its own value). 110 for (const auto& inst : ctx_->module()->GetConstants()) { 111 MapInst(inst); 112 } 113 } 114 115 Type* ConstantManager::GetType(const Instruction* inst) const { 116 return context()->get_type_mgr()->GetType(inst->type_id()); 117 } 118 119 std::vector<const Constant*> ConstantManager::GetOperandConstants( 120 Instruction* inst) const { 121 std::vector<const Constant*> constants; 122 for (uint32_t i = 0; i < inst->NumInOperands(); i++) { 123 const Operand* operand = &inst->GetInOperand(i); 124 if (operand->type != SPV_OPERAND_TYPE_ID) { 125 constants.push_back(nullptr); 126 } else { 127 uint32_t id = operand->words[0]; 128 const analysis::Constant* constant = FindDeclaredConstant(id); 129 constants.push_back(constant); 130 } 131 } 132 return constants; 133 } 134 135 uint32_t ConstantManager::FindDeclaredConstant(const Constant* c, 136 uint32_t type_id) const { 137 c = FindConstant(c); 138 if (c == nullptr) { 139 return 0; 140 } 141 142 for (auto range = const_val_to_id_.equal_range(c); 143 range.first != range.second; ++range.first) { 144 Instruction* const_def = 145 context()->get_def_use_mgr()->GetDef(range.first->second); 146 if (type_id == 0 || const_def->type_id() == type_id) { 147 return range.first->second; 148 } 149 } 150 return 0; 151 } 152 153 std::vector<const Constant*> ConstantManager::GetConstantsFromIds( 154 const std::vector<uint32_t>& ids) const { 155 std::vector<const Constant*> constants; 156 for (uint32_t id : ids) { 157 if (const Constant* c = FindDeclaredConstant(id)) { 158 constants.push_back(c); 159 } else { 160 return {}; 161 } 162 } 163 return constants; 164 } 165 166 Instruction* ConstantManager::BuildInstructionAndAddToModule( 167 const Constant* new_const, Module::inst_iterator* pos, uint32_t type_id) { 168 // TODO(1841): Handle id overflow. 169 uint32_t new_id = context()->TakeNextId(); 170 auto new_inst = CreateInstruction(new_id, new_const, type_id); 171 if (!new_inst) { 172 return nullptr; 173 } 174 auto* new_inst_ptr = new_inst.get(); 175 *pos = pos->InsertBefore(std::move(new_inst)); 176 ++(*pos); 177 context()->get_def_use_mgr()->AnalyzeInstDefUse(new_inst_ptr); 178 MapConstantToInst(new_const, new_inst_ptr); 179 return new_inst_ptr; 180 } 181 182 Instruction* ConstantManager::GetDefiningInstruction( 183 const Constant* c, uint32_t type_id, Module::inst_iterator* pos) { 184 assert(type_id == 0 || 185 context()->get_type_mgr()->GetType(type_id) == c->type()); 186 uint32_t decl_id = FindDeclaredConstant(c, type_id); 187 if (decl_id == 0) { 188 auto iter = context()->types_values_end(); 189 if (pos == nullptr) pos = &iter; 190 return BuildInstructionAndAddToModule(c, pos, type_id); 191 } else { 192 auto def = context()->get_def_use_mgr()->GetDef(decl_id); 193 assert(def != nullptr); 194 assert((type_id == 0 || def->type_id() == type_id) && 195 "This constant already has an instruction with a different type."); 196 return def; 197 } 198 } 199 200 std::unique_ptr<Constant> ConstantManager::CreateConstant( 201 const Type* type, const std::vector<uint32_t>& literal_words_or_ids) const { 202 if (literal_words_or_ids.size() == 0) { 203 // Constant declared with OpConstantNull 204 return MakeUnique<NullConstant>(type); 205 } else if (auto* bt = type->AsBool()) { 206 assert(literal_words_or_ids.size() == 1 && 207 "Bool constant should be declared with one operand"); 208 return MakeUnique<BoolConstant>(bt, literal_words_or_ids.front()); 209 } else if (auto* it = type->AsInteger()) { 210 return MakeUnique<IntConstant>(it, literal_words_or_ids); 211 } else if (auto* ft = type->AsFloat()) { 212 return MakeUnique<FloatConstant>(ft, literal_words_or_ids); 213 } else if (auto* vt = type->AsVector()) { 214 auto components = GetConstantsFromIds(literal_words_or_ids); 215 if (components.empty()) return nullptr; 216 // All components of VectorConstant must be of type Bool, Integer or Float. 217 if (!std::all_of(components.begin(), components.end(), 218 [](const Constant* c) { 219 if (c->type()->AsBool() || c->type()->AsInteger() || 220 c->type()->AsFloat()) { 221 return true; 222 } else { 223 return false; 224 } 225 })) 226 return nullptr; 227 // All components of VectorConstant must be in the same type. 228 const auto* component_type = components.front()->type(); 229 if (!std::all_of(components.begin(), components.end(), 230 [&component_type](const Constant* c) { 231 if (c->type() == component_type) return true; 232 return false; 233 })) 234 return nullptr; 235 return MakeUnique<VectorConstant>(vt, components); 236 } else if (auto* mt = type->AsMatrix()) { 237 auto components = GetConstantsFromIds(literal_words_or_ids); 238 if (components.empty()) return nullptr; 239 return MakeUnique<MatrixConstant>(mt, components); 240 } else if (auto* st = type->AsStruct()) { 241 auto components = GetConstantsFromIds(literal_words_or_ids); 242 if (components.empty()) return nullptr; 243 return MakeUnique<StructConstant>(st, components); 244 } else if (auto* at = type->AsArray()) { 245 auto components = GetConstantsFromIds(literal_words_or_ids); 246 if (components.empty()) return nullptr; 247 return MakeUnique<ArrayConstant>(at, components); 248 } else { 249 return nullptr; 250 } 251 } 252 253 const Constant* ConstantManager::GetConstantFromInst(Instruction* inst) { 254 std::vector<uint32_t> literal_words_or_ids; 255 256 // Collect the constant defining literals or component ids. 257 for (uint32_t i = 0; i < inst->NumInOperands(); i++) { 258 literal_words_or_ids.insert(literal_words_or_ids.end(), 259 inst->GetInOperand(i).words.begin(), 260 inst->GetInOperand(i).words.end()); 261 } 262 263 switch (inst->opcode()) { 264 // OpConstant{True|False} have the value embedded in the opcode. So they 265 // are not handled by the for-loop above. Here we add the value explicitly. 266 case SpvOp::SpvOpConstantTrue: 267 literal_words_or_ids.push_back(true); 268 break; 269 case SpvOp::SpvOpConstantFalse: 270 literal_words_or_ids.push_back(false); 271 break; 272 case SpvOp::SpvOpConstantNull: 273 case SpvOp::SpvOpConstant: 274 case SpvOp::SpvOpConstantComposite: 275 case SpvOp::SpvOpSpecConstantComposite: 276 break; 277 default: 278 return nullptr; 279 } 280 281 return GetConstant(GetType(inst), literal_words_or_ids); 282 } 283 284 std::unique_ptr<Instruction> ConstantManager::CreateInstruction( 285 uint32_t id, const Constant* c, uint32_t type_id) const { 286 uint32_t type = 287 (type_id == 0) ? context()->get_type_mgr()->GetId(c->type()) : type_id; 288 if (c->AsNullConstant()) { 289 return MakeUnique<Instruction>(context(), SpvOp::SpvOpConstantNull, type, 290 id, std::initializer_list<Operand>{}); 291 } else if (const BoolConstant* bc = c->AsBoolConstant()) { 292 return MakeUnique<Instruction>( 293 context(), 294 bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse, 295 type, id, std::initializer_list<Operand>{}); 296 } else if (const IntConstant* ic = c->AsIntConstant()) { 297 return MakeUnique<Instruction>( 298 context(), SpvOp::SpvOpConstant, type, id, 299 std::initializer_list<Operand>{ 300 Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, 301 ic->words())}); 302 } else if (const FloatConstant* fc = c->AsFloatConstant()) { 303 return MakeUnique<Instruction>( 304 context(), SpvOp::SpvOpConstant, type, id, 305 std::initializer_list<Operand>{ 306 Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, 307 fc->words())}); 308 } else if (const CompositeConstant* cc = c->AsCompositeConstant()) { 309 return CreateCompositeInstruction(id, cc, type_id); 310 } else { 311 return nullptr; 312 } 313 } 314 315 std::unique_ptr<Instruction> ConstantManager::CreateCompositeInstruction( 316 uint32_t result_id, const CompositeConstant* cc, uint32_t type_id) const { 317 std::vector<Operand> operands; 318 Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id); 319 uint32_t component_index = 0; 320 for (const Constant* component_const : cc->GetComponents()) { 321 uint32_t component_type_id = 0; 322 if (type_inst && type_inst->opcode() == SpvOpTypeStruct) { 323 component_type_id = type_inst->GetSingleWordInOperand(component_index); 324 } else if (type_inst && type_inst->opcode() == SpvOpTypeArray) { 325 component_type_id = type_inst->GetSingleWordInOperand(0); 326 } 327 uint32_t id = FindDeclaredConstant(component_const, component_type_id); 328 329 if (id == 0) { 330 // Cannot get the id of the component constant, while all components 331 // should have been added to the module prior to the composite constant. 332 // Cannot create OpConstantComposite instruction in this case. 333 return nullptr; 334 } 335 operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID, 336 std::initializer_list<uint32_t>{id}); 337 component_index++; 338 } 339 uint32_t type = 340 (type_id == 0) ? context()->get_type_mgr()->GetId(cc->type()) : type_id; 341 return MakeUnique<Instruction>(context(), SpvOp::SpvOpConstantComposite, type, 342 result_id, std::move(operands)); 343 } 344 345 const Constant* ConstantManager::GetConstant( 346 const Type* type, const std::vector<uint32_t>& literal_words_or_ids) { 347 auto cst = CreateConstant(type, literal_words_or_ids); 348 return cst ? RegisterConstant(std::move(cst)) : nullptr; 349 } 350 351 std::vector<const analysis::Constant*> Constant::GetVectorComponents( 352 analysis::ConstantManager* const_mgr) const { 353 std::vector<const analysis::Constant*> components; 354 const analysis::VectorConstant* a = this->AsVectorConstant(); 355 const analysis::Vector* vector_type = this->type()->AsVector(); 356 assert(vector_type != nullptr); 357 if (a != nullptr) { 358 for (uint32_t i = 0; i < vector_type->element_count(); ++i) { 359 components.push_back(a->GetComponents()[i]); 360 } 361 } else { 362 const analysis::Type* element_type = vector_type->element_type(); 363 const analysis::Constant* element_null_const = 364 const_mgr->GetConstant(element_type, {}); 365 for (uint32_t i = 0; i < vector_type->element_count(); ++i) { 366 components.push_back(element_null_const); 367 } 368 } 369 return components; 370 } 371 372 } // namespace analysis 373 } // namespace opt 374 } // namespace spvtools 375