1 // Copyright (c) 2016 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "fold_spec_constant_op_and_composite_pass.h" 16 17 #include <algorithm> 18 #include <initializer_list> 19 #include <tuple> 20 21 #include "constants.h" 22 #include "make_unique.h" 23 24 namespace spvtools { 25 namespace opt { 26 27 namespace { 28 // Returns the single-word result from performing the given unary operation on 29 // the operand value which is passed in as a 32-bit word. 30 uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) { 31 switch (opcode) { 32 // Arthimetics 33 case SpvOp::SpvOpSNegate: 34 return -static_cast<int32_t>(operand); 35 case SpvOp::SpvOpNot: 36 return ~operand; 37 case SpvOp::SpvOpLogicalNot: 38 return !static_cast<bool>(operand); 39 default: 40 assert(false && 41 "Unsupported unary operation for OpSpecConstantOp instruction"); 42 return 0u; 43 } 44 } 45 46 // Returns the single-word result from performing the given binary operation on 47 // the operand values which are passed in as two 32-bit word. 48 uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) { 49 switch (opcode) { 50 // Arthimetics 51 case SpvOp::SpvOpIAdd: 52 return a + b; 53 case SpvOp::SpvOpISub: 54 return a - b; 55 case SpvOp::SpvOpIMul: 56 return a * b; 57 case SpvOp::SpvOpUDiv: 58 assert(b != 0); 59 return a / b; 60 case SpvOp::SpvOpSDiv: 61 assert(b != 0u); 62 return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b)); 63 case SpvOp::SpvOpSRem: { 64 // The sign of non-zero result comes from the first operand: a. This is 65 // guaranteed by C++11 rules for integer division operator. The division 66 // result is rounded toward zero, so the result of '%' has the sign of 67 // the first operand. 68 assert(b != 0u); 69 return static_cast<int32_t>(a) % static_cast<int32_t>(b); 70 } 71 case SpvOp::SpvOpSMod: { 72 // The sign of non-zero result comes from the second operand: b 73 assert(b != 0u); 74 int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b); 75 int32_t b_prim = static_cast<int32_t>(b); 76 return (rem + b_prim) % b_prim; 77 } 78 case SpvOp::SpvOpUMod: 79 assert(b != 0u); 80 return (a % b); 81 82 // Shifting 83 case SpvOp::SpvOpShiftRightLogical: { 84 return a >> b; 85 } 86 case SpvOp::SpvOpShiftRightArithmetic: 87 return (static_cast<int32_t>(a)) >> b; 88 case SpvOp::SpvOpShiftLeftLogical: 89 return a << b; 90 91 // Bitwise operations 92 case SpvOp::SpvOpBitwiseOr: 93 return a | b; 94 case SpvOp::SpvOpBitwiseAnd: 95 return a & b; 96 case SpvOp::SpvOpBitwiseXor: 97 return a ^ b; 98 99 // Logical 100 case SpvOp::SpvOpLogicalEqual: 101 return (static_cast<bool>(a)) == (static_cast<bool>(b)); 102 case SpvOp::SpvOpLogicalNotEqual: 103 return (static_cast<bool>(a)) != (static_cast<bool>(b)); 104 case SpvOp::SpvOpLogicalOr: 105 return (static_cast<bool>(a)) || (static_cast<bool>(b)); 106 case SpvOp::SpvOpLogicalAnd: 107 return (static_cast<bool>(a)) && (static_cast<bool>(b)); 108 109 // Comparison 110 case SpvOp::SpvOpIEqual: 111 return a == b; 112 case SpvOp::SpvOpINotEqual: 113 return a != b; 114 case SpvOp::SpvOpULessThan: 115 return a < b; 116 case SpvOp::SpvOpSLessThan: 117 return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b)); 118 case SpvOp::SpvOpUGreaterThan: 119 return a > b; 120 case SpvOp::SpvOpSGreaterThan: 121 return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b)); 122 case SpvOp::SpvOpULessThanEqual: 123 return a <= b; 124 case SpvOp::SpvOpSLessThanEqual: 125 return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b)); 126 case SpvOp::SpvOpUGreaterThanEqual: 127 return a >= b; 128 case SpvOp::SpvOpSGreaterThanEqual: 129 return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b)); 130 default: 131 assert(false && 132 "Unsupported binary operation for OpSpecConstantOp instruction"); 133 return 0u; 134 } 135 } 136 137 // Returns the single-word result from performing the given ternary operation 138 // on the operand values which are passed in as three 32-bit word. 139 uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, uint32_t c) { 140 switch (opcode) { 141 case SpvOp::SpvOpSelect: 142 return (static_cast<bool>(a)) ? b : c; 143 default: 144 assert(false && 145 "Unsupported ternary operation for OpSpecConstantOp instruction"); 146 return 0u; 147 } 148 } 149 150 // Returns the single-word result from performing the given operation on the 151 // operand words. This only works with 32-bit operations and uses boolean 152 // convention that 0u is false, and anything else is boolean true. 153 // TODO(qining): Support operands other than 32-bit wide. 154 uint32_t OperateWords(SpvOp opcode, 155 const std::vector<uint32_t>& operand_words) { 156 switch (operand_words.size()) { 157 case 1: 158 return UnaryOperate(opcode, operand_words.front()); 159 case 2: 160 return BinaryOperate(opcode, operand_words.front(), operand_words.back()); 161 case 3: 162 return TernaryOperate(opcode, operand_words[0], operand_words[1], 163 operand_words[2]); 164 default: 165 assert(false && "Invalid number of operands"); 166 return 0; 167 } 168 } 169 170 // Returns the result of performing an operation on scalar constant operands. 171 // This function extracts the operand values as 32 bit words and returns the 172 // result in 32 bit word. Scalar constants with longer than 32-bit width are 173 // not accepted in this function. 174 uint32_t OperateScalars(SpvOp opcode, 175 const std::vector<analysis::Constant*>& operands) { 176 std::vector<uint32_t> operand_values_in_raw_words; 177 for (analysis::Constant* operand : operands) { 178 if (analysis::ScalarConstant* scalar = operand->AsScalarConstant()) { 179 const auto& scalar_words = scalar->words(); 180 assert(scalar_words.size() == 1 && 181 "Scalar constants with longer than 32-bit width are not allowed " 182 "in OperateScalars()"); 183 operand_values_in_raw_words.push_back(scalar_words.front()); 184 } else if (operand->AsNullConstant()) { 185 operand_values_in_raw_words.push_back(0u); 186 } else { 187 assert(false && 188 "OperateScalars() only accepts ScalarConst or NullConst type of " 189 "constant"); 190 } 191 } 192 return OperateWords(opcode, operand_values_in_raw_words); 193 } 194 195 // Returns the result of performing an operation over constant vectors. This 196 // function iterates through the given vector type constant operands and 197 // calculates the result for each element of the result vector to return. 198 // Vectors with longer than 32-bit scalar components are not accepted in this 199 // function. 200 std::vector<uint32_t> OperateVectors( 201 SpvOp opcode, uint32_t num_dims, 202 const std::vector<analysis::Constant*>& operands) { 203 std::vector<uint32_t> result; 204 for (uint32_t d = 0; d < num_dims; d++) { 205 std::vector<uint32_t> operand_values_for_one_dimension; 206 for (analysis::Constant* operand : operands) { 207 if (analysis::VectorConstant* vector_operand = 208 operand->AsVectorConstant()) { 209 // Extract the raw value of the scalar component constants 210 // in 32-bit words here. The reason of not using OperateScalars() here 211 // is that we do not create temporary null constants as components 212 // when the vector operand is a NullConstant because Constant creation 213 // may need extra checks for the validity and that is not manageed in 214 // here. 215 if (const analysis::ScalarConstant* scalar_component = 216 vector_operand->GetComponents().at(d)->AsScalarConstant()) { 217 const auto& scalar_words = scalar_component->words(); 218 assert( 219 scalar_words.size() == 1 && 220 "Vector components with longer than 32-bit width are not allowed " 221 "in OperateVectors()"); 222 operand_values_for_one_dimension.push_back(scalar_words.front()); 223 } else if (operand->AsNullConstant()) { 224 operand_values_for_one_dimension.push_back(0u); 225 } else { 226 assert(false && 227 "VectorConst should only has ScalarConst or NullConst as " 228 "components"); 229 } 230 } else if (operand->AsNullConstant()) { 231 operand_values_for_one_dimension.push_back(0u); 232 } else { 233 assert(false && 234 "OperateVectors() only accepts VectorConst or NullConst type of " 235 "constant"); 236 } 237 } 238 result.push_back(OperateWords(opcode, operand_values_for_one_dimension)); 239 } 240 return result; 241 } 242 } // anonymous namespace 243 244 FoldSpecConstantOpAndCompositePass::FoldSpecConstantOpAndCompositePass() 245 : max_id_(0), 246 module_(nullptr), 247 def_use_mgr_(nullptr), 248 type_mgr_(nullptr), 249 id_to_const_val_() {} 250 251 Pass::Status FoldSpecConstantOpAndCompositePass::Process(ir::Module* module) { 252 Initialize(module); 253 return ProcessImpl(module); 254 } 255 256 void FoldSpecConstantOpAndCompositePass::Initialize(ir::Module* module) { 257 type_mgr_.reset(new analysis::TypeManager(consumer(), *module)); 258 def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module)); 259 for (const auto& id_def : def_use_mgr_->id_to_defs()) { 260 max_id_ = std::max(max_id_, id_def.first); 261 } 262 module_ = module; 263 }; 264 265 Pass::Status FoldSpecConstantOpAndCompositePass::ProcessImpl( 266 ir::Module* module) { 267 bool modified = false; 268 // Traverse through all the constant defining instructions. For Normal 269 // Constants whose values are determined and do not depend on OpUndef 270 // instructions, records their values in two internal maps: id_to_const_val_ 271 // and const_val_to_id_ so that we can use them to infer the value of Spec 272 // Constants later. 273 // For Spec Constants defined with OpSpecConstantComposite instructions, if 274 // all of their components are Normal Constants, they will be turned into 275 // Normal Constants too. For Spec Constants defined with OpSpecConstantOp 276 // instructions, we check if they only depends on Normal Constants and fold 277 // them when possible. The two maps for Normal Constants: id_to_const_val_ 278 // and const_val_to_id_ will be updated along the traversal so that the new 279 // Normal Constants generated from folding can be used to fold following Spec 280 // Constants. 281 // This algorithm depends on the SSA property of SPIR-V when 282 // defining constants. The dependent constants must be defined before the 283 // dependee constants. So a dependent Spec Constant must be defined and 284 // will be processed before its dependee Spec Constant. When we encounter 285 // the dependee Spec Constants, all its dependent constants must have been 286 // processed and all its dependent Spec Constants should have been folded if 287 // possible. 288 for (ir::Module::inst_iterator inst_iter = module->types_values_begin(); 289 // Need to re-evaluate the end iterator since we may modify the list of 290 // instructions in this section of the module as the process goes. 291 inst_iter != module->types_values_end(); ++inst_iter) { 292 ir::Instruction* inst = &*inst_iter; 293 // Collect constant values of normal constants and process the 294 // OpSpecConstantOp and OpSpecConstantComposite instructions if possible. 295 // The constant values will be stored in analysis::Constant instances. 296 // OpConstantSampler instruction is not collected here because it cannot be 297 // used in OpSpecConstant{Composite|Op} instructions. 298 // TODO(qining): If the constant or its type has decoration, we may need 299 // to skip it. 300 if (GetType(inst) && !GetType(inst)->decoration_empty()) continue; 301 switch (SpvOp opcode = inst->opcode()) { 302 // Records the values of Normal Constants. 303 case SpvOp::SpvOpConstantTrue: 304 case SpvOp::SpvOpConstantFalse: 305 case SpvOp::SpvOpConstant: 306 case SpvOp::SpvOpConstantNull: 307 case SpvOp::SpvOpConstantComposite: 308 case SpvOp::SpvOpSpecConstantComposite: { 309 // A Constant instance will be created if the given instruction is a 310 // Normal Constant whose value(s) are fixed. Note that for a composite 311 // Spec Constant defined with OpSpecConstantComposite instruction, if 312 // all of its components are Normal Constants already, the Spec 313 // Constant will be turned in to a Normal Constant. In that case, a 314 // Constant instance should also be created successfully and recorded 315 // in the id_to_const_val_ and const_val_to_id_ mapps. 316 if (auto const_value = CreateConstFromInst(inst)) { 317 // Need to replace the OpSpecConstantComposite instruction with a 318 // corresponding OpConstantComposite instruction. 319 if (opcode == SpvOp::SpvOpSpecConstantComposite) { 320 inst->SetOpcode(SpvOp::SpvOpConstantComposite); 321 modified = true; 322 } 323 const_val_to_id_[const_value.get()] = inst->result_id(); 324 id_to_const_val_[inst->result_id()] = std::move(const_value); 325 } 326 break; 327 } 328 // For a Spec Constants defined with OpSpecConstantOp instruction, check 329 // if it only depends on Normal Constants. If so, the Spec Constant will 330 // be folded. The original Spec Constant defining instruction will be 331 // replaced by Normal Constant defining instructions, and the new Normal 332 // Constants will be added to id_to_const_val_ and const_val_to_id_ so 333 // that we can use the new Normal Constants when folding following Spec 334 // Constants. 335 case SpvOp::SpvOpSpecConstantOp: 336 modified |= ProcessOpSpecConstantOp(&inst_iter); 337 break; 338 default: 339 break; 340 } 341 } 342 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; 343 } 344 345 bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp( 346 ir::Module::inst_iterator* pos) { 347 ir::Instruction* inst = &**pos; 348 ir::Instruction* folded_inst = nullptr; 349 assert(inst->GetInOperand(0).type == 350 SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER && 351 "The first in-operand of OpSpecContantOp instruction must be of " 352 "SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER type"); 353 354 switch (static_cast<SpvOp>(inst->GetSingleWordInOperand(0))) { 355 case SpvOp::SpvOpCompositeExtract: 356 folded_inst = DoCompositeExtract(pos); 357 break; 358 case SpvOp::SpvOpVectorShuffle: 359 folded_inst = DoVectorShuffle(pos); 360 break; 361 362 case SpvOp::SpvOpCompositeInsert: 363 // Current Glslang does not generate code with OpSpecConstantOp 364 // CompositeInsert instruction, so this is not implmented so far. 365 // TODO(qining): Implement CompositeInsert case. 366 return false; 367 368 default: 369 // Component-wise operations. 370 folded_inst = DoComponentWiseOperation(pos); 371 break; 372 } 373 if (!folded_inst) return false; 374 375 // Replace the original constant with the new folded constant, kill the 376 // original constant. 377 uint32_t new_id = folded_inst->result_id(); 378 uint32_t old_id = inst->result_id(); 379 def_use_mgr_->ReplaceAllUsesWith(old_id, new_id); 380 def_use_mgr_->KillDef(old_id); 381 return true; 382 } 383 384 ir::Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract( 385 ir::Module::inst_iterator* pos) { 386 ir::Instruction* inst = &**pos; 387 assert(inst->NumInOperands() - 1 >= 2 && 388 "OpSpecConstantOp CompositeExtract requires at least two non-type " 389 "non-opcode operands."); 390 assert(inst->GetInOperand(1).type == SPV_OPERAND_TYPE_ID && 391 "The vector operand must have a SPV_OPERAND_TYPE_ID type"); 392 assert( 393 inst->GetInOperand(2).type == SPV_OPERAND_TYPE_LITERAL_INTEGER && 394 "The literal operand must have a SPV_OPERAND_TYPE_LITERAL_INTEGER type"); 395 396 // Note that for OpSpecConstantOp, the second in-operand is the first id 397 // operand. The first in-operand is the spec opcode. 398 analysis::Constant* first_operand_const = 399 FindRecordedConst(inst->GetSingleWordInOperand(1)); 400 if (!first_operand_const) return nullptr; 401 402 const analysis::Constant* current_const = first_operand_const; 403 for (uint32_t i = 2; i < inst->NumInOperands(); i++) { 404 uint32_t literal = inst->GetSingleWordInOperand(i); 405 if (const analysis::CompositeConstant* composite_const = 406 current_const->AsCompositeConstant()) { 407 // Case 1: current constant is a non-null composite type constant. 408 assert(literal < composite_const->GetComponents().size() && 409 "Literal index out of bound of the composite constant"); 410 current_const = composite_const->GetComponents().at(literal); 411 } else if (current_const->AsNullConstant()) { 412 // Case 2: current constant is a constant created with OpConstantNull. 413 // Because components of a NullConstant are always NullConstants, we can 414 // return early with a NullConstant in the result type. 415 return BuildInstructionAndAddToModule(CreateConst(GetType(inst), {}), 416 pos); 417 } else { 418 // Dereferencing a non-composite constant. Invalid case. 419 return nullptr; 420 } 421 } 422 return BuildInstructionAndAddToModule(current_const->Copy(), pos); 423 } 424 425 ir::Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle( 426 ir::Module::inst_iterator* pos) { 427 ir::Instruction* inst = &**pos; 428 analysis::Vector* result_vec_type = GetType(inst)->AsVector(); 429 assert(inst->NumInOperands() - 1 > 2 && 430 "OpSpecConstantOp DoVectorShuffle instruction requires more than 2 " 431 "operands (2 vector ids and at least one literal operand"); 432 assert(result_vec_type && 433 "The result of VectorShuffle must be of type vector"); 434 435 // A temporary null constants that can be used as the components fo the 436 // result vector. This is needed when any one of the vector operands are null 437 // constant. 438 std::unique_ptr<analysis::Constant> null_component_constants; 439 440 // Get a concatenated vector of scalar constants. The vector should be built 441 // with the components from the first and the second operand of VectorShuffle. 442 std::vector<const analysis::Constant*> concatenated_components; 443 // Note that for OpSpecConstantOp, the second in-operand is the first id 444 // operand. The first in-operand is the spec opcode. 445 for (uint32_t i : {1, 2}) { 446 assert(inst->GetInOperand(i).type == SPV_OPERAND_TYPE_ID && 447 "The vector operand must have a SPV_OPERAND_TYPE_ID type"); 448 uint32_t operand_id = inst->GetSingleWordInOperand(i); 449 analysis::Constant* operand_const = FindRecordedConst(operand_id); 450 if (!operand_const) return nullptr; 451 const analysis::Type* operand_type = operand_const->type(); 452 assert(operand_type->AsVector() && 453 "The first two operand of VectorShuffle must be of vector type"); 454 if (analysis::VectorConstant* vec_const = 455 operand_const->AsVectorConstant()) { 456 // case 1: current operand is a non-null vector constant. 457 concatenated_components.insert(concatenated_components.end(), 458 vec_const->GetComponents().begin(), 459 vec_const->GetComponents().end()); 460 } else if (operand_const->AsNullConstant()) { 461 // case 2: current operand is a null vector constant. Create a temporary 462 // null scalar constant as the component. 463 if (!null_component_constants) { 464 const analysis::Type* component_type = 465 operand_type->AsVector()->element_type(); 466 null_component_constants = CreateConst(component_type, {}); 467 } 468 // Append the null scalar consts to the concatenated components 469 // vector. 470 concatenated_components.insert(concatenated_components.end(), 471 operand_type->AsVector()->element_count(), 472 null_component_constants.get()); 473 } else { 474 // no other valid cases 475 return nullptr; 476 } 477 } 478 // Create null component constants if there are any. The component constants 479 // must be added to the module before the dependee composite constants to 480 // satisfy SSA def-use dominance. 481 if (null_component_constants) { 482 BuildInstructionAndAddToModule(std::move(null_component_constants), pos); 483 } 484 // Create the new vector constant with the selected components. 485 std::vector<const analysis::Constant*> selected_components; 486 for (uint32_t i = 3; i < inst->NumInOperands(); i++) { 487 assert(inst->GetInOperand(i).type == SPV_OPERAND_TYPE_LITERAL_INTEGER && 488 "The literal operand must of type SPV_OPERAND_TYPE_LITERAL_INTEGER"); 489 uint32_t literal = inst->GetSingleWordInOperand(i); 490 assert(literal < concatenated_components.size() && 491 "Literal index out of bound of the concatenated vector"); 492 selected_components.push_back(concatenated_components[literal]); 493 } 494 auto new_vec_const = MakeUnique<analysis::VectorConstant>( 495 result_vec_type, selected_components); 496 return BuildInstructionAndAddToModule(std::move(new_vec_const), pos); 497 } 498 499 namespace { 500 // A helper function to check the type for component wise operations. Returns 501 // true if the type: 502 // 1) is bool type; 503 // 2) is 32-bit int type; 504 // 3) is vector of bool type; 505 // 4) is vector of 32-bit integer type. 506 // Otherwise returns false. 507 bool IsValidTypeForComponentWiseOperation(const analysis::Type* type) { 508 if (type->AsBool()) { 509 return true; 510 } else if (auto* it = type->AsInteger()) { 511 if (it->width() == 32) return true; 512 } else if (auto* vt = type->AsVector()) { 513 if (vt->element_type()->AsBool()) 514 return true; 515 else if (auto* vit = vt->element_type()->AsInteger()) { 516 if (vit->width() == 32) return true; 517 } 518 } 519 return false; 520 } 521 } 522 523 ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation( 524 ir::Module::inst_iterator* pos) { 525 const ir::Instruction* inst = &**pos; 526 const analysis::Type* result_type = GetType(inst); 527 SpvOp spec_opcode = static_cast<SpvOp>(inst->GetSingleWordInOperand(0)); 528 // Check and collect operands. 529 std::vector<analysis::Constant*> operands; 530 531 if (!std::all_of(inst->cbegin(), inst->cend(), 532 [&operands, this](const ir::Operand& o) { 533 // skip the operands that is not an id. 534 if (o.type != spv_operand_type_t::SPV_OPERAND_TYPE_ID) 535 return true; 536 uint32_t id = o.words.front(); 537 if (analysis::Constant* c = FindRecordedConst(id)) { 538 if (IsValidTypeForComponentWiseOperation(c->type())) { 539 operands.push_back(c); 540 return true; 541 } 542 } 543 return false; 544 })) 545 return nullptr; 546 547 if (result_type->AsInteger() || result_type->AsBool()) { 548 // Scalar operation 549 uint32_t result_val = OperateScalars(spec_opcode, operands); 550 auto result_const = CreateConst(result_type, {result_val}); 551 return BuildInstructionAndAddToModule(std::move(result_const), pos); 552 } else if (result_type->AsVector()) { 553 // Vector operation 554 const analysis::Type* element_type = 555 result_type->AsVector()->element_type(); 556 uint32_t num_dims = result_type->AsVector()->element_count(); 557 std::vector<uint32_t> result_vec = 558 OperateVectors(spec_opcode, num_dims, operands); 559 std::vector<const analysis::Constant*> result_vector_components; 560 for (uint32_t r : result_vec) { 561 if (auto rc = CreateConst(element_type, {r})) { 562 result_vector_components.push_back(rc.get()); 563 if (!BuildInstructionAndAddToModule(std::move(rc), pos)) { 564 assert(false && 565 "Failed to build and insert constant declaring instruction " 566 "for the given vector component constant"); 567 } 568 } else { 569 assert(false && "Failed to create constants with 32-bit word"); 570 } 571 } 572 auto new_vec_const = MakeUnique<analysis::VectorConstant>( 573 result_type->AsVector(), result_vector_components); 574 return BuildInstructionAndAddToModule(std::move(new_vec_const), pos); 575 } else { 576 // Cannot process invalid component wise operation. The result of component 577 // wise operation must be of integer or bool scalar or vector of 578 // integer/bool type. 579 return nullptr; 580 } 581 } 582 583 ir::Instruction* 584 FoldSpecConstantOpAndCompositePass::BuildInstructionAndAddToModule( 585 std::unique_ptr<analysis::Constant> c, ir::Module::inst_iterator* pos) { 586 analysis::Constant* new_const = c.get(); 587 uint32_t new_id = ++max_id_; 588 module_->SetIdBound(new_id + 1); 589 const_val_to_id_[new_const] = new_id; 590 id_to_const_val_[new_id] = std::move(c); 591 auto new_inst = CreateInstruction(new_id, new_const); 592 if (!new_inst) return nullptr; 593 auto* new_inst_ptr = new_inst.get(); 594 *pos = pos->InsertBefore(std::move(new_inst)); 595 (*pos)++; 596 def_use_mgr_->AnalyzeInstDefUse(new_inst_ptr); 597 return new_inst_ptr; 598 } 599 600 std::unique_ptr<analysis::Constant> 601 FoldSpecConstantOpAndCompositePass::CreateConstFromInst(ir::Instruction* inst) { 602 std::vector<uint32_t> literal_words_or_ids; 603 std::unique_ptr<analysis::Constant> new_const; 604 // Collect the constant defining literals or component ids. 605 for (uint32_t i = 0; i < inst->NumInOperands(); i++) { 606 literal_words_or_ids.insert(literal_words_or_ids.end(), 607 inst->GetInOperand(i).words.begin(), 608 inst->GetInOperand(i).words.end()); 609 } 610 switch (inst->opcode()) { 611 // OpConstant{True|Flase} have the value embedded in the opcode. So they 612 // are not handled by the for-loop above. Here we add the value explicitly. 613 case SpvOp::SpvOpConstantTrue: 614 literal_words_or_ids.push_back(true); 615 break; 616 case SpvOp::SpvOpConstantFalse: 617 literal_words_or_ids.push_back(false); 618 break; 619 case SpvOp::SpvOpConstantNull: 620 case SpvOp::SpvOpConstant: 621 case SpvOp::SpvOpConstantComposite: 622 case SpvOp::SpvOpSpecConstantComposite: 623 break; 624 default: 625 return nullptr; 626 } 627 return CreateConst(GetType(inst), literal_words_or_ids); 628 } 629 630 analysis::Constant* FoldSpecConstantOpAndCompositePass::FindRecordedConst( 631 uint32_t id) { 632 auto iter = id_to_const_val_.find(id); 633 if (iter == id_to_const_val_.end()) { 634 return nullptr; 635 } else { 636 return iter->second.get(); 637 } 638 } 639 640 uint32_t FoldSpecConstantOpAndCompositePass::FindRecordedConst( 641 const analysis::Constant* c) { 642 auto iter = const_val_to_id_.find(c); 643 if (iter == const_val_to_id_.end()) { 644 return 0; 645 } else { 646 return iter->second; 647 } 648 } 649 650 std::vector<const analysis::Constant*> 651 FoldSpecConstantOpAndCompositePass::GetConstsFromIds( 652 const std::vector<uint32_t>& ids) { 653 std::vector<const analysis::Constant*> constants; 654 for (uint32_t id : ids) { 655 if (analysis::Constant* c = FindRecordedConst(id)) { 656 constants.push_back(c); 657 } else { 658 return {}; 659 } 660 } 661 return constants; 662 } 663 664 std::unique_ptr<analysis::Constant> 665 FoldSpecConstantOpAndCompositePass::CreateConst( 666 const analysis::Type* type, 667 const std::vector<uint32_t>& literal_words_or_ids) { 668 std::unique_ptr<analysis::Constant> new_const; 669 if (literal_words_or_ids.size() == 0) { 670 // Constant declared with OpConstantNull 671 return MakeUnique<analysis::NullConstant>(type); 672 } else if (auto* bt = type->AsBool()) { 673 assert(literal_words_or_ids.size() == 1 && 674 "Bool constant should be declared with one operand"); 675 return MakeUnique<analysis::BoolConstant>(bt, literal_words_or_ids.front()); 676 } else if (auto* it = type->AsInteger()) { 677 return MakeUnique<analysis::IntConstant>(it, literal_words_or_ids); 678 } else if (auto* ft = type->AsFloat()) { 679 return MakeUnique<analysis::FloatConstant>(ft, literal_words_or_ids); 680 } else if (auto* vt = type->AsVector()) { 681 auto components = GetConstsFromIds(literal_words_or_ids); 682 if (components.empty()) return nullptr; 683 // All components of VectorConstant must be of type Bool, Integer or Float. 684 if (!std::all_of(components.begin(), components.end(), 685 [](const analysis::Constant* c) { 686 if (c->type()->AsBool() || c->type()->AsInteger() || 687 c->type()->AsFloat()) { 688 return true; 689 } else { 690 return false; 691 } 692 })) 693 return nullptr; 694 // All components of VectorConstant must be in the same type. 695 const auto* component_type = components.front()->type(); 696 if (!std::all_of(components.begin(), components.end(), 697 [&component_type](const analysis::Constant* c) { 698 if (c->type() == component_type) return true; 699 return false; 700 })) 701 return nullptr; 702 return MakeUnique<analysis::VectorConstant>(vt, components); 703 } else if (auto* st = type->AsStruct()) { 704 auto components = GetConstsFromIds(literal_words_or_ids); 705 if (components.empty()) return nullptr; 706 return MakeUnique<analysis::StructConstant>(st, components); 707 } else if (auto* at = type->AsArray()) { 708 auto components = GetConstsFromIds(literal_words_or_ids); 709 if (components.empty()) return nullptr; 710 return MakeUnique<analysis::ArrayConstant>(at, components); 711 } else { 712 return nullptr; 713 } 714 } 715 716 std::vector<ir::Operand> BuildOperandsFromIds( 717 const std::vector<uint32_t>& ids) { 718 std::vector<ir::Operand> operands; 719 for (uint32_t id : ids) { 720 operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID, 721 std::initializer_list<uint32_t>{id}); 722 } 723 return operands; 724 } 725 726 std::unique_ptr<ir::Instruction> 727 FoldSpecConstantOpAndCompositePass::CreateInstruction(uint32_t id, 728 analysis::Constant* c) { 729 if (c->AsNullConstant()) { 730 return MakeUnique<ir::Instruction>(SpvOp::SpvOpConstantNull, 731 type_mgr_->GetId(c->type()), id, 732 std::initializer_list<ir::Operand>{}); 733 } else if (analysis::BoolConstant* bc = c->AsBoolConstant()) { 734 return MakeUnique<ir::Instruction>( 735 bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse, 736 type_mgr_->GetId(c->type()), id, std::initializer_list<ir::Operand>{}); 737 } else if (analysis::IntConstant* ic = c->AsIntConstant()) { 738 return MakeUnique<ir::Instruction>( 739 SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id, 740 std::initializer_list<ir::Operand>{ir::Operand( 741 spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, 742 ic->words())}); 743 } else if (analysis::FloatConstant* fc = c->AsFloatConstant()) { 744 return MakeUnique<ir::Instruction>( 745 SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id, 746 std::initializer_list<ir::Operand>{ir::Operand( 747 spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, 748 fc->words())}); 749 } else if (analysis::CompositeConstant* cc = c->AsCompositeConstant()) { 750 return CreateCompositeInstruction(id, cc); 751 } else { 752 return nullptr; 753 } 754 } 755 756 std::unique_ptr<ir::Instruction> 757 FoldSpecConstantOpAndCompositePass::CreateCompositeInstruction( 758 uint32_t result_id, analysis::CompositeConstant* cc) { 759 std::vector<ir::Operand> operands; 760 for (const analysis::Constant* component_const : cc->GetComponents()) { 761 uint32_t id = FindRecordedConst(component_const); 762 if (id == 0) { 763 // Cannot get the id of the component constant, while all components 764 // should have been added to the module prior to the composite constant. 765 // Cannot create OpConstantComposite instruction in this case. 766 return nullptr; 767 } 768 operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID, 769 std::initializer_list<uint32_t>{id}); 770 } 771 return MakeUnique<ir::Instruction>(SpvOp::SpvOpConstantComposite, 772 type_mgr_->GetId(cc->type()), result_id, 773 std::move(operands)); 774 } 775 776 } // namespace opt 777 } // namespace spvtools 778