1 // Copyright (c) 2018 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "source/opt/const_folding_rules.h" 16 17 #include "source/opt/ir_context.h" 18 19 namespace spvtools { 20 namespace opt { 21 namespace { 22 23 const uint32_t kExtractCompositeIdInIdx = 0; 24 25 // Returns true if |type| is Float or a vector of Float. 26 bool HasFloatingPoint(const analysis::Type* type) { 27 if (type->AsFloat()) { 28 return true; 29 } else if (const analysis::Vector* vec_type = type->AsVector()) { 30 return vec_type->element_type()->AsFloat() != nullptr; 31 } 32 33 return false; 34 } 35 36 // Folds an OpcompositeExtract where input is a composite constant. 37 ConstantFoldingRule FoldExtractWithConstants() { 38 return [](IRContext* context, Instruction* inst, 39 const std::vector<const analysis::Constant*>& constants) 40 -> const analysis::Constant* { 41 const analysis::Constant* c = constants[kExtractCompositeIdInIdx]; 42 if (c == nullptr) { 43 return nullptr; 44 } 45 46 for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { 47 uint32_t element_index = inst->GetSingleWordInOperand(i); 48 if (c->AsNullConstant()) { 49 // Return Null for the return type. 50 analysis::ConstantManager* const_mgr = context->get_constant_mgr(); 51 analysis::TypeManager* type_mgr = context->get_type_mgr(); 52 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {}); 53 } 54 55 auto cc = c->AsCompositeConstant(); 56 assert(cc != nullptr); 57 auto components = cc->GetComponents(); 58 c = components[element_index]; 59 } 60 return c; 61 }; 62 } 63 64 ConstantFoldingRule FoldVectorShuffleWithConstants() { 65 return [](IRContext* context, Instruction* inst, 66 const std::vector<const analysis::Constant*>& constants) 67 -> const analysis::Constant* { 68 assert(inst->opcode() == SpvOpVectorShuffle); 69 const analysis::Constant* c1 = constants[0]; 70 const analysis::Constant* c2 = constants[1]; 71 if (c1 == nullptr || c2 == nullptr) { 72 return nullptr; 73 } 74 75 analysis::ConstantManager* const_mgr = context->get_constant_mgr(); 76 const analysis::Type* element_type = c1->type()->AsVector()->element_type(); 77 78 std::vector<const analysis::Constant*> c1_components; 79 if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) { 80 c1_components = vec_const->GetComponents(); 81 } else { 82 assert(c1->AsNullConstant()); 83 const analysis::Constant* element = 84 const_mgr->GetConstant(element_type, {}); 85 c1_components.resize(c1->type()->AsVector()->element_count(), element); 86 } 87 std::vector<const analysis::Constant*> c2_components; 88 if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) { 89 c2_components = vec_const->GetComponents(); 90 } else { 91 assert(c2->AsNullConstant()); 92 const analysis::Constant* element = 93 const_mgr->GetConstant(element_type, {}); 94 c2_components.resize(c2->type()->AsVector()->element_count(), element); 95 } 96 97 std::vector<uint32_t> ids; 98 const uint32_t undef_literal_value = 0xffffffff; 99 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { 100 uint32_t index = inst->GetSingleWordInOperand(i); 101 if (index == undef_literal_value) { 102 // Don't fold shuffle with undef literal value. 103 return nullptr; 104 } else if (index < c1_components.size()) { 105 Instruction* member_inst = 106 const_mgr->GetDefiningInstruction(c1_components[index]); 107 ids.push_back(member_inst->result_id()); 108 } else { 109 Instruction* member_inst = const_mgr->GetDefiningInstruction( 110 c2_components[index - c1_components.size()]); 111 ids.push_back(member_inst->result_id()); 112 } 113 } 114 115 analysis::TypeManager* type_mgr = context->get_type_mgr(); 116 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids); 117 }; 118 } 119 120 ConstantFoldingRule FoldVectorTimesScalar() { 121 return [](IRContext* context, Instruction* inst, 122 const std::vector<const analysis::Constant*>& constants) 123 -> const analysis::Constant* { 124 assert(inst->opcode() == SpvOpVectorTimesScalar); 125 analysis::ConstantManager* const_mgr = context->get_constant_mgr(); 126 analysis::TypeManager* type_mgr = context->get_type_mgr(); 127 128 if (!inst->IsFloatingPointFoldingAllowed()) { 129 if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) { 130 return nullptr; 131 } 132 } 133 134 const analysis::Constant* c1 = constants[0]; 135 const analysis::Constant* c2 = constants[1]; 136 137 if (c1 && c1->IsZero()) { 138 return c1; 139 } 140 141 if (c2 && c2->IsZero()) { 142 // Get or create the NullConstant for this type. 143 std::vector<uint32_t> ids; 144 return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids); 145 } 146 147 if (c1 == nullptr || c2 == nullptr) { 148 return nullptr; 149 } 150 151 // Check result type. 152 const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); 153 const analysis::Vector* vector_type = result_type->AsVector(); 154 assert(vector_type != nullptr); 155 const analysis::Type* element_type = vector_type->element_type(); 156 assert(element_type != nullptr); 157 const analysis::Float* float_type = element_type->AsFloat(); 158 assert(float_type != nullptr); 159 160 // Check types of c1 and c2. 161 assert(c1->type()->AsVector() == vector_type); 162 assert(c1->type()->AsVector()->element_type() == element_type && 163 c2->type() == element_type); 164 165 // Get a float vector that is the result of vector-times-scalar. 166 std::vector<const analysis::Constant*> c1_components = 167 c1->GetVectorComponents(const_mgr); 168 std::vector<uint32_t> ids; 169 if (float_type->width() == 32) { 170 float scalar = c2->GetFloat(); 171 for (uint32_t i = 0; i < c1_components.size(); ++i) { 172 utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar); 173 std::vector<uint32_t> words = result.GetWords(); 174 const analysis::Constant* new_elem = 175 const_mgr->GetConstant(float_type, words); 176 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); 177 } 178 return const_mgr->GetConstant(vector_type, ids); 179 } else if (float_type->width() == 64) { 180 double scalar = c2->GetDouble(); 181 for (uint32_t i = 0; i < c1_components.size(); ++i) { 182 utils::FloatProxy<double> result(c1_components[i]->GetDouble() * 183 scalar); 184 std::vector<uint32_t> words = result.GetWords(); 185 const analysis::Constant* new_elem = 186 const_mgr->GetConstant(float_type, words); 187 ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); 188 } 189 return const_mgr->GetConstant(vector_type, ids); 190 } 191 return nullptr; 192 }; 193 } 194 195 ConstantFoldingRule FoldCompositeWithConstants() { 196 // Folds an OpCompositeConstruct where all of the inputs are constants to a 197 // constant. A new constant is created if necessary. 198 return [](IRContext* context, Instruction* inst, 199 const std::vector<const analysis::Constant*>& constants) 200 -> const analysis::Constant* { 201 analysis::ConstantManager* const_mgr = context->get_constant_mgr(); 202 analysis::TypeManager* type_mgr = context->get_type_mgr(); 203 const analysis::Type* new_type = type_mgr->GetType(inst->type_id()); 204 Instruction* type_inst = 205 context->get_def_use_mgr()->GetDef(inst->type_id()); 206 207 std::vector<uint32_t> ids; 208 for (uint32_t i = 0; i < constants.size(); ++i) { 209 const analysis::Constant* element_const = constants[i]; 210 if (element_const == nullptr) { 211 return nullptr; 212 } 213 214 uint32_t component_type_id = 0; 215 if (type_inst->opcode() == SpvOpTypeStruct) { 216 component_type_id = type_inst->GetSingleWordInOperand(i); 217 } else if (type_inst->opcode() == SpvOpTypeArray) { 218 component_type_id = type_inst->GetSingleWordInOperand(0); 219 } 220 221 uint32_t element_id = 222 const_mgr->FindDeclaredConstant(element_const, component_type_id); 223 if (element_id == 0) { 224 return nullptr; 225 } 226 ids.push_back(element_id); 227 } 228 return const_mgr->GetConstant(new_type, ids); 229 }; 230 } 231 232 // The interface for a function that returns the result of applying a scalar 233 // floating-point binary operation on |a| and |b|. The type of the return value 234 // will be |type|. The input constants must also be of type |type|. 235 using UnaryScalarFoldingRule = std::function<const analysis::Constant*( 236 const analysis::Type* result_type, const analysis::Constant* a, 237 analysis::ConstantManager*)>; 238 239 // The interface for a function that returns the result of applying a scalar 240 // floating-point binary operation on |a| and |b|. The type of the return value 241 // will be |type|. The input constants must also be of type |type|. 242 using BinaryScalarFoldingRule = std::function<const analysis::Constant*( 243 const analysis::Type* result_type, const analysis::Constant* a, 244 const analysis::Constant* b, analysis::ConstantManager*)>; 245 246 // Returns a |ConstantFoldingRule| that folds unary floating point scalar ops 247 // using |scalar_rule| and unary float point vectors ops by applying 248 // |scalar_rule| to the elements of the vector. The |ConstantFoldingRule| 249 // that is returned assumes that |constants| contains 1 entry. If they are 250 // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector| 251 // whose element type is |Float| or |Integer|. 252 ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) { 253 return [scalar_rule](IRContext* context, Instruction* inst, 254 const std::vector<const analysis::Constant*>& constants) 255 -> const analysis::Constant* { 256 analysis::ConstantManager* const_mgr = context->get_constant_mgr(); 257 analysis::TypeManager* type_mgr = context->get_type_mgr(); 258 const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); 259 const analysis::Vector* vector_type = result_type->AsVector(); 260 261 if (!inst->IsFloatingPointFoldingAllowed()) { 262 return nullptr; 263 } 264 265 if (constants[0] == nullptr) { 266 return nullptr; 267 } 268 269 if (vector_type != nullptr) { 270 std::vector<const analysis::Constant*> a_components; 271 std::vector<const analysis::Constant*> results_components; 272 273 a_components = constants[0]->GetVectorComponents(const_mgr); 274 275 // Fold each component of the vector. 276 for (uint32_t i = 0; i < a_components.size(); ++i) { 277 results_components.push_back(scalar_rule(vector_type->element_type(), 278 a_components[i], const_mgr)); 279 if (results_components[i] == nullptr) { 280 return nullptr; 281 } 282 } 283 284 // Build the constant object and return it. 285 std::vector<uint32_t> ids; 286 for (const analysis::Constant* member : results_components) { 287 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); 288 } 289 return const_mgr->GetConstant(vector_type, ids); 290 } else { 291 return scalar_rule(result_type, constants[0], const_mgr); 292 } 293 }; 294 } 295 296 // Returns a |ConstantFoldingRule| that folds floating point scalars using 297 // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the 298 // elements of the vector. The |ConstantFoldingRule| that is returned assumes 299 // that |constants| contains 2 entries. If they are not |nullptr|, then their 300 // type is either |Float| or a |Vector| whose element type is |Float|. 301 ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) { 302 return [scalar_rule](IRContext* context, Instruction* inst, 303 const std::vector<const analysis::Constant*>& constants) 304 -> const analysis::Constant* { 305 analysis::ConstantManager* const_mgr = context->get_constant_mgr(); 306 analysis::TypeManager* type_mgr = context->get_type_mgr(); 307 const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); 308 const analysis::Vector* vector_type = result_type->AsVector(); 309 310 if (!inst->IsFloatingPointFoldingAllowed()) { 311 return nullptr; 312 } 313 314 if (constants[0] == nullptr || constants[1] == nullptr) { 315 return nullptr; 316 } 317 318 if (vector_type != nullptr) { 319 std::vector<const analysis::Constant*> a_components; 320 std::vector<const analysis::Constant*> b_components; 321 std::vector<const analysis::Constant*> results_components; 322 323 a_components = constants[0]->GetVectorComponents(const_mgr); 324 b_components = constants[1]->GetVectorComponents(const_mgr); 325 326 // Fold each component of the vector. 327 for (uint32_t i = 0; i < a_components.size(); ++i) { 328 results_components.push_back(scalar_rule(vector_type->element_type(), 329 a_components[i], 330 b_components[i], const_mgr)); 331 if (results_components[i] == nullptr) { 332 return nullptr; 333 } 334 } 335 336 // Build the constant object and return it. 337 std::vector<uint32_t> ids; 338 for (const analysis::Constant* member : results_components) { 339 ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); 340 } 341 return const_mgr->GetConstant(vector_type, ids); 342 } else { 343 return scalar_rule(result_type, constants[0], constants[1], const_mgr); 344 } 345 }; 346 } 347 348 // This macro defines a |UnaryScalarFoldingRule| that performs float to 349 // integer conversion. 350 // TODO(greg-lunarg): Support for 64-bit integer types. 351 UnaryScalarFoldingRule FoldFToIOp() { 352 return [](const analysis::Type* result_type, const analysis::Constant* a, 353 analysis::ConstantManager* const_mgr) -> const analysis::Constant* { 354 assert(result_type != nullptr && a != nullptr); 355 const analysis::Integer* integer_type = result_type->AsInteger(); 356 const analysis::Float* float_type = a->type()->AsFloat(); 357 assert(float_type != nullptr); 358 assert(integer_type != nullptr); 359 if (integer_type->width() != 32) return nullptr; 360 if (float_type->width() == 32) { 361 float fa = a->GetFloat(); 362 uint32_t result = integer_type->IsSigned() 363 ? static_cast<uint32_t>(static_cast<int32_t>(fa)) 364 : static_cast<uint32_t>(fa); 365 std::vector<uint32_t> words = {result}; 366 return const_mgr->GetConstant(result_type, words); 367 } else if (float_type->width() == 64) { 368 double fa = a->GetDouble(); 369 uint32_t result = integer_type->IsSigned() 370 ? static_cast<uint32_t>(static_cast<int32_t>(fa)) 371 : static_cast<uint32_t>(fa); 372 std::vector<uint32_t> words = {result}; 373 return const_mgr->GetConstant(result_type, words); 374 } 375 return nullptr; 376 }; 377 } 378 379 // This function defines a |UnaryScalarFoldingRule| that performs integer to 380 // float conversion. 381 // TODO(greg-lunarg): Support for 64-bit integer types. 382 UnaryScalarFoldingRule FoldIToFOp() { 383 return [](const analysis::Type* result_type, const analysis::Constant* a, 384 analysis::ConstantManager* const_mgr) -> const analysis::Constant* { 385 assert(result_type != nullptr && a != nullptr); 386 const analysis::Integer* integer_type = a->type()->AsInteger(); 387 const analysis::Float* float_type = result_type->AsFloat(); 388 assert(float_type != nullptr); 389 assert(integer_type != nullptr); 390 if (integer_type->width() != 32) return nullptr; 391 uint32_t ua = a->GetU32(); 392 if (float_type->width() == 32) { 393 float result_val = integer_type->IsSigned() 394 ? static_cast<float>(static_cast<int32_t>(ua)) 395 : static_cast<float>(ua); 396 utils::FloatProxy<float> result(result_val); 397 std::vector<uint32_t> words = {result.data()}; 398 return const_mgr->GetConstant(result_type, words); 399 } else if (float_type->width() == 64) { 400 double result_val = integer_type->IsSigned() 401 ? static_cast<double>(static_cast<int32_t>(ua)) 402 : static_cast<double>(ua); 403 utils::FloatProxy<double> result(result_val); 404 std::vector<uint32_t> words = result.GetWords(); 405 return const_mgr->GetConstant(result_type, words); 406 } 407 return nullptr; 408 }; 409 } 410 411 // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The 412 // operator |op| must work for both float and double, and use syntax "f1 op f2". 413 #define FOLD_FPARITH_OP(op) \ 414 [](const analysis::Type* result_type, const analysis::Constant* a, \ 415 const analysis::Constant* b, \ 416 analysis::ConstantManager* const_mgr_in_macro) \ 417 -> const analysis::Constant* { \ 418 assert(result_type != nullptr && a != nullptr && b != nullptr); \ 419 assert(result_type == a->type() && result_type == b->type()); \ 420 const analysis::Float* float_type_in_macro = result_type->AsFloat(); \ 421 assert(float_type_in_macro != nullptr); \ 422 if (float_type_in_macro->width() == 32) { \ 423 float fa = a->GetFloat(); \ 424 float fb = b->GetFloat(); \ 425 utils::FloatProxy<float> result_in_macro(fa op fb); \ 426 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \ 427 return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \ 428 } else if (float_type_in_macro->width() == 64) { \ 429 double fa = a->GetDouble(); \ 430 double fb = b->GetDouble(); \ 431 utils::FloatProxy<double> result_in_macro(fa op fb); \ 432 std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \ 433 return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \ 434 } \ 435 return nullptr; \ 436 } 437 438 // Define the folding rule for conversion between floating point and integer 439 ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); } 440 ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); } 441 442 // Define the folding rules for subtraction, addition, multiplication, and 443 // division for floating point values. 444 ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); } 445 ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); } 446 ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); } 447 ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FOLD_FPARITH_OP(/)); } 448 449 bool CompareFloatingPoint(bool op_result, bool op_unordered, 450 bool need_ordered) { 451 if (need_ordered) { 452 // operands are ordered and Operand 1 is |op| Operand 2 453 return !op_unordered && op_result; 454 } else { 455 // operands are unordered or Operand 1 is |op| Operand 2 456 return op_unordered || op_result; 457 } 458 } 459 460 // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The 461 // operator |op| must work for both float and double, and use syntax "f1 op f2". 462 #define FOLD_FPCMP_OP(op, ord) \ 463 [](const analysis::Type* result_type, const analysis::Constant* a, \ 464 const analysis::Constant* b, \ 465 analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \ 466 assert(result_type != nullptr && a != nullptr && b != nullptr); \ 467 assert(result_type->AsBool()); \ 468 assert(a->type() == b->type()); \ 469 const analysis::Float* float_type = a->type()->AsFloat(); \ 470 assert(float_type != nullptr); \ 471 if (float_type->width() == 32) { \ 472 float fa = a->GetFloat(); \ 473 float fb = b->GetFloat(); \ 474 bool result = CompareFloatingPoint( \ 475 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \ 476 std::vector<uint32_t> words = {uint32_t(result)}; \ 477 return const_mgr->GetConstant(result_type, words); \ 478 } else if (float_type->width() == 64) { \ 479 double fa = a->GetDouble(); \ 480 double fb = b->GetDouble(); \ 481 bool result = CompareFloatingPoint( \ 482 fa op fb, std::isnan(fa) || std::isnan(fb), ord); \ 483 std::vector<uint32_t> words = {uint32_t(result)}; \ 484 return const_mgr->GetConstant(result_type, words); \ 485 } \ 486 return nullptr; \ 487 } 488 489 // Define the folding rules for ordered and unordered comparison for floating 490 // point values. 491 ConstantFoldingRule FoldFOrdEqual() { 492 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true)); 493 } 494 ConstantFoldingRule FoldFUnordEqual() { 495 return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false)); 496 } 497 ConstantFoldingRule FoldFOrdNotEqual() { 498 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true)); 499 } 500 ConstantFoldingRule FoldFUnordNotEqual() { 501 return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false)); 502 } 503 ConstantFoldingRule FoldFOrdLessThan() { 504 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true)); 505 } 506 ConstantFoldingRule FoldFUnordLessThan() { 507 return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false)); 508 } 509 ConstantFoldingRule FoldFOrdGreaterThan() { 510 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true)); 511 } 512 ConstantFoldingRule FoldFUnordGreaterThan() { 513 return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false)); 514 } 515 ConstantFoldingRule FoldFOrdLessThanEqual() { 516 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true)); 517 } 518 ConstantFoldingRule FoldFUnordLessThanEqual() { 519 return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false)); 520 } 521 ConstantFoldingRule FoldFOrdGreaterThanEqual() { 522 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true)); 523 } 524 ConstantFoldingRule FoldFUnordGreaterThanEqual() { 525 return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false)); 526 } 527 528 // Folds an OpDot where all of the inputs are constants to a 529 // constant. A new constant is created if necessary. 530 ConstantFoldingRule FoldOpDotWithConstants() { 531 return [](IRContext* context, Instruction* inst, 532 const std::vector<const analysis::Constant*>& constants) 533 -> const analysis::Constant* { 534 analysis::ConstantManager* const_mgr = context->get_constant_mgr(); 535 analysis::TypeManager* type_mgr = context->get_type_mgr(); 536 const analysis::Type* new_type = type_mgr->GetType(inst->type_id()); 537 assert(new_type->AsFloat() && "OpDot should have a float return type."); 538 const analysis::Float* float_type = new_type->AsFloat(); 539 540 if (!inst->IsFloatingPointFoldingAllowed()) { 541 return nullptr; 542 } 543 544 // If one of the operands is 0, then the result is 0. 545 bool has_zero_operand = false; 546 547 for (int i = 0; i < 2; ++i) { 548 if (constants[i]) { 549 if (constants[i]->AsNullConstant() || 550 constants[i]->AsVectorConstant()->IsZero()) { 551 has_zero_operand = true; 552 break; 553 } 554 } 555 } 556 557 if (has_zero_operand) { 558 if (float_type->width() == 32) { 559 utils::FloatProxy<float> result(0.0f); 560 std::vector<uint32_t> words = result.GetWords(); 561 return const_mgr->GetConstant(float_type, words); 562 } 563 if (float_type->width() == 64) { 564 utils::FloatProxy<double> result(0.0); 565 std::vector<uint32_t> words = result.GetWords(); 566 return const_mgr->GetConstant(float_type, words); 567 } 568 return nullptr; 569 } 570 571 if (constants[0] == nullptr || constants[1] == nullptr) { 572 return nullptr; 573 } 574 575 std::vector<const analysis::Constant*> a_components; 576 std::vector<const analysis::Constant*> b_components; 577 578 a_components = constants[0]->GetVectorComponents(const_mgr); 579 b_components = constants[1]->GetVectorComponents(const_mgr); 580 581 utils::FloatProxy<double> result(0.0); 582 std::vector<uint32_t> words = result.GetWords(); 583 const analysis::Constant* result_const = 584 const_mgr->GetConstant(float_type, words); 585 for (uint32_t i = 0; i < a_components.size(); ++i) { 586 if (a_components[i] == nullptr || b_components[i] == nullptr) { 587 return nullptr; 588 } 589 590 const analysis::Constant* component = FOLD_FPARITH_OP(*)( 591 new_type, a_components[i], b_components[i], const_mgr); 592 result_const = 593 FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr); 594 } 595 return result_const; 596 }; 597 } 598 599 // This function defines a |UnaryScalarFoldingRule| that subtracts the constant 600 // from zero. 601 UnaryScalarFoldingRule FoldFNegateOp() { 602 return [](const analysis::Type* result_type, const analysis::Constant* a, 603 analysis::ConstantManager* const_mgr) -> const analysis::Constant* { 604 assert(result_type != nullptr && a != nullptr); 605 assert(result_type == a->type()); 606 const analysis::Float* float_type = result_type->AsFloat(); 607 assert(float_type != nullptr); 608 if (float_type->width() == 32) { 609 float fa = a->GetFloat(); 610 utils::FloatProxy<float> result(-fa); 611 std::vector<uint32_t> words = result.GetWords(); 612 return const_mgr->GetConstant(result_type, words); 613 } else if (float_type->width() == 64) { 614 double da = a->GetDouble(); 615 utils::FloatProxy<double> result(-da); 616 std::vector<uint32_t> words = result.GetWords(); 617 return const_mgr->GetConstant(result_type, words); 618 } 619 return nullptr; 620 }; 621 } 622 623 ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); } 624 625 ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) { 626 return [cmp_opcode](IRContext* context, Instruction* inst, 627 const std::vector<const analysis::Constant*>& constants) 628 -> const analysis::Constant* { 629 analysis::ConstantManager* const_mgr = context->get_constant_mgr(); 630 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); 631 632 if (!inst->IsFloatingPointFoldingAllowed()) { 633 return nullptr; 634 } 635 636 uint32_t non_const_idx = (constants[0] ? 1 : 0); 637 uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx); 638 Instruction* operand_inst = def_use_mgr->GetDef(operand_id); 639 640 analysis::TypeManager* type_mgr = context->get_type_mgr(); 641 const analysis::Type* operand_type = 642 type_mgr->GetType(operand_inst->type_id()); 643 644 if (!operand_type->AsFloat()) { 645 return nullptr; 646 } 647 648 if (operand_type->AsFloat()->width() != 32 && 649 operand_type->AsFloat()->width() != 64) { 650 return nullptr; 651 } 652 653 if (operand_inst->opcode() != SpvOpExtInst) { 654 return nullptr; 655 } 656 657 if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) { 658 return nullptr; 659 } 660 661 if (constants[1] == nullptr && constants[0] == nullptr) { 662 return nullptr; 663 } 664 665 uint32_t max_id = operand_inst->GetSingleWordInOperand(4); 666 const analysis::Constant* max_const = 667 const_mgr->FindDeclaredConstant(max_id); 668 669 uint32_t min_id = operand_inst->GetSingleWordInOperand(3); 670 const analysis::Constant* min_const = 671 const_mgr->FindDeclaredConstant(min_id); 672 673 bool found_result = false; 674 bool result = false; 675 676 switch (cmp_opcode) { 677 case SpvOpFOrdLessThan: 678 case SpvOpFUnordLessThan: 679 case SpvOpFOrdGreaterThanEqual: 680 case SpvOpFUnordGreaterThanEqual: 681 if (constants[0]) { 682 if (min_const) { 683 if (constants[0]->GetValueAsDouble() < 684 min_const->GetValueAsDouble()) { 685 found_result = true; 686 result = (cmp_opcode == SpvOpFOrdLessThan || 687 cmp_opcode == SpvOpFUnordLessThan); 688 } 689 } 690 if (max_const) { 691 if (constants[0]->GetValueAsDouble() >= 692 max_const->GetValueAsDouble()) { 693 found_result = true; 694 result = !(cmp_opcode == SpvOpFOrdLessThan || 695 cmp_opcode == SpvOpFUnordLessThan); 696 } 697 } 698 } 699 700 if (constants[1]) { 701 if (max_const) { 702 if (max_const->GetValueAsDouble() < 703 constants[1]->GetValueAsDouble()) { 704 found_result = true; 705 result = (cmp_opcode == SpvOpFOrdLessThan || 706 cmp_opcode == SpvOpFUnordLessThan); 707 } 708 } 709 710 if (min_const) { 711 if (min_const->GetValueAsDouble() >= 712 constants[1]->GetValueAsDouble()) { 713 found_result = true; 714 result = !(cmp_opcode == SpvOpFOrdLessThan || 715 cmp_opcode == SpvOpFUnordLessThan); 716 } 717 } 718 } 719 break; 720 case SpvOpFOrdGreaterThan: 721 case SpvOpFUnordGreaterThan: 722 case SpvOpFOrdLessThanEqual: 723 case SpvOpFUnordLessThanEqual: 724 if (constants[0]) { 725 if (min_const) { 726 if (constants[0]->GetValueAsDouble() <= 727 min_const->GetValueAsDouble()) { 728 found_result = true; 729 result = (cmp_opcode == SpvOpFOrdLessThanEqual || 730 cmp_opcode == SpvOpFUnordLessThanEqual); 731 } 732 } 733 if (max_const) { 734 if (constants[0]->GetValueAsDouble() > 735 max_const->GetValueAsDouble()) { 736 found_result = true; 737 result = !(cmp_opcode == SpvOpFOrdLessThanEqual || 738 cmp_opcode == SpvOpFUnordLessThanEqual); 739 } 740 } 741 } 742 743 if (constants[1]) { 744 if (max_const) { 745 if (max_const->GetValueAsDouble() <= 746 constants[1]->GetValueAsDouble()) { 747 found_result = true; 748 result = (cmp_opcode == SpvOpFOrdLessThanEqual || 749 cmp_opcode == SpvOpFUnordLessThanEqual); 750 } 751 } 752 753 if (min_const) { 754 if (min_const->GetValueAsDouble() > 755 constants[1]->GetValueAsDouble()) { 756 found_result = true; 757 result = !(cmp_opcode == SpvOpFOrdLessThanEqual || 758 cmp_opcode == SpvOpFUnordLessThanEqual); 759 } 760 } 761 } 762 break; 763 default: 764 return nullptr; 765 } 766 767 if (!found_result) { 768 return nullptr; 769 } 770 771 const analysis::Type* bool_type = 772 context->get_type_mgr()->GetType(inst->type_id()); 773 const analysis::Constant* result_const = 774 const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)}); 775 assert(result_const); 776 return result_const; 777 }; 778 } 779 780 } // namespace 781 782 ConstantFoldingRules::ConstantFoldingRules() { 783 // Add all folding rules to the list for the opcodes to which they apply. 784 // Note that the order in which rules are added to the list matters. If a rule 785 // applies to the instruction, the rest of the rules will not be attempted. 786 // Take that into consideration. 787 788 rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants()); 789 790 rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants()); 791 792 rules_[SpvOpConvertFToS].push_back(FoldFToI()); 793 rules_[SpvOpConvertFToU].push_back(FoldFToI()); 794 rules_[SpvOpConvertSToF].push_back(FoldIToF()); 795 rules_[SpvOpConvertUToF].push_back(FoldIToF()); 796 797 rules_[SpvOpDot].push_back(FoldOpDotWithConstants()); 798 rules_[SpvOpFAdd].push_back(FoldFAdd()); 799 rules_[SpvOpFDiv].push_back(FoldFDiv()); 800 rules_[SpvOpFMul].push_back(FoldFMul()); 801 rules_[SpvOpFSub].push_back(FoldFSub()); 802 803 rules_[SpvOpFOrdEqual].push_back(FoldFOrdEqual()); 804 805 rules_[SpvOpFUnordEqual].push_back(FoldFUnordEqual()); 806 807 rules_[SpvOpFOrdNotEqual].push_back(FoldFOrdNotEqual()); 808 809 rules_[SpvOpFUnordNotEqual].push_back(FoldFUnordNotEqual()); 810 811 rules_[SpvOpFOrdLessThan].push_back(FoldFOrdLessThan()); 812 rules_[SpvOpFOrdLessThan].push_back( 813 FoldFClampFeedingCompare(SpvOpFOrdLessThan)); 814 815 rules_[SpvOpFUnordLessThan].push_back(FoldFUnordLessThan()); 816 rules_[SpvOpFUnordLessThan].push_back( 817 FoldFClampFeedingCompare(SpvOpFUnordLessThan)); 818 819 rules_[SpvOpFOrdGreaterThan].push_back(FoldFOrdGreaterThan()); 820 rules_[SpvOpFOrdGreaterThan].push_back( 821 FoldFClampFeedingCompare(SpvOpFOrdGreaterThan)); 822 823 rules_[SpvOpFUnordGreaterThan].push_back(FoldFUnordGreaterThan()); 824 rules_[SpvOpFUnordGreaterThan].push_back( 825 FoldFClampFeedingCompare(SpvOpFUnordGreaterThan)); 826 827 rules_[SpvOpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual()); 828 rules_[SpvOpFOrdLessThanEqual].push_back( 829 FoldFClampFeedingCompare(SpvOpFOrdLessThanEqual)); 830 831 rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual()); 832 rules_[SpvOpFUnordLessThanEqual].push_back( 833 FoldFClampFeedingCompare(SpvOpFUnordLessThanEqual)); 834 835 rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual()); 836 rules_[SpvOpFOrdGreaterThanEqual].push_back( 837 FoldFClampFeedingCompare(SpvOpFOrdGreaterThanEqual)); 838 839 rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual()); 840 rules_[SpvOpFUnordGreaterThanEqual].push_back( 841 FoldFClampFeedingCompare(SpvOpFUnordGreaterThanEqual)); 842 843 rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants()); 844 rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar()); 845 846 rules_[SpvOpFNegate].push_back(FoldFNegate()); 847 } 848 } // namespace opt 849 } // namespace spvtools 850