Home | History | Annotate | Download | only in opt
      1 // Copyright (c) 2018 Google LLC
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "source/opt/const_folding_rules.h"
     16 
     17 #include "source/opt/ir_context.h"
     18 
     19 namespace spvtools {
     20 namespace opt {
     21 namespace {
     22 
     23 const uint32_t kExtractCompositeIdInIdx = 0;
     24 
     25 // Returns true if |type| is Float or a vector of Float.
     26 bool HasFloatingPoint(const analysis::Type* type) {
     27   if (type->AsFloat()) {
     28     return true;
     29   } else if (const analysis::Vector* vec_type = type->AsVector()) {
     30     return vec_type->element_type()->AsFloat() != nullptr;
     31   }
     32 
     33   return false;
     34 }
     35 
     36 // Folds an OpcompositeExtract where input is a composite constant.
     37 ConstantFoldingRule FoldExtractWithConstants() {
     38   return [](IRContext* context, Instruction* inst,
     39             const std::vector<const analysis::Constant*>& constants)
     40              -> const analysis::Constant* {
     41     const analysis::Constant* c = constants[kExtractCompositeIdInIdx];
     42     if (c == nullptr) {
     43       return nullptr;
     44     }
     45 
     46     for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
     47       uint32_t element_index = inst->GetSingleWordInOperand(i);
     48       if (c->AsNullConstant()) {
     49         // Return Null for the return type.
     50         analysis::ConstantManager* const_mgr = context->get_constant_mgr();
     51         analysis::TypeManager* type_mgr = context->get_type_mgr();
     52         return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
     53       }
     54 
     55       auto cc = c->AsCompositeConstant();
     56       assert(cc != nullptr);
     57       auto components = cc->GetComponents();
     58       c = components[element_index];
     59     }
     60     return c;
     61   };
     62 }
     63 
     64 ConstantFoldingRule FoldVectorShuffleWithConstants() {
     65   return [](IRContext* context, Instruction* inst,
     66             const std::vector<const analysis::Constant*>& constants)
     67              -> const analysis::Constant* {
     68     assert(inst->opcode() == SpvOpVectorShuffle);
     69     const analysis::Constant* c1 = constants[0];
     70     const analysis::Constant* c2 = constants[1];
     71     if (c1 == nullptr || c2 == nullptr) {
     72       return nullptr;
     73     }
     74 
     75     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
     76     const analysis::Type* element_type = c1->type()->AsVector()->element_type();
     77 
     78     std::vector<const analysis::Constant*> c1_components;
     79     if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) {
     80       c1_components = vec_const->GetComponents();
     81     } else {
     82       assert(c1->AsNullConstant());
     83       const analysis::Constant* element =
     84           const_mgr->GetConstant(element_type, {});
     85       c1_components.resize(c1->type()->AsVector()->element_count(), element);
     86     }
     87     std::vector<const analysis::Constant*> c2_components;
     88     if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) {
     89       c2_components = vec_const->GetComponents();
     90     } else {
     91       assert(c2->AsNullConstant());
     92       const analysis::Constant* element =
     93           const_mgr->GetConstant(element_type, {});
     94       c2_components.resize(c2->type()->AsVector()->element_count(), element);
     95     }
     96 
     97     std::vector<uint32_t> ids;
     98     const uint32_t undef_literal_value = 0xffffffff;
     99     for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
    100       uint32_t index = inst->GetSingleWordInOperand(i);
    101       if (index == undef_literal_value) {
    102         // Don't fold shuffle with undef literal value.
    103         return nullptr;
    104       } else if (index < c1_components.size()) {
    105         Instruction* member_inst =
    106             const_mgr->GetDefiningInstruction(c1_components[index]);
    107         ids.push_back(member_inst->result_id());
    108       } else {
    109         Instruction* member_inst = const_mgr->GetDefiningInstruction(
    110             c2_components[index - c1_components.size()]);
    111         ids.push_back(member_inst->result_id());
    112       }
    113     }
    114 
    115     analysis::TypeManager* type_mgr = context->get_type_mgr();
    116     return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
    117   };
    118 }
    119 
    120 ConstantFoldingRule FoldVectorTimesScalar() {
    121   return [](IRContext* context, Instruction* inst,
    122             const std::vector<const analysis::Constant*>& constants)
    123              -> const analysis::Constant* {
    124     assert(inst->opcode() == SpvOpVectorTimesScalar);
    125     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
    126     analysis::TypeManager* type_mgr = context->get_type_mgr();
    127 
    128     if (!inst->IsFloatingPointFoldingAllowed()) {
    129       if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) {
    130         return nullptr;
    131       }
    132     }
    133 
    134     const analysis::Constant* c1 = constants[0];
    135     const analysis::Constant* c2 = constants[1];
    136 
    137     if (c1 && c1->IsZero()) {
    138       return c1;
    139     }
    140 
    141     if (c2 && c2->IsZero()) {
    142       // Get or create the NullConstant for this type.
    143       std::vector<uint32_t> ids;
    144       return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids);
    145     }
    146 
    147     if (c1 == nullptr || c2 == nullptr) {
    148       return nullptr;
    149     }
    150 
    151     // Check result type.
    152     const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
    153     const analysis::Vector* vector_type = result_type->AsVector();
    154     assert(vector_type != nullptr);
    155     const analysis::Type* element_type = vector_type->element_type();
    156     assert(element_type != nullptr);
    157     const analysis::Float* float_type = element_type->AsFloat();
    158     assert(float_type != nullptr);
    159 
    160     // Check types of c1 and c2.
    161     assert(c1->type()->AsVector() == vector_type);
    162     assert(c1->type()->AsVector()->element_type() == element_type &&
    163            c2->type() == element_type);
    164 
    165     // Get a float vector that is the result of vector-times-scalar.
    166     std::vector<const analysis::Constant*> c1_components =
    167         c1->GetVectorComponents(const_mgr);
    168     std::vector<uint32_t> ids;
    169     if (float_type->width() == 32) {
    170       float scalar = c2->GetFloat();
    171       for (uint32_t i = 0; i < c1_components.size(); ++i) {
    172         utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar);
    173         std::vector<uint32_t> words = result.GetWords();
    174         const analysis::Constant* new_elem =
    175             const_mgr->GetConstant(float_type, words);
    176         ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
    177       }
    178       return const_mgr->GetConstant(vector_type, ids);
    179     } else if (float_type->width() == 64) {
    180       double scalar = c2->GetDouble();
    181       for (uint32_t i = 0; i < c1_components.size(); ++i) {
    182         utils::FloatProxy<double> result(c1_components[i]->GetDouble() *
    183                                          scalar);
    184         std::vector<uint32_t> words = result.GetWords();
    185         const analysis::Constant* new_elem =
    186             const_mgr->GetConstant(float_type, words);
    187         ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id());
    188       }
    189       return const_mgr->GetConstant(vector_type, ids);
    190     }
    191     return nullptr;
    192   };
    193 }
    194 
    195 ConstantFoldingRule FoldCompositeWithConstants() {
    196   // Folds an OpCompositeConstruct where all of the inputs are constants to a
    197   // constant.  A new constant is created if necessary.
    198   return [](IRContext* context, Instruction* inst,
    199             const std::vector<const analysis::Constant*>& constants)
    200              -> const analysis::Constant* {
    201     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
    202     analysis::TypeManager* type_mgr = context->get_type_mgr();
    203     const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
    204     Instruction* type_inst =
    205         context->get_def_use_mgr()->GetDef(inst->type_id());
    206 
    207     std::vector<uint32_t> ids;
    208     for (uint32_t i = 0; i < constants.size(); ++i) {
    209       const analysis::Constant* element_const = constants[i];
    210       if (element_const == nullptr) {
    211         return nullptr;
    212       }
    213 
    214       uint32_t component_type_id = 0;
    215       if (type_inst->opcode() == SpvOpTypeStruct) {
    216         component_type_id = type_inst->GetSingleWordInOperand(i);
    217       } else if (type_inst->opcode() == SpvOpTypeArray) {
    218         component_type_id = type_inst->GetSingleWordInOperand(0);
    219       }
    220 
    221       uint32_t element_id =
    222           const_mgr->FindDeclaredConstant(element_const, component_type_id);
    223       if (element_id == 0) {
    224         return nullptr;
    225       }
    226       ids.push_back(element_id);
    227     }
    228     return const_mgr->GetConstant(new_type, ids);
    229   };
    230 }
    231 
    232 // The interface for a function that returns the result of applying a scalar
    233 // floating-point binary operation on |a| and |b|.  The type of the return value
    234 // will be |type|.  The input constants must also be of type |type|.
    235 using UnaryScalarFoldingRule = std::function<const analysis::Constant*(
    236     const analysis::Type* result_type, const analysis::Constant* a,
    237     analysis::ConstantManager*)>;
    238 
    239 // The interface for a function that returns the result of applying a scalar
    240 // floating-point binary operation on |a| and |b|.  The type of the return value
    241 // will be |type|.  The input constants must also be of type |type|.
    242 using BinaryScalarFoldingRule = std::function<const analysis::Constant*(
    243     const analysis::Type* result_type, const analysis::Constant* a,
    244     const analysis::Constant* b, analysis::ConstantManager*)>;
    245 
    246 // Returns a |ConstantFoldingRule| that folds unary floating point scalar ops
    247 // using |scalar_rule| and unary float point vectors ops by applying
    248 // |scalar_rule| to the elements of the vector.  The |ConstantFoldingRule|
    249 // that is returned assumes that |constants| contains 1 entry.  If they are
    250 // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
    251 // whose element type is |Float| or |Integer|.
    252 ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) {
    253   return [scalar_rule](IRContext* context, Instruction* inst,
    254                        const std::vector<const analysis::Constant*>& constants)
    255              -> const analysis::Constant* {
    256     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
    257     analysis::TypeManager* type_mgr = context->get_type_mgr();
    258     const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
    259     const analysis::Vector* vector_type = result_type->AsVector();
    260 
    261     if (!inst->IsFloatingPointFoldingAllowed()) {
    262       return nullptr;
    263     }
    264 
    265     if (constants[0] == nullptr) {
    266       return nullptr;
    267     }
    268 
    269     if (vector_type != nullptr) {
    270       std::vector<const analysis::Constant*> a_components;
    271       std::vector<const analysis::Constant*> results_components;
    272 
    273       a_components = constants[0]->GetVectorComponents(const_mgr);
    274 
    275       // Fold each component of the vector.
    276       for (uint32_t i = 0; i < a_components.size(); ++i) {
    277         results_components.push_back(scalar_rule(vector_type->element_type(),
    278                                                  a_components[i], const_mgr));
    279         if (results_components[i] == nullptr) {
    280           return nullptr;
    281         }
    282       }
    283 
    284       // Build the constant object and return it.
    285       std::vector<uint32_t> ids;
    286       for (const analysis::Constant* member : results_components) {
    287         ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
    288       }
    289       return const_mgr->GetConstant(vector_type, ids);
    290     } else {
    291       return scalar_rule(result_type, constants[0], const_mgr);
    292     }
    293   };
    294 }
    295 
    296 // Returns a |ConstantFoldingRule| that folds floating point scalars using
    297 // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the
    298 // elements of the vector.  The |ConstantFoldingRule| that is returned assumes
    299 // that |constants| contains 2 entries.  If they are not |nullptr|, then their
    300 // type is either |Float| or a |Vector| whose element type is |Float|.
    301 ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
    302   return [scalar_rule](IRContext* context, Instruction* inst,
    303                        const std::vector<const analysis::Constant*>& constants)
    304              -> const analysis::Constant* {
    305     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
    306     analysis::TypeManager* type_mgr = context->get_type_mgr();
    307     const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
    308     const analysis::Vector* vector_type = result_type->AsVector();
    309 
    310     if (!inst->IsFloatingPointFoldingAllowed()) {
    311       return nullptr;
    312     }
    313 
    314     if (constants[0] == nullptr || constants[1] == nullptr) {
    315       return nullptr;
    316     }
    317 
    318     if (vector_type != nullptr) {
    319       std::vector<const analysis::Constant*> a_components;
    320       std::vector<const analysis::Constant*> b_components;
    321       std::vector<const analysis::Constant*> results_components;
    322 
    323       a_components = constants[0]->GetVectorComponents(const_mgr);
    324       b_components = constants[1]->GetVectorComponents(const_mgr);
    325 
    326       // Fold each component of the vector.
    327       for (uint32_t i = 0; i < a_components.size(); ++i) {
    328         results_components.push_back(scalar_rule(vector_type->element_type(),
    329                                                  a_components[i],
    330                                                  b_components[i], const_mgr));
    331         if (results_components[i] == nullptr) {
    332           return nullptr;
    333         }
    334       }
    335 
    336       // Build the constant object and return it.
    337       std::vector<uint32_t> ids;
    338       for (const analysis::Constant* member : results_components) {
    339         ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
    340       }
    341       return const_mgr->GetConstant(vector_type, ids);
    342     } else {
    343       return scalar_rule(result_type, constants[0], constants[1], const_mgr);
    344     }
    345   };
    346 }
    347 
    348 // This macro defines a |UnaryScalarFoldingRule| that performs float to
    349 // integer conversion.
    350 // TODO(greg-lunarg): Support for 64-bit integer types.
    351 UnaryScalarFoldingRule FoldFToIOp() {
    352   return [](const analysis::Type* result_type, const analysis::Constant* a,
    353             analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
    354     assert(result_type != nullptr && a != nullptr);
    355     const analysis::Integer* integer_type = result_type->AsInteger();
    356     const analysis::Float* float_type = a->type()->AsFloat();
    357     assert(float_type != nullptr);
    358     assert(integer_type != nullptr);
    359     if (integer_type->width() != 32) return nullptr;
    360     if (float_type->width() == 32) {
    361       float fa = a->GetFloat();
    362       uint32_t result = integer_type->IsSigned()
    363                             ? static_cast<uint32_t>(static_cast<int32_t>(fa))
    364                             : static_cast<uint32_t>(fa);
    365       std::vector<uint32_t> words = {result};
    366       return const_mgr->GetConstant(result_type, words);
    367     } else if (float_type->width() == 64) {
    368       double fa = a->GetDouble();
    369       uint32_t result = integer_type->IsSigned()
    370                             ? static_cast<uint32_t>(static_cast<int32_t>(fa))
    371                             : static_cast<uint32_t>(fa);
    372       std::vector<uint32_t> words = {result};
    373       return const_mgr->GetConstant(result_type, words);
    374     }
    375     return nullptr;
    376   };
    377 }
    378 
    379 // This function defines a |UnaryScalarFoldingRule| that performs integer to
    380 // float conversion.
    381 // TODO(greg-lunarg): Support for 64-bit integer types.
    382 UnaryScalarFoldingRule FoldIToFOp() {
    383   return [](const analysis::Type* result_type, const analysis::Constant* a,
    384             analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
    385     assert(result_type != nullptr && a != nullptr);
    386     const analysis::Integer* integer_type = a->type()->AsInteger();
    387     const analysis::Float* float_type = result_type->AsFloat();
    388     assert(float_type != nullptr);
    389     assert(integer_type != nullptr);
    390     if (integer_type->width() != 32) return nullptr;
    391     uint32_t ua = a->GetU32();
    392     if (float_type->width() == 32) {
    393       float result_val = integer_type->IsSigned()
    394                              ? static_cast<float>(static_cast<int32_t>(ua))
    395                              : static_cast<float>(ua);
    396       utils::FloatProxy<float> result(result_val);
    397       std::vector<uint32_t> words = {result.data()};
    398       return const_mgr->GetConstant(result_type, words);
    399     } else if (float_type->width() == 64) {
    400       double result_val = integer_type->IsSigned()
    401                               ? static_cast<double>(static_cast<int32_t>(ua))
    402                               : static_cast<double>(ua);
    403       utils::FloatProxy<double> result(result_val);
    404       std::vector<uint32_t> words = result.GetWords();
    405       return const_mgr->GetConstant(result_type, words);
    406     }
    407     return nullptr;
    408   };
    409 }
    410 
    411 // This macro defines a |BinaryScalarFoldingRule| that applies |op|.  The
    412 // operator |op| must work for both float and double, and use syntax "f1 op f2".
    413 #define FOLD_FPARITH_OP(op)                                                \
    414   [](const analysis::Type* result_type, const analysis::Constant* a,       \
    415      const analysis::Constant* b,                                          \
    416      analysis::ConstantManager* const_mgr_in_macro)                        \
    417       -> const analysis::Constant* {                                       \
    418     assert(result_type != nullptr && a != nullptr && b != nullptr);        \
    419     assert(result_type == a->type() && result_type == b->type());          \
    420     const analysis::Float* float_type_in_macro = result_type->AsFloat();   \
    421     assert(float_type_in_macro != nullptr);                                \
    422     if (float_type_in_macro->width() == 32) {                              \
    423       float fa = a->GetFloat();                                            \
    424       float fb = b->GetFloat();                                            \
    425       utils::FloatProxy<float> result_in_macro(fa op fb);                  \
    426       std::vector<uint32_t> words_in_macro = result_in_macro.GetWords();   \
    427       return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
    428     } else if (float_type_in_macro->width() == 64) {                       \
    429       double fa = a->GetDouble();                                          \
    430       double fb = b->GetDouble();                                          \
    431       utils::FloatProxy<double> result_in_macro(fa op fb);                 \
    432       std::vector<uint32_t> words_in_macro = result_in_macro.GetWords();   \
    433       return const_mgr_in_macro->GetConstant(result_type, words_in_macro); \
    434     }                                                                      \
    435     return nullptr;                                                        \
    436   }
    437 
    438 // Define the folding rule for conversion between floating point and integer
    439 ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); }
    440 ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); }
    441 
    442 // Define the folding rules for subtraction, addition, multiplication, and
    443 // division for floating point values.
    444 ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
    445 ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
    446 ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
    447 ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FOLD_FPARITH_OP(/)); }
    448 
    449 bool CompareFloatingPoint(bool op_result, bool op_unordered,
    450                           bool need_ordered) {
    451   if (need_ordered) {
    452     // operands are ordered and Operand 1 is |op| Operand 2
    453     return !op_unordered && op_result;
    454   } else {
    455     // operands are unordered or Operand 1 is |op| Operand 2
    456     return op_unordered || op_result;
    457   }
    458 }
    459 
    460 // This macro defines a |BinaryScalarFoldingRule| that applies |op|.  The
    461 // operator |op| must work for both float and double, and use syntax "f1 op f2".
    462 #define FOLD_FPCMP_OP(op, ord)                                            \
    463   [](const analysis::Type* result_type, const analysis::Constant* a,      \
    464      const analysis::Constant* b,                                         \
    465      analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \
    466     assert(result_type != nullptr && a != nullptr && b != nullptr);       \
    467     assert(result_type->AsBool());                                        \
    468     assert(a->type() == b->type());                                       \
    469     const analysis::Float* float_type = a->type()->AsFloat();             \
    470     assert(float_type != nullptr);                                        \
    471     if (float_type->width() == 32) {                                      \
    472       float fa = a->GetFloat();                                           \
    473       float fb = b->GetFloat();                                           \
    474       bool result = CompareFloatingPoint(                                 \
    475           fa op fb, std::isnan(fa) || std::isnan(fb), ord);               \
    476       std::vector<uint32_t> words = {uint32_t(result)};                   \
    477       return const_mgr->GetConstant(result_type, words);                  \
    478     } else if (float_type->width() == 64) {                               \
    479       double fa = a->GetDouble();                                         \
    480       double fb = b->GetDouble();                                         \
    481       bool result = CompareFloatingPoint(                                 \
    482           fa op fb, std::isnan(fa) || std::isnan(fb), ord);               \
    483       std::vector<uint32_t> words = {uint32_t(result)};                   \
    484       return const_mgr->GetConstant(result_type, words);                  \
    485     }                                                                     \
    486     return nullptr;                                                       \
    487   }
    488 
    489 // Define the folding rules for ordered and unordered comparison for floating
    490 // point values.
    491 ConstantFoldingRule FoldFOrdEqual() {
    492   return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true));
    493 }
    494 ConstantFoldingRule FoldFUnordEqual() {
    495   return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false));
    496 }
    497 ConstantFoldingRule FoldFOrdNotEqual() {
    498   return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true));
    499 }
    500 ConstantFoldingRule FoldFUnordNotEqual() {
    501   return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false));
    502 }
    503 ConstantFoldingRule FoldFOrdLessThan() {
    504   return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true));
    505 }
    506 ConstantFoldingRule FoldFUnordLessThan() {
    507   return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false));
    508 }
    509 ConstantFoldingRule FoldFOrdGreaterThan() {
    510   return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true));
    511 }
    512 ConstantFoldingRule FoldFUnordGreaterThan() {
    513   return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false));
    514 }
    515 ConstantFoldingRule FoldFOrdLessThanEqual() {
    516   return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true));
    517 }
    518 ConstantFoldingRule FoldFUnordLessThanEqual() {
    519   return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false));
    520 }
    521 ConstantFoldingRule FoldFOrdGreaterThanEqual() {
    522   return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true));
    523 }
    524 ConstantFoldingRule FoldFUnordGreaterThanEqual() {
    525   return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false));
    526 }
    527 
    528 // Folds an OpDot where all of the inputs are constants to a
    529 // constant.  A new constant is created if necessary.
    530 ConstantFoldingRule FoldOpDotWithConstants() {
    531   return [](IRContext* context, Instruction* inst,
    532             const std::vector<const analysis::Constant*>& constants)
    533              -> const analysis::Constant* {
    534     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
    535     analysis::TypeManager* type_mgr = context->get_type_mgr();
    536     const analysis::Type* new_type = type_mgr->GetType(inst->type_id());
    537     assert(new_type->AsFloat() && "OpDot should have a float return type.");
    538     const analysis::Float* float_type = new_type->AsFloat();
    539 
    540     if (!inst->IsFloatingPointFoldingAllowed()) {
    541       return nullptr;
    542     }
    543 
    544     // If one of the operands is 0, then the result is 0.
    545     bool has_zero_operand = false;
    546 
    547     for (int i = 0; i < 2; ++i) {
    548       if (constants[i]) {
    549         if (constants[i]->AsNullConstant() ||
    550             constants[i]->AsVectorConstant()->IsZero()) {
    551           has_zero_operand = true;
    552           break;
    553         }
    554       }
    555     }
    556 
    557     if (has_zero_operand) {
    558       if (float_type->width() == 32) {
    559         utils::FloatProxy<float> result(0.0f);
    560         std::vector<uint32_t> words = result.GetWords();
    561         return const_mgr->GetConstant(float_type, words);
    562       }
    563       if (float_type->width() == 64) {
    564         utils::FloatProxy<double> result(0.0);
    565         std::vector<uint32_t> words = result.GetWords();
    566         return const_mgr->GetConstant(float_type, words);
    567       }
    568       return nullptr;
    569     }
    570 
    571     if (constants[0] == nullptr || constants[1] == nullptr) {
    572       return nullptr;
    573     }
    574 
    575     std::vector<const analysis::Constant*> a_components;
    576     std::vector<const analysis::Constant*> b_components;
    577 
    578     a_components = constants[0]->GetVectorComponents(const_mgr);
    579     b_components = constants[1]->GetVectorComponents(const_mgr);
    580 
    581     utils::FloatProxy<double> result(0.0);
    582     std::vector<uint32_t> words = result.GetWords();
    583     const analysis::Constant* result_const =
    584         const_mgr->GetConstant(float_type, words);
    585     for (uint32_t i = 0; i < a_components.size(); ++i) {
    586       if (a_components[i] == nullptr || b_components[i] == nullptr) {
    587         return nullptr;
    588       }
    589 
    590       const analysis::Constant* component = FOLD_FPARITH_OP(*)(
    591           new_type, a_components[i], b_components[i], const_mgr);
    592       result_const =
    593           FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr);
    594     }
    595     return result_const;
    596   };
    597 }
    598 
    599 // This function defines a |UnaryScalarFoldingRule| that subtracts the constant
    600 // from zero.
    601 UnaryScalarFoldingRule FoldFNegateOp() {
    602   return [](const analysis::Type* result_type, const analysis::Constant* a,
    603             analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
    604     assert(result_type != nullptr && a != nullptr);
    605     assert(result_type == a->type());
    606     const analysis::Float* float_type = result_type->AsFloat();
    607     assert(float_type != nullptr);
    608     if (float_type->width() == 32) {
    609       float fa = a->GetFloat();
    610       utils::FloatProxy<float> result(-fa);
    611       std::vector<uint32_t> words = result.GetWords();
    612       return const_mgr->GetConstant(result_type, words);
    613     } else if (float_type->width() == 64) {
    614       double da = a->GetDouble();
    615       utils::FloatProxy<double> result(-da);
    616       std::vector<uint32_t> words = result.GetWords();
    617       return const_mgr->GetConstant(result_type, words);
    618     }
    619     return nullptr;
    620   };
    621 }
    622 
    623 ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); }
    624 
    625 ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) {
    626   return [cmp_opcode](IRContext* context, Instruction* inst,
    627                       const std::vector<const analysis::Constant*>& constants)
    628              -> const analysis::Constant* {
    629     analysis::ConstantManager* const_mgr = context->get_constant_mgr();
    630     analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
    631 
    632     if (!inst->IsFloatingPointFoldingAllowed()) {
    633       return nullptr;
    634     }
    635 
    636     uint32_t non_const_idx = (constants[0] ? 1 : 0);
    637     uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx);
    638     Instruction* operand_inst = def_use_mgr->GetDef(operand_id);
    639 
    640     analysis::TypeManager* type_mgr = context->get_type_mgr();
    641     const analysis::Type* operand_type =
    642         type_mgr->GetType(operand_inst->type_id());
    643 
    644     if (!operand_type->AsFloat()) {
    645       return nullptr;
    646     }
    647 
    648     if (operand_type->AsFloat()->width() != 32 &&
    649         operand_type->AsFloat()->width() != 64) {
    650       return nullptr;
    651     }
    652 
    653     if (operand_inst->opcode() != SpvOpExtInst) {
    654       return nullptr;
    655     }
    656 
    657     if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) {
    658       return nullptr;
    659     }
    660 
    661     if (constants[1] == nullptr && constants[0] == nullptr) {
    662       return nullptr;
    663     }
    664 
    665     uint32_t max_id = operand_inst->GetSingleWordInOperand(4);
    666     const analysis::Constant* max_const =
    667         const_mgr->FindDeclaredConstant(max_id);
    668 
    669     uint32_t min_id = operand_inst->GetSingleWordInOperand(3);
    670     const analysis::Constant* min_const =
    671         const_mgr->FindDeclaredConstant(min_id);
    672 
    673     bool found_result = false;
    674     bool result = false;
    675 
    676     switch (cmp_opcode) {
    677       case SpvOpFOrdLessThan:
    678       case SpvOpFUnordLessThan:
    679       case SpvOpFOrdGreaterThanEqual:
    680       case SpvOpFUnordGreaterThanEqual:
    681         if (constants[0]) {
    682           if (min_const) {
    683             if (constants[0]->GetValueAsDouble() <
    684                 min_const->GetValueAsDouble()) {
    685               found_result = true;
    686               result = (cmp_opcode == SpvOpFOrdLessThan ||
    687                         cmp_opcode == SpvOpFUnordLessThan);
    688             }
    689           }
    690           if (max_const) {
    691             if (constants[0]->GetValueAsDouble() >=
    692                 max_const->GetValueAsDouble()) {
    693               found_result = true;
    694               result = !(cmp_opcode == SpvOpFOrdLessThan ||
    695                          cmp_opcode == SpvOpFUnordLessThan);
    696             }
    697           }
    698         }
    699 
    700         if (constants[1]) {
    701           if (max_const) {
    702             if (max_const->GetValueAsDouble() <
    703                 constants[1]->GetValueAsDouble()) {
    704               found_result = true;
    705               result = (cmp_opcode == SpvOpFOrdLessThan ||
    706                         cmp_opcode == SpvOpFUnordLessThan);
    707             }
    708           }
    709 
    710           if (min_const) {
    711             if (min_const->GetValueAsDouble() >=
    712                 constants[1]->GetValueAsDouble()) {
    713               found_result = true;
    714               result = !(cmp_opcode == SpvOpFOrdLessThan ||
    715                          cmp_opcode == SpvOpFUnordLessThan);
    716             }
    717           }
    718         }
    719         break;
    720       case SpvOpFOrdGreaterThan:
    721       case SpvOpFUnordGreaterThan:
    722       case SpvOpFOrdLessThanEqual:
    723       case SpvOpFUnordLessThanEqual:
    724         if (constants[0]) {
    725           if (min_const) {
    726             if (constants[0]->GetValueAsDouble() <=
    727                 min_const->GetValueAsDouble()) {
    728               found_result = true;
    729               result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
    730                         cmp_opcode == SpvOpFUnordLessThanEqual);
    731             }
    732           }
    733           if (max_const) {
    734             if (constants[0]->GetValueAsDouble() >
    735                 max_const->GetValueAsDouble()) {
    736               found_result = true;
    737               result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
    738                          cmp_opcode == SpvOpFUnordLessThanEqual);
    739             }
    740           }
    741         }
    742 
    743         if (constants[1]) {
    744           if (max_const) {
    745             if (max_const->GetValueAsDouble() <=
    746                 constants[1]->GetValueAsDouble()) {
    747               found_result = true;
    748               result = (cmp_opcode == SpvOpFOrdLessThanEqual ||
    749                         cmp_opcode == SpvOpFUnordLessThanEqual);
    750             }
    751           }
    752 
    753           if (min_const) {
    754             if (min_const->GetValueAsDouble() >
    755                 constants[1]->GetValueAsDouble()) {
    756               found_result = true;
    757               result = !(cmp_opcode == SpvOpFOrdLessThanEqual ||
    758                          cmp_opcode == SpvOpFUnordLessThanEqual);
    759             }
    760           }
    761         }
    762         break;
    763       default:
    764         return nullptr;
    765     }
    766 
    767     if (!found_result) {
    768       return nullptr;
    769     }
    770 
    771     const analysis::Type* bool_type =
    772         context->get_type_mgr()->GetType(inst->type_id());
    773     const analysis::Constant* result_const =
    774         const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)});
    775     assert(result_const);
    776     return result_const;
    777   };
    778 }
    779 
    780 }  // namespace
    781 
    782 ConstantFoldingRules::ConstantFoldingRules() {
    783   // Add all folding rules to the list for the opcodes to which they apply.
    784   // Note that the order in which rules are added to the list matters. If a rule
    785   // applies to the instruction, the rest of the rules will not be attempted.
    786   // Take that into consideration.
    787 
    788   rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants());
    789 
    790   rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants());
    791 
    792   rules_[SpvOpConvertFToS].push_back(FoldFToI());
    793   rules_[SpvOpConvertFToU].push_back(FoldFToI());
    794   rules_[SpvOpConvertSToF].push_back(FoldIToF());
    795   rules_[SpvOpConvertUToF].push_back(FoldIToF());
    796 
    797   rules_[SpvOpDot].push_back(FoldOpDotWithConstants());
    798   rules_[SpvOpFAdd].push_back(FoldFAdd());
    799   rules_[SpvOpFDiv].push_back(FoldFDiv());
    800   rules_[SpvOpFMul].push_back(FoldFMul());
    801   rules_[SpvOpFSub].push_back(FoldFSub());
    802 
    803   rules_[SpvOpFOrdEqual].push_back(FoldFOrdEqual());
    804 
    805   rules_[SpvOpFUnordEqual].push_back(FoldFUnordEqual());
    806 
    807   rules_[SpvOpFOrdNotEqual].push_back(FoldFOrdNotEqual());
    808 
    809   rules_[SpvOpFUnordNotEqual].push_back(FoldFUnordNotEqual());
    810 
    811   rules_[SpvOpFOrdLessThan].push_back(FoldFOrdLessThan());
    812   rules_[SpvOpFOrdLessThan].push_back(
    813       FoldFClampFeedingCompare(SpvOpFOrdLessThan));
    814 
    815   rules_[SpvOpFUnordLessThan].push_back(FoldFUnordLessThan());
    816   rules_[SpvOpFUnordLessThan].push_back(
    817       FoldFClampFeedingCompare(SpvOpFUnordLessThan));
    818 
    819   rules_[SpvOpFOrdGreaterThan].push_back(FoldFOrdGreaterThan());
    820   rules_[SpvOpFOrdGreaterThan].push_back(
    821       FoldFClampFeedingCompare(SpvOpFOrdGreaterThan));
    822 
    823   rules_[SpvOpFUnordGreaterThan].push_back(FoldFUnordGreaterThan());
    824   rules_[SpvOpFUnordGreaterThan].push_back(
    825       FoldFClampFeedingCompare(SpvOpFUnordGreaterThan));
    826 
    827   rules_[SpvOpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual());
    828   rules_[SpvOpFOrdLessThanEqual].push_back(
    829       FoldFClampFeedingCompare(SpvOpFOrdLessThanEqual));
    830 
    831   rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual());
    832   rules_[SpvOpFUnordLessThanEqual].push_back(
    833       FoldFClampFeedingCompare(SpvOpFUnordLessThanEqual));
    834 
    835   rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual());
    836   rules_[SpvOpFOrdGreaterThanEqual].push_back(
    837       FoldFClampFeedingCompare(SpvOpFOrdGreaterThanEqual));
    838 
    839   rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual());
    840   rules_[SpvOpFUnordGreaterThanEqual].push_back(
    841       FoldFClampFeedingCompare(SpvOpFUnordGreaterThanEqual));
    842 
    843   rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants());
    844   rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar());
    845 
    846   rules_[SpvOpFNegate].push_back(FoldFNegate());
    847 }
    848 }  // namespace opt
    849 }  // namespace spvtools
    850