Home | History | Annotate | Download | only in opt
      1 // Copyright (c) 2016 Google Inc.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "source/opt/set_spec_constant_default_value_pass.h"
     16 
     17 #include <algorithm>
     18 #include <cctype>
     19 #include <cstring>
     20 #include <tuple>
     21 #include <vector>
     22 
     23 #include "source/opt/def_use_manager.h"
     24 #include "source/opt/ir_context.h"
     25 #include "source/opt/type_manager.h"
     26 #include "source/opt/types.h"
     27 #include "source/util/make_unique.h"
     28 #include "source/util/parse_number.h"
     29 #include "spirv-tools/libspirv.h"
     30 
     31 namespace spvtools {
     32 namespace opt {
     33 
     34 namespace {
     35 using utils::EncodeNumberStatus;
     36 using utils::NumberType;
     37 using utils::ParseAndEncodeNumber;
     38 using utils::ParseNumber;
     39 
     40 // Given a numeric value in a null-terminated c string and the expected type of
     41 // the value, parses the string and encodes it in a vector of words. If the
     42 // value is a scalar integer or floating point value, encodes the value in
     43 // SPIR-V encoding format. If the value is 'false' or 'true', returns a vector
     44 // with single word with value 0 or 1 respectively. Returns the vector
     45 // containing the encoded value on success. Otherwise returns an empty vector.
     46 std::vector<uint32_t> ParseDefaultValueStr(const char* text,
     47                                            const analysis::Type* type) {
     48   std::vector<uint32_t> result;
     49   if (!strcmp(text, "true") && type->AsBool()) {
     50     result.push_back(1u);
     51   } else if (!strcmp(text, "false") && type->AsBool()) {
     52     result.push_back(0u);
     53   } else {
     54     NumberType number_type = {32, SPV_NUMBER_UNSIGNED_INT};
     55     if (const auto* IT = type->AsInteger()) {
     56       number_type.bitwidth = IT->width();
     57       number_type.kind =
     58           IT->IsSigned() ? SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT;
     59     } else if (const auto* FT = type->AsFloat()) {
     60       number_type.bitwidth = FT->width();
     61       number_type.kind = SPV_NUMBER_FLOATING;
     62     } else {
     63       // Does not handle types other then boolean, integer or float. Returns
     64       // empty vector.
     65       result.clear();
     66       return result;
     67     }
     68     EncodeNumberStatus rc = ParseAndEncodeNumber(
     69         text, number_type, [&result](uint32_t word) { result.push_back(word); },
     70         nullptr);
     71     // Clear the result vector on failure.
     72     if (rc != EncodeNumberStatus::kSuccess) {
     73       result.clear();
     74     }
     75   }
     76   return result;
     77 }
     78 
     79 // Given a bit pattern and a type, checks if the bit pattern is compatible
     80 // with the type. If so, returns the bit pattern, otherwise returns an empty
     81 // bit pattern. If the given bit pattern is empty, returns an empty bit
     82 // pattern. If the given type represents a SPIR-V Boolean type, the bit pattern
     83 // to be returned is determined with the following standard:
     84 //   If any words in the input bit pattern are non zero, returns a bit pattern
     85 //   with 0x1, which represents a 'true'.
     86 //   If all words in the bit pattern are zero, returns a bit pattern with 0x0,
     87 //   which represents a 'false'.
     88 std::vector<uint32_t> ParseDefaultValueBitPattern(
     89     const std::vector<uint32_t>& input_bit_pattern,
     90     const analysis::Type* type) {
     91   std::vector<uint32_t> result;
     92   if (type->AsBool()) {
     93     if (std::any_of(input_bit_pattern.begin(), input_bit_pattern.end(),
     94                     [](uint32_t i) { return i != 0; })) {
     95       result.push_back(1u);
     96     } else {
     97       result.push_back(0u);
     98     }
     99     return result;
    100   } else if (const auto* IT = type->AsInteger()) {
    101     if (IT->width() == input_bit_pattern.size() * sizeof(uint32_t) * 8) {
    102       return std::vector<uint32_t>(input_bit_pattern);
    103     }
    104   } else if (const auto* FT = type->AsFloat()) {
    105     if (FT->width() == input_bit_pattern.size() * sizeof(uint32_t) * 8) {
    106       return std::vector<uint32_t>(input_bit_pattern);
    107     }
    108   }
    109   result.clear();
    110   return result;
    111 }
    112 
    113 // Returns true if the given instruction's result id could have a SpecId
    114 // decoration.
    115 bool CanHaveSpecIdDecoration(const Instruction& inst) {
    116   switch (inst.opcode()) {
    117     case SpvOp::SpvOpSpecConstant:
    118     case SpvOp::SpvOpSpecConstantFalse:
    119     case SpvOp::SpvOpSpecConstantTrue:
    120       return true;
    121     default:
    122       return false;
    123   }
    124 }
    125 
    126 // Given a decoration group defining instruction that is decorated with SpecId
    127 // decoration, finds the spec constant defining instruction which is the real
    128 // target of the SpecId decoration. Returns the spec constant defining
    129 // instruction if such an instruction is found, otherwise returns a nullptr.
    130 Instruction* GetSpecIdTargetFromDecorationGroup(
    131     const Instruction& decoration_group_defining_inst,
    132     analysis::DefUseManager* def_use_mgr) {
    133   // Find the OpGroupDecorate instruction which consumes the given decoration
    134   // group. Note that the given decoration group has SpecId decoration, which
    135   // is unique for different spec constants. So the decoration group cannot be
    136   // consumed by different OpGroupDecorate instructions. Therefore we only need
    137   // the first OpGroupDecoration instruction that uses the given decoration
    138   // group.
    139   Instruction* group_decorate_inst = nullptr;
    140   if (def_use_mgr->WhileEachUser(&decoration_group_defining_inst,
    141                                  [&group_decorate_inst](Instruction* user) {
    142                                    if (user->opcode() ==
    143                                        SpvOp::SpvOpGroupDecorate) {
    144                                      group_decorate_inst = user;
    145                                      return false;
    146                                    }
    147                                    return true;
    148                                  }))
    149     return nullptr;
    150 
    151   // Scan through the target ids of the OpGroupDecorate instruction. There
    152   // should be only one spec constant target consumes the SpecId decoration.
    153   // If multiple target ids are presented in the OpGroupDecorate instruction,
    154   // they must be the same one that defined by an eligible spec constant
    155   // instruction. If the OpGroupDecorate instruction has different target ids
    156   // or a target id is not defined by an eligible spec cosntant instruction,
    157   // returns a nullptr.
    158   Instruction* target_inst = nullptr;
    159   for (uint32_t i = 1; i < group_decorate_inst->NumInOperands(); i++) {
    160     // All the operands of a OpGroupDecorate instruction should be of type
    161     // SPV_OPERAND_TYPE_ID.
    162     uint32_t candidate_id = group_decorate_inst->GetSingleWordInOperand(i);
    163     Instruction* candidate_inst = def_use_mgr->GetDef(candidate_id);
    164 
    165     if (!candidate_inst) {
    166       continue;
    167     }
    168 
    169     if (!target_inst) {
    170       // If the spec constant target has not been found yet, check if the
    171       // candidate instruction is the target.
    172       if (CanHaveSpecIdDecoration(*candidate_inst)) {
    173         target_inst = candidate_inst;
    174       } else {
    175         // Spec id decoration should not be applied on other instructions.
    176         // TODO(qining): Emit an error message in the invalid case once the
    177         // error handling is done.
    178         return nullptr;
    179       }
    180     } else {
    181       // If the spec constant target has been found, check if the candidate
    182       // instruction is the same one as the target. The module is invalid if
    183       // the candidate instruction is different with the found target.
    184       // TODO(qining): Emit an error messaage in the invalid case once the
    185       // error handling is done.
    186       if (candidate_inst != target_inst) return nullptr;
    187     }
    188   }
    189   return target_inst;
    190 }
    191 }  // namespace
    192 
    193 Pass::Status SetSpecConstantDefaultValuePass::Process() {
    194   // The operand index of decoration target in an OpDecorate instruction.
    195   const uint32_t kTargetIdOperandIndex = 0;
    196   // The operand index of the decoration literal in an OpDecorate instruction.
    197   const uint32_t kDecorationOperandIndex = 1;
    198   // The operand index of Spec id literal value in an OpDecorate SpecId
    199   // instruction.
    200   const uint32_t kSpecIdLiteralOperandIndex = 2;
    201   // The number of operands in an OpDecorate SpecId instruction.
    202   const uint32_t kOpDecorateSpecIdNumOperands = 3;
    203   // The in-operand index of the default value in a OpSpecConstant instruction.
    204   const uint32_t kOpSpecConstantLiteralInOperandIndex = 0;
    205 
    206   bool modified = false;
    207   // Scan through all the annotation instructions to find 'OpDecorate SpecId'
    208   // instructions. Then extract the decoration target of those instructions.
    209   // The decoration targets should be spec constant defining instructions with
    210   // opcode: OpSpecConstant{|True|False}. The spec id of those spec constants
    211   // will be used to look up their new default values in the mapping from
    212   // spec id to new default value strings. Once a new default value string
    213   // is found for a spec id, the string will be parsed according to the target
    214   // spec constant type. The parsed value will be used to replace the original
    215   // default value of the target spec constant.
    216   for (Instruction& inst : context()->annotations()) {
    217     // Only process 'OpDecorate SpecId' instructions
    218     if (inst.opcode() != SpvOp::SpvOpDecorate) continue;
    219     if (inst.NumOperands() != kOpDecorateSpecIdNumOperands) continue;
    220     if (inst.GetSingleWordInOperand(kDecorationOperandIndex) !=
    221         uint32_t(SpvDecoration::SpvDecorationSpecId)) {
    222       continue;
    223     }
    224 
    225     // 'inst' is an OpDecorate SpecId instruction.
    226     uint32_t spec_id = inst.GetSingleWordOperand(kSpecIdLiteralOperandIndex);
    227     uint32_t target_id = inst.GetSingleWordOperand(kTargetIdOperandIndex);
    228 
    229     // Find the spec constant defining instruction. Note that the
    230     // target_id might be a decoration group id.
    231     Instruction* spec_inst = nullptr;
    232     if (Instruction* target_inst = get_def_use_mgr()->GetDef(target_id)) {
    233       if (target_inst->opcode() == SpvOp::SpvOpDecorationGroup) {
    234         spec_inst =
    235             GetSpecIdTargetFromDecorationGroup(*target_inst, get_def_use_mgr());
    236       } else {
    237         spec_inst = target_inst;
    238       }
    239     } else {
    240       continue;
    241     }
    242     if (!spec_inst) continue;
    243 
    244     // Get the default value bit pattern for this spec id.
    245     std::vector<uint32_t> bit_pattern;
    246 
    247     if (spec_id_to_value_str_.size() != 0) {
    248       // Search for the new string-form default value for this spec id.
    249       auto iter = spec_id_to_value_str_.find(spec_id);
    250       if (iter == spec_id_to_value_str_.end()) {
    251         continue;
    252       }
    253 
    254       // Gets the string of the default value and parses it to bit pattern
    255       // with the type of the spec constant.
    256       const std::string& default_value_str = iter->second;
    257       bit_pattern = ParseDefaultValueStr(
    258           default_value_str.c_str(),
    259           context()->get_type_mgr()->GetType(spec_inst->type_id()));
    260 
    261     } else {
    262       // Search for the new bit-pattern-form default value for this spec id.
    263       auto iter = spec_id_to_value_bit_pattern_.find(spec_id);
    264       if (iter == spec_id_to_value_bit_pattern_.end()) {
    265         continue;
    266       }
    267 
    268       // Gets the bit-pattern of the default value from the map directly.
    269       bit_pattern = ParseDefaultValueBitPattern(
    270           iter->second,
    271           context()->get_type_mgr()->GetType(spec_inst->type_id()));
    272     }
    273 
    274     if (bit_pattern.empty()) continue;
    275 
    276     // Update the operand bit patterns of the spec constant defining
    277     // instruction.
    278     switch (spec_inst->opcode()) {
    279       case SpvOp::SpvOpSpecConstant:
    280         // If the new value is the same with the original value, no
    281         // need to do anything. Otherwise update the operand words.
    282         if (spec_inst->GetInOperand(kOpSpecConstantLiteralInOperandIndex)
    283                 .words != bit_pattern) {
    284           spec_inst->SetInOperand(kOpSpecConstantLiteralInOperandIndex,
    285                                   std::move(bit_pattern));
    286           modified = true;
    287         }
    288         break;
    289       case SpvOp::SpvOpSpecConstantTrue:
    290         // If the new value is also 'true', no need to change anything.
    291         // Otherwise, set the opcode to OpSpecConstantFalse;
    292         if (!static_cast<bool>(bit_pattern.front())) {
    293           spec_inst->SetOpcode(SpvOp::SpvOpSpecConstantFalse);
    294           modified = true;
    295         }
    296         break;
    297       case SpvOp::SpvOpSpecConstantFalse:
    298         // If the new value is also 'false', no need to change anything.
    299         // Otherwise, set the opcode to OpSpecConstantTrue;
    300         if (static_cast<bool>(bit_pattern.front())) {
    301           spec_inst->SetOpcode(SpvOp::SpvOpSpecConstantTrue);
    302           modified = true;
    303         }
    304         break;
    305       default:
    306         break;
    307     }
    308     // No need to update the DefUse manager, as this pass does not change any
    309     // ids.
    310   }
    311   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
    312 }
    313 
    314 // Returns true if the given char is ':', '\0' or considered as blank space
    315 // (i.e.: '\n', '\r', '\v', '\t', '\f' and ' ').
    316 bool IsSeparator(char ch) {
    317   return std::strchr(":\0", ch) || std::isspace(ch) != 0;
    318 }
    319 
    320 std::unique_ptr<SetSpecConstantDefaultValuePass::SpecIdToValueStrMap>
    321 SetSpecConstantDefaultValuePass::ParseDefaultValuesString(const char* str) {
    322   if (!str) return nullptr;
    323 
    324   auto spec_id_to_value = MakeUnique<SpecIdToValueStrMap>();
    325 
    326   // The parsing loop, break when points to the end.
    327   while (*str) {
    328     // Find the spec id.
    329     while (std::isspace(*str)) str++;  // skip leading spaces.
    330     const char* entry_begin = str;
    331     while (!IsSeparator(*str)) str++;
    332     const char* entry_end = str;
    333     std::string spec_id_str(entry_begin, entry_end - entry_begin);
    334     uint32_t spec_id = 0;
    335     if (!ParseNumber(spec_id_str.c_str(), &spec_id)) {
    336       // The spec id is not a valid uint32 number.
    337       return nullptr;
    338     }
    339     auto iter = spec_id_to_value->find(spec_id);
    340     if (iter != spec_id_to_value->end()) {
    341       // Same spec id has been defined before
    342       return nullptr;
    343     }
    344     // Find the ':', spaces between the spec id and the ':' are not allowed.
    345     if (*str++ != ':') {
    346       // ':' not found
    347       return nullptr;
    348     }
    349     // Find the value string
    350     const char* val_begin = str;
    351     while (!IsSeparator(*str)) str++;
    352     const char* val_end = str;
    353     if (val_end == val_begin) {
    354       // Value string is empty.
    355       return nullptr;
    356     }
    357     // Update the mapping with spec id and value string.
    358     (*spec_id_to_value)[spec_id] = std::string(val_begin, val_end - val_begin);
    359 
    360     // Skip trailing spaces.
    361     while (std::isspace(*str)) str++;
    362   }
    363 
    364   return spec_id_to_value;
    365 }
    366 
    367 }  // namespace opt
    368 }  // namespace spvtools
    369